mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-02 19:55:18 +02:00
feat(gateway): add Discord gateway install flow
This commit is contained in:
parent
05eaa46c3a
commit
7860714f74
2 changed files with 218 additions and 0 deletions
|
|
@ -33,11 +33,13 @@ from app.db import (
|
|||
get_async_session,
|
||||
)
|
||||
from app.gateway.accounts import (
|
||||
get_discord_account_by_guild,
|
||||
get_or_create_system_telegram_account,
|
||||
get_or_create_system_whatsapp_account,
|
||||
get_slack_account_by_team,
|
||||
)
|
||||
from app.gateway.bindings import resume_binding, revoke_binding
|
||||
from app.gateway.discord.adapter import discord_user_peer_id
|
||||
from app.gateway.inbox import (
|
||||
persist_inbound_event,
|
||||
slack_event_dedupe_key,
|
||||
|
|
@ -57,6 +59,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
SLACK_AUTHORIZATION_URL = "https://slack.com/oauth/v2/authorize"
|
||||
SLACK_TOKEN_URL = "https://slack.com/api/oauth.v2.access"
|
||||
DISCORD_AUTHORIZATION_URL = "https://discord.com/api/oauth2/authorize"
|
||||
DISCORD_TOKEN_URL = "https://discord.com/api/oauth2/token"
|
||||
DISCORD_API = "https://discord.com/api/v10"
|
||||
SLACK_BOT_SCOPES = [
|
||||
"app_mentions:read",
|
||||
"chat:write",
|
||||
|
|
@ -66,6 +71,17 @@ SLACK_BOT_SCOPES = [
|
|||
"users:read",
|
||||
"team:read",
|
||||
]
|
||||
DISCORD_GATEWAY_SCOPES = ["identify", "guilds", "bot"]
|
||||
DISCORD_VIEW_CHANNEL = 1 << 10
|
||||
DISCORD_SEND_MESSAGES = 1 << 11
|
||||
DISCORD_READ_MESSAGE_HISTORY = 1 << 16
|
||||
DISCORD_SEND_MESSAGES_IN_THREADS = 1 << 38
|
||||
DISCORD_GATEWAY_PERMISSIONS = (
|
||||
DISCORD_VIEW_CHANNEL
|
||||
| DISCORD_SEND_MESSAGES
|
||||
| DISCORD_READ_MESSAGE_HISTORY
|
||||
| DISCORD_SEND_MESSAGES_IN_THREADS
|
||||
)
|
||||
_state_manager: OAuthStateManager | None = None
|
||||
_token_encryption: TokenEncryption | None = None
|
||||
|
||||
|
|
@ -95,6 +111,13 @@ def _slack_redirect_uri() -> str:
|
|||
return f"{base.rstrip('/')}/api/v1/gateway/slack/callback"
|
||||
|
||||
|
||||
def _discord_redirect_uri() -> str:
|
||||
if config.GATEWAY_DISCORD_REDIRECT_URI:
|
||||
return config.GATEWAY_DISCORD_REDIRECT_URI
|
||||
base = config.BACKEND_URL or ""
|
||||
return f"{base.rstrip('/')}/api/v1/gateway/discord/callback"
|
||||
|
||||
|
||||
def _slack_frontend_redirect(space_id: int, *, success: bool = False, error: str | None = None) -> RedirectResponse:
|
||||
qs = "slack_gateway=connected" if success else f"error={error or 'slack_gateway_failed'}"
|
||||
return RedirectResponse(
|
||||
|
|
@ -102,6 +125,13 @@ def _slack_frontend_redirect(space_id: int, *, success: bool = False, error: str
|
|||
)
|
||||
|
||||
|
||||
def _discord_frontend_redirect(space_id: int, *, success: bool = False, error: str | None = None) -> RedirectResponse:
|
||||
qs = "discord_gateway=connected" if success else f"error={error or 'discord_gateway_failed'}"
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/user-settings?{qs}"
|
||||
)
|
||||
|
||||
|
||||
def verify_slack_signature(*, signing_secret: str, timestamp: str | None, signature: str | None, body: bytes) -> bool:
|
||||
if not signing_secret or not timestamp or not signature:
|
||||
return False
|
||||
|
|
@ -295,6 +325,166 @@ async def slack_gateway_callback(
|
|||
return _slack_frontend_redirect(space_id, success=True)
|
||||
|
||||
|
||||
@router.get("/discord/install")
|
||||
async def install_discord_gateway(
|
||||
search_space_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
) -> dict[str, str]:
|
||||
if not config.DISCORD_CLIENT_ID:
|
||||
raise HTTPException(status_code=500, detail="Discord gateway OAuth is not configured")
|
||||
state = _get_state_manager().generate_secure_state(search_space_id, user.id)
|
||||
auth_params = {
|
||||
"client_id": config.DISCORD_CLIENT_ID,
|
||||
"scope": " ".join(DISCORD_GATEWAY_SCOPES),
|
||||
"redirect_uri": _discord_redirect_uri(),
|
||||
"response_type": "code",
|
||||
"state": state,
|
||||
"permissions": str(DISCORD_GATEWAY_PERMISSIONS),
|
||||
}
|
||||
return {"auth_url": f"{DISCORD_AUTHORIZATION_URL}?{urlencode(auth_params)}"}
|
||||
|
||||
|
||||
@router.get("/discord/callback")
|
||||
async def discord_gateway_callback(
|
||||
code: str | None = None,
|
||||
error: str | None = None,
|
||||
state: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> RedirectResponse:
|
||||
space_id = None
|
||||
if state:
|
||||
try:
|
||||
state_data = _get_state_manager().validate_state(state)
|
||||
space_id = int(state_data["space_id"])
|
||||
except Exception:
|
||||
state_data = None
|
||||
else:
|
||||
state_data = None
|
||||
|
||||
if error:
|
||||
return _discord_frontend_redirect(space_id or 0, error="discord_gateway_oauth_denied")
|
||||
if not code or state_data is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid Discord gateway OAuth callback")
|
||||
if not config.DISCORD_CLIENT_ID or not config.DISCORD_CLIENT_SECRET:
|
||||
raise HTTPException(status_code=500, detail="Discord gateway OAuth is not configured")
|
||||
if not config.DISCORD_BOT_TOKEN:
|
||||
raise HTTPException(status_code=500, detail="Discord gateway bot token is not configured")
|
||||
|
||||
user_id = UUID(state_data["user_id"])
|
||||
token_payload = {
|
||||
"client_id": config.DISCORD_CLIENT_ID,
|
||||
"client_secret": config.DISCORD_CLIENT_SECRET,
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": _discord_redirect_uri(),
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
token_response = await client.post(
|
||||
DISCORD_TOKEN_URL,
|
||||
data=token_payload,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
token_response.raise_for_status()
|
||||
token_json = token_response.json()
|
||||
|
||||
oauth_access_token = token_json.get("access_token")
|
||||
guild = token_json.get("guild") or {}
|
||||
guild_id = guild.get("id")
|
||||
guild_name = guild.get("name")
|
||||
discord_user_id = None
|
||||
discord_username = None
|
||||
if oauth_access_token:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
user_response = await client.get(
|
||||
f"{DISCORD_API}/users/@me",
|
||||
headers={"Authorization": f"Bearer {oauth_access_token}"},
|
||||
)
|
||||
user_response.raise_for_status()
|
||||
user_json = user_response.json()
|
||||
discord_user_id = user_json.get("id")
|
||||
discord_username = user_json.get("username")
|
||||
|
||||
if not guild_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"Discord gateway OAuth did not return a guild. "
|
||||
"Choose a server during bot installation and try again."
|
||||
),
|
||||
)
|
||||
|
||||
enc = _get_token_encryption()
|
||||
credentials = {
|
||||
"bot_token": config.DISCORD_BOT_TOKEN,
|
||||
"token_type": "bot",
|
||||
"scope": token_json.get("scope"),
|
||||
}
|
||||
cursor_state = {
|
||||
"guild_id": guild_id,
|
||||
"guild_name": guild_name,
|
||||
"application_id": config.DISCORD_CLIENT_ID,
|
||||
"scope": token_json.get("scope"),
|
||||
"permissions": str(DISCORD_GATEWAY_PERMISSIONS),
|
||||
}
|
||||
|
||||
account = await get_discord_account_by_guild(session, guild_id=str(guild_id))
|
||||
if account is None:
|
||||
account = ExternalChatAccount(
|
||||
platform=ExternalChatPlatform.DISCORD,
|
||||
mode=ExternalChatAccountMode.CLOUD_SHARED,
|
||||
is_system_account=True,
|
||||
encrypted_credentials=enc.encrypt_token(json.dumps(credentials)),
|
||||
bot_username="SurfSense",
|
||||
cursor_state=cursor_state,
|
||||
health_status=ExternalChatHealthStatus.UNKNOWN,
|
||||
)
|
||||
session.add(account)
|
||||
await session.flush()
|
||||
else:
|
||||
account.encrypted_credentials = enc.encrypt_token(json.dumps(credentials))
|
||||
account.cursor_state = {**(account.cursor_state or {}), **cursor_state}
|
||||
account.health_status = ExternalChatHealthStatus.UNKNOWN
|
||||
|
||||
if discord_user_id:
|
||||
peer_id = discord_user_peer_id(str(guild_id), str(discord_user_id))
|
||||
existing_binding_result = await session.execute(
|
||||
select(ExternalChatBinding).where(
|
||||
ExternalChatBinding.account_id == account.id,
|
||||
ExternalChatBinding.external_peer_id == peer_id,
|
||||
ExternalChatBinding.state.in_(
|
||||
[ExternalChatBindingState.BOUND, ExternalChatBindingState.SUSPENDED]
|
||||
),
|
||||
)
|
||||
)
|
||||
binding = existing_binding_result.scalars().first()
|
||||
metadata = {
|
||||
"kind": "discord_user",
|
||||
"guild_id": guild_id,
|
||||
"guild_name": guild_name,
|
||||
"discord_user_id": discord_user_id,
|
||||
}
|
||||
if binding is None:
|
||||
session.add(
|
||||
ExternalChatBinding(
|
||||
account_id=account.id,
|
||||
user_id=user_id,
|
||||
search_space_id=space_id,
|
||||
state=ExternalChatBindingState.BOUND,
|
||||
external_peer_id=peer_id,
|
||||
external_peer_kind=ExternalChatPeerKind.DIRECT,
|
||||
external_username=discord_username or discord_user_id,
|
||||
external_metadata=metadata,
|
||||
)
|
||||
)
|
||||
elif binding.user_id == user_id:
|
||||
binding.search_space_id = space_id
|
||||
binding.external_username = discord_username or binding.external_username
|
||||
binding.external_metadata = {**(binding.external_metadata or {}), **metadata}
|
||||
|
||||
await session.commit()
|
||||
return _discord_frontend_redirect(space_id, success=True)
|
||||
|
||||
|
||||
@router.post("/webhooks/slack")
|
||||
async def slack_webhook(
|
||||
request: Request,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import hmac
|
|||
import inspect
|
||||
import json
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -272,3 +273,30 @@ async def test_slack_webhook_ignores_self_event(monkeypatch, mocker):
|
|||
assert response.status_code == 200
|
||||
persist.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_gateway_install_returns_oauth_url(monkeypatch):
|
||||
monkeypatch.setattr(routes.config, "DISCORD_CLIENT_ID", "discord-client")
|
||||
monkeypatch.setattr(
|
||||
routes.config,
|
||||
"GATEWAY_DISCORD_REDIRECT_URI",
|
||||
"http://localhost:8000/api/v1/gateway/discord/callback",
|
||||
)
|
||||
monkeypatch.setattr(routes.config, "SECRET_KEY", "test-secret")
|
||||
|
||||
response = await routes.install_discord_gateway(
|
||||
search_space_id=123,
|
||||
user=SimpleNamespace(id="00000000-0000-0000-0000-000000000001"),
|
||||
)
|
||||
|
||||
assert response["auth_url"].startswith("https://discord.com/api/oauth2/authorize?")
|
||||
assert "client_id=discord-client" in response["auth_url"]
|
||||
assert "gateway%2Fdiscord%2Fcallback" in response["auth_url"]
|
||||
assert "scope=identify+guilds+bot" in response["auth_url"]
|
||||
|
||||
|
||||
def test_discord_gateway_callback_does_not_create_search_source_connector():
|
||||
callback_source = inspect.getsource(routes.discord_gateway_callback)
|
||||
|
||||
assert "SearchSourceConnector" not in callback_source
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue