mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-04 20:05:16 +02:00
feat(gateway): add Discord gateway install flow
This commit is contained in:
parent
05eaa46c3a
commit
7860714f74
2 changed files with 218 additions and 0 deletions
|
|
@ -33,11 +33,13 @@ from app.db import (
|
||||||
get_async_session,
|
get_async_session,
|
||||||
)
|
)
|
||||||
from app.gateway.accounts import (
|
from app.gateway.accounts import (
|
||||||
|
get_discord_account_by_guild,
|
||||||
get_or_create_system_telegram_account,
|
get_or_create_system_telegram_account,
|
||||||
get_or_create_system_whatsapp_account,
|
get_or_create_system_whatsapp_account,
|
||||||
get_slack_account_by_team,
|
get_slack_account_by_team,
|
||||||
)
|
)
|
||||||
from app.gateway.bindings import resume_binding, revoke_binding
|
from app.gateway.bindings import resume_binding, revoke_binding
|
||||||
|
from app.gateway.discord.adapter import discord_user_peer_id
|
||||||
from app.gateway.inbox import (
|
from app.gateway.inbox import (
|
||||||
persist_inbound_event,
|
persist_inbound_event,
|
||||||
slack_event_dedupe_key,
|
slack_event_dedupe_key,
|
||||||
|
|
@ -57,6 +59,9 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SLACK_AUTHORIZATION_URL = "https://slack.com/oauth/v2/authorize"
|
SLACK_AUTHORIZATION_URL = "https://slack.com/oauth/v2/authorize"
|
||||||
SLACK_TOKEN_URL = "https://slack.com/api/oauth.v2.access"
|
SLACK_TOKEN_URL = "https://slack.com/api/oauth.v2.access"
|
||||||
|
DISCORD_AUTHORIZATION_URL = "https://discord.com/api/oauth2/authorize"
|
||||||
|
DISCORD_TOKEN_URL = "https://discord.com/api/oauth2/token"
|
||||||
|
DISCORD_API = "https://discord.com/api/v10"
|
||||||
SLACK_BOT_SCOPES = [
|
SLACK_BOT_SCOPES = [
|
||||||
"app_mentions:read",
|
"app_mentions:read",
|
||||||
"chat:write",
|
"chat:write",
|
||||||
|
|
@ -66,6 +71,17 @@ SLACK_BOT_SCOPES = [
|
||||||
"users:read",
|
"users:read",
|
||||||
"team:read",
|
"team:read",
|
||||||
]
|
]
|
||||||
|
DISCORD_GATEWAY_SCOPES = ["identify", "guilds", "bot"]
|
||||||
|
DISCORD_VIEW_CHANNEL = 1 << 10
|
||||||
|
DISCORD_SEND_MESSAGES = 1 << 11
|
||||||
|
DISCORD_READ_MESSAGE_HISTORY = 1 << 16
|
||||||
|
DISCORD_SEND_MESSAGES_IN_THREADS = 1 << 38
|
||||||
|
DISCORD_GATEWAY_PERMISSIONS = (
|
||||||
|
DISCORD_VIEW_CHANNEL
|
||||||
|
| DISCORD_SEND_MESSAGES
|
||||||
|
| DISCORD_READ_MESSAGE_HISTORY
|
||||||
|
| DISCORD_SEND_MESSAGES_IN_THREADS
|
||||||
|
)
|
||||||
_state_manager: OAuthStateManager | None = None
|
_state_manager: OAuthStateManager | None = None
|
||||||
_token_encryption: TokenEncryption | None = None
|
_token_encryption: TokenEncryption | None = None
|
||||||
|
|
||||||
|
|
@ -95,6 +111,13 @@ def _slack_redirect_uri() -> str:
|
||||||
return f"{base.rstrip('/')}/api/v1/gateway/slack/callback"
|
return f"{base.rstrip('/')}/api/v1/gateway/slack/callback"
|
||||||
|
|
||||||
|
|
||||||
|
def _discord_redirect_uri() -> str:
|
||||||
|
if config.GATEWAY_DISCORD_REDIRECT_URI:
|
||||||
|
return config.GATEWAY_DISCORD_REDIRECT_URI
|
||||||
|
base = config.BACKEND_URL or ""
|
||||||
|
return f"{base.rstrip('/')}/api/v1/gateway/discord/callback"
|
||||||
|
|
||||||
|
|
||||||
def _slack_frontend_redirect(space_id: int, *, success: bool = False, error: str | None = None) -> RedirectResponse:
|
def _slack_frontend_redirect(space_id: int, *, success: bool = False, error: str | None = None) -> RedirectResponse:
|
||||||
qs = "slack_gateway=connected" if success else f"error={error or 'slack_gateway_failed'}"
|
qs = "slack_gateway=connected" if success else f"error={error or 'slack_gateway_failed'}"
|
||||||
return RedirectResponse(
|
return RedirectResponse(
|
||||||
|
|
@ -102,6 +125,13 @@ def _slack_frontend_redirect(space_id: int, *, success: bool = False, error: str
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _discord_frontend_redirect(space_id: int, *, success: bool = False, error: str | None = None) -> RedirectResponse:
|
||||||
|
qs = "discord_gateway=connected" if success else f"error={error or 'discord_gateway_failed'}"
|
||||||
|
return RedirectResponse(
|
||||||
|
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/user-settings?{qs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def verify_slack_signature(*, signing_secret: str, timestamp: str | None, signature: str | None, body: bytes) -> bool:
|
def verify_slack_signature(*, signing_secret: str, timestamp: str | None, signature: str | None, body: bytes) -> bool:
|
||||||
if not signing_secret or not timestamp or not signature:
|
if not signing_secret or not timestamp or not signature:
|
||||||
return False
|
return False
|
||||||
|
|
@ -295,6 +325,166 @@ async def slack_gateway_callback(
|
||||||
return _slack_frontend_redirect(space_id, success=True)
|
return _slack_frontend_redirect(space_id, success=True)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/discord/install")
|
||||||
|
async def install_discord_gateway(
|
||||||
|
search_space_id: int,
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
) -> dict[str, str]:
|
||||||
|
if not config.DISCORD_CLIENT_ID:
|
||||||
|
raise HTTPException(status_code=500, detail="Discord gateway OAuth is not configured")
|
||||||
|
state = _get_state_manager().generate_secure_state(search_space_id, user.id)
|
||||||
|
auth_params = {
|
||||||
|
"client_id": config.DISCORD_CLIENT_ID,
|
||||||
|
"scope": " ".join(DISCORD_GATEWAY_SCOPES),
|
||||||
|
"redirect_uri": _discord_redirect_uri(),
|
||||||
|
"response_type": "code",
|
||||||
|
"state": state,
|
||||||
|
"permissions": str(DISCORD_GATEWAY_PERMISSIONS),
|
||||||
|
}
|
||||||
|
return {"auth_url": f"{DISCORD_AUTHORIZATION_URL}?{urlencode(auth_params)}"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/discord/callback")
|
||||||
|
async def discord_gateway_callback(
|
||||||
|
code: str | None = None,
|
||||||
|
error: str | None = None,
|
||||||
|
state: str | None = None,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
) -> RedirectResponse:
|
||||||
|
space_id = None
|
||||||
|
if state:
|
||||||
|
try:
|
||||||
|
state_data = _get_state_manager().validate_state(state)
|
||||||
|
space_id = int(state_data["space_id"])
|
||||||
|
except Exception:
|
||||||
|
state_data = None
|
||||||
|
else:
|
||||||
|
state_data = None
|
||||||
|
|
||||||
|
if error:
|
||||||
|
return _discord_frontend_redirect(space_id or 0, error="discord_gateway_oauth_denied")
|
||||||
|
if not code or state_data is None:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid Discord gateway OAuth callback")
|
||||||
|
if not config.DISCORD_CLIENT_ID or not config.DISCORD_CLIENT_SECRET:
|
||||||
|
raise HTTPException(status_code=500, detail="Discord gateway OAuth is not configured")
|
||||||
|
if not config.DISCORD_BOT_TOKEN:
|
||||||
|
raise HTTPException(status_code=500, detail="Discord gateway bot token is not configured")
|
||||||
|
|
||||||
|
user_id = UUID(state_data["user_id"])
|
||||||
|
token_payload = {
|
||||||
|
"client_id": config.DISCORD_CLIENT_ID,
|
||||||
|
"client_secret": config.DISCORD_CLIENT_SECRET,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": _discord_redirect_uri(),
|
||||||
|
}
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
token_response = await client.post(
|
||||||
|
DISCORD_TOKEN_URL,
|
||||||
|
data=token_payload,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
)
|
||||||
|
token_response.raise_for_status()
|
||||||
|
token_json = token_response.json()
|
||||||
|
|
||||||
|
oauth_access_token = token_json.get("access_token")
|
||||||
|
guild = token_json.get("guild") or {}
|
||||||
|
guild_id = guild.get("id")
|
||||||
|
guild_name = guild.get("name")
|
||||||
|
discord_user_id = None
|
||||||
|
discord_username = None
|
||||||
|
if oauth_access_token:
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
user_response = await client.get(
|
||||||
|
f"{DISCORD_API}/users/@me",
|
||||||
|
headers={"Authorization": f"Bearer {oauth_access_token}"},
|
||||||
|
)
|
||||||
|
user_response.raise_for_status()
|
||||||
|
user_json = user_response.json()
|
||||||
|
discord_user_id = user_json.get("id")
|
||||||
|
discord_username = user_json.get("username")
|
||||||
|
|
||||||
|
if not guild_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=(
|
||||||
|
"Discord gateway OAuth did not return a guild. "
|
||||||
|
"Choose a server during bot installation and try again."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
enc = _get_token_encryption()
|
||||||
|
credentials = {
|
||||||
|
"bot_token": config.DISCORD_BOT_TOKEN,
|
||||||
|
"token_type": "bot",
|
||||||
|
"scope": token_json.get("scope"),
|
||||||
|
}
|
||||||
|
cursor_state = {
|
||||||
|
"guild_id": guild_id,
|
||||||
|
"guild_name": guild_name,
|
||||||
|
"application_id": config.DISCORD_CLIENT_ID,
|
||||||
|
"scope": token_json.get("scope"),
|
||||||
|
"permissions": str(DISCORD_GATEWAY_PERMISSIONS),
|
||||||
|
}
|
||||||
|
|
||||||
|
account = await get_discord_account_by_guild(session, guild_id=str(guild_id))
|
||||||
|
if account is None:
|
||||||
|
account = ExternalChatAccount(
|
||||||
|
platform=ExternalChatPlatform.DISCORD,
|
||||||
|
mode=ExternalChatAccountMode.CLOUD_SHARED,
|
||||||
|
is_system_account=True,
|
||||||
|
encrypted_credentials=enc.encrypt_token(json.dumps(credentials)),
|
||||||
|
bot_username="SurfSense",
|
||||||
|
cursor_state=cursor_state,
|
||||||
|
health_status=ExternalChatHealthStatus.UNKNOWN,
|
||||||
|
)
|
||||||
|
session.add(account)
|
||||||
|
await session.flush()
|
||||||
|
else:
|
||||||
|
account.encrypted_credentials = enc.encrypt_token(json.dumps(credentials))
|
||||||
|
account.cursor_state = {**(account.cursor_state or {}), **cursor_state}
|
||||||
|
account.health_status = ExternalChatHealthStatus.UNKNOWN
|
||||||
|
|
||||||
|
if discord_user_id:
|
||||||
|
peer_id = discord_user_peer_id(str(guild_id), str(discord_user_id))
|
||||||
|
existing_binding_result = await session.execute(
|
||||||
|
select(ExternalChatBinding).where(
|
||||||
|
ExternalChatBinding.account_id == account.id,
|
||||||
|
ExternalChatBinding.external_peer_id == peer_id,
|
||||||
|
ExternalChatBinding.state.in_(
|
||||||
|
[ExternalChatBindingState.BOUND, ExternalChatBindingState.SUSPENDED]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
binding = existing_binding_result.scalars().first()
|
||||||
|
metadata = {
|
||||||
|
"kind": "discord_user",
|
||||||
|
"guild_id": guild_id,
|
||||||
|
"guild_name": guild_name,
|
||||||
|
"discord_user_id": discord_user_id,
|
||||||
|
}
|
||||||
|
if binding is None:
|
||||||
|
session.add(
|
||||||
|
ExternalChatBinding(
|
||||||
|
account_id=account.id,
|
||||||
|
user_id=user_id,
|
||||||
|
search_space_id=space_id,
|
||||||
|
state=ExternalChatBindingState.BOUND,
|
||||||
|
external_peer_id=peer_id,
|
||||||
|
external_peer_kind=ExternalChatPeerKind.DIRECT,
|
||||||
|
external_username=discord_username or discord_user_id,
|
||||||
|
external_metadata=metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif binding.user_id == user_id:
|
||||||
|
binding.search_space_id = space_id
|
||||||
|
binding.external_username = discord_username or binding.external_username
|
||||||
|
binding.external_metadata = {**(binding.external_metadata or {}), **metadata}
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
return _discord_frontend_redirect(space_id, success=True)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/webhooks/slack")
|
@router.post("/webhooks/slack")
|
||||||
async def slack_webhook(
|
async def slack_webhook(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import hmac
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
@ -272,3 +273,30 @@ async def test_slack_webhook_ignores_self_event(monkeypatch, mocker):
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
persist.assert_not_awaited()
|
persist.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discord_gateway_install_returns_oauth_url(monkeypatch):
|
||||||
|
monkeypatch.setattr(routes.config, "DISCORD_CLIENT_ID", "discord-client")
|
||||||
|
monkeypatch.setattr(
|
||||||
|
routes.config,
|
||||||
|
"GATEWAY_DISCORD_REDIRECT_URI",
|
||||||
|
"http://localhost:8000/api/v1/gateway/discord/callback",
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(routes.config, "SECRET_KEY", "test-secret")
|
||||||
|
|
||||||
|
response = await routes.install_discord_gateway(
|
||||||
|
search_space_id=123,
|
||||||
|
user=SimpleNamespace(id="00000000-0000-0000-0000-000000000001"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response["auth_url"].startswith("https://discord.com/api/oauth2/authorize?")
|
||||||
|
assert "client_id=discord-client" in response["auth_url"]
|
||||||
|
assert "gateway%2Fdiscord%2Fcallback" in response["auth_url"]
|
||||||
|
assert "scope=identify+guilds+bot" in response["auth_url"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_discord_gateway_callback_does_not_create_search_source_connector():
|
||||||
|
callback_source = inspect.getsource(routes.discord_gateway_callback)
|
||||||
|
|
||||||
|
assert "SearchSourceConnector" not in callback_source
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue