mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-04 20:05:16 +02:00
feat(gateway): add gateway domain primitives
This commit is contained in:
parent
ae3ce91465
commit
c9b7d7b572
13 changed files with 481 additions and 0 deletions
2
surfsense_backend/app/gateway/__init__.py
Normal file
2
surfsense_backend/app/gateway/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
"""Messaging gateway infrastructure for external chat channels."""
|
||||||
|
|
||||||
54
surfsense_backend/app/gateway/accounts.py
Normal file
54
surfsense_backend/app/gateway/accounts.py
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
"""Gateway account helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.db import (
|
||||||
|
GatewayAccountMode,
|
||||||
|
GatewayHealthStatus,
|
||||||
|
GatewayPlatform,
|
||||||
|
GatewayPlatformAccount,
|
||||||
|
)
|
||||||
|
from app.utils.oauth_security import TokenEncryption
|
||||||
|
|
||||||
|
|
||||||
|
def account_token(account: GatewayPlatformAccount) -> str | None:
|
||||||
|
if account.is_system_account and account.platform == GatewayPlatform.TELEGRAM:
|
||||||
|
return config.TELEGRAM_SHARED_BOT_TOKEN
|
||||||
|
if not account.encrypted_credentials:
|
||||||
|
return None
|
||||||
|
return TokenEncryption(config.SECRET_KEY or "").decrypt_token(
|
||||||
|
account.encrypted_credentials
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_or_create_system_telegram_account(
|
||||||
|
session: AsyncSession,
|
||||||
|
) -> GatewayPlatformAccount:
|
||||||
|
result = await session.execute(
|
||||||
|
select(GatewayPlatformAccount).where(
|
||||||
|
GatewayPlatformAccount.platform == GatewayPlatform.TELEGRAM,
|
||||||
|
GatewayPlatformAccount.is_system_account.is_(True),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
account = result.scalars().first()
|
||||||
|
if account is not None:
|
||||||
|
return account
|
||||||
|
account = GatewayPlatformAccount(
|
||||||
|
platform=GatewayPlatform.TELEGRAM,
|
||||||
|
mode=GatewayAccountMode.CLOUD_SHARED,
|
||||||
|
is_system_account=True,
|
||||||
|
account_metadata={
|
||||||
|
"bot_username": config.TELEGRAM_SHARED_BOT_USERNAME,
|
||||||
|
"webhook_secret": config.TELEGRAM_WEBHOOK_SECRET,
|
||||||
|
},
|
||||||
|
cursor_state={},
|
||||||
|
health_status=GatewayHealthStatus.UNKNOWN,
|
||||||
|
)
|
||||||
|
session.add(account)
|
||||||
|
await session.flush()
|
||||||
|
return account
|
||||||
|
|
||||||
55
surfsense_backend/app/gateway/auth_invariant.py
Normal file
55
surfsense_backend/app/gateway/auth_invariant.py
Normal file
|
|
@ -0,0 +1,55 @@
|
||||||
|
"""Authorization invariants for gateway-routed turns."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import GatewayConversationBinding, Permission, User
|
||||||
|
from app.gateway.bindings import suspend_binding
|
||||||
|
from app.observability.metrics import record_gateway_auth_invariant_failure
|
||||||
|
from app.utils.rbac import check_permission, check_search_space_access
|
||||||
|
|
||||||
|
|
||||||
|
class GatewaySuspendedError(RuntimeError):
|
||||||
|
def __init__(self, reason: str) -> None:
|
||||||
|
self.reason = reason
|
||||||
|
super().__init__(reason)
|
||||||
|
|
||||||
|
|
||||||
|
async def _fail(
|
||||||
|
session: AsyncSession,
|
||||||
|
binding: GatewayConversationBinding,
|
||||||
|
reason: str,
|
||||||
|
) -> None:
|
||||||
|
suspend_binding(binding, reason)
|
||||||
|
record_gateway_auth_invariant_failure(cause=reason)
|
||||||
|
await session.flush()
|
||||||
|
raise GatewaySuspendedError(reason)
|
||||||
|
|
||||||
|
|
||||||
|
async def assert_authorization_invariant(
|
||||||
|
session: AsyncSession,
|
||||||
|
binding: GatewayConversationBinding,
|
||||||
|
) -> User:
|
||||||
|
if binding.state != "bound":
|
||||||
|
await _fail(session, binding, "binding_not_bound")
|
||||||
|
|
||||||
|
user = await session.get(User, binding.user_id)
|
||||||
|
if user is None:
|
||||||
|
await _fail(session, binding, "owner_missing")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await check_search_space_access(session, user, binding.search_space_id)
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
binding.search_space_id,
|
||||||
|
Permission.CHATS_CREATE.value,
|
||||||
|
"Gateway owner no longer has permission to chat in this search space",
|
||||||
|
)
|
||||||
|
except HTTPException as exc:
|
||||||
|
await _fail(session, binding, f"rbac_{exc.status_code}")
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
2
surfsense_backend/app/gateway/base/__init__.py
Normal file
2
surfsense_backend/app/gateway/base/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
"""Base gateway interfaces."""
|
||||||
|
|
||||||
70
surfsense_backend/app/gateway/base/adapter.py
Normal file
70
surfsense_backend/app/gateway/base/adapter.py
Normal file
|
|
@ -0,0 +1,70 @@
|
||||||
|
"""Platform adapter interfaces for messaging gateways."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ParsedInboundEvent:
|
||||||
|
platform: str
|
||||||
|
event_kind: str
|
||||||
|
external_peer_id: str | None
|
||||||
|
external_peer_kind: str
|
||||||
|
external_message_id: str | None
|
||||||
|
external_user_id: str | None
|
||||||
|
text: str | None
|
||||||
|
raw_payload: dict[str, Any]
|
||||||
|
display_name: str | None = None
|
||||||
|
username: str | None = None
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PlatformSendResult:
|
||||||
|
external_message_id: str
|
||||||
|
raw_response: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class BasePlatformAdapter(ABC):
|
||||||
|
platform: str
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse_inbound(self, raw_payload: dict[str, Any]) -> ParsedInboundEvent:
|
||||||
|
"""Parse a provider webhook/update into the gateway's normalized shape."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def send_message(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
external_peer_id: str,
|
||||||
|
text: str,
|
||||||
|
parse_mode: str | None = None,
|
||||||
|
reply_to_message_id: str | None = None,
|
||||||
|
) -> PlatformSendResult:
|
||||||
|
"""Send a new platform message."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def edit_message(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
external_peer_id: str,
|
||||||
|
external_message_id: str,
|
||||||
|
text: str,
|
||||||
|
parse_mode: str | None = None,
|
||||||
|
) -> PlatformSendResult:
|
||||||
|
"""Edit an existing platform message."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def validate_credentials(self) -> dict[str, Any]:
|
||||||
|
"""Validate configured credentials and return account metadata."""
|
||||||
|
|
||||||
|
async def fetch_updates(self, *, offset: int | None) -> AsyncIterator[dict[str, Any]]:
|
||||||
|
"""Yield provider updates for long-polling adapters."""
|
||||||
|
if False:
|
||||||
|
yield {} # pragma: no cover
|
||||||
|
raise NotImplementedError("This adapter does not support long-polling")
|
||||||
|
|
||||||
19
surfsense_backend/app/gateway/base/identity.py
Normal file
19
surfsense_backend/app/gateway/base/identity.py
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
"""Gateway identity helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_external_peer_id(value: str | int | None) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return str(value).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def hash_external_id(value: str | int | None) -> str | None:
|
||||||
|
normalized = normalize_external_peer_id(value)
|
||||||
|
if not normalized:
|
||||||
|
return None
|
||||||
|
return hashlib.sha256(normalized.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
28
surfsense_backend/app/gateway/base/translator.py
Normal file
28
surfsense_backend/app/gateway/base/translator.py
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
"""Base stream translator for platform-specific outbound UX."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GatewayStreamEvent:
|
||||||
|
"""Small provider-neutral event shape consumed by translators.
|
||||||
|
|
||||||
|
The existing chat stack emits Vercel/assistant-ui events. Gateway code
|
||||||
|
normalizes the subset it needs into this shape before handing it to the
|
||||||
|
platform translator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str
|
||||||
|
data: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseStreamTranslator(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def translate(self, events: AsyncIterator[GatewayStreamEvent]) -> None:
|
||||||
|
"""Consume agent stream events and emit platform messages."""
|
||||||
|
|
||||||
62
surfsense_backend/app/gateway/bindings.py
Normal file
62
surfsense_backend/app/gateway/bindings.py
Normal file
|
|
@ -0,0 +1,62 @@
|
||||||
|
"""Gateway binding helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import (
|
||||||
|
ChatVisibility,
|
||||||
|
GatewayBindingState,
|
||||||
|
GatewayConversationBinding,
|
||||||
|
NewChatThread,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_or_create_thread_for_binding(
|
||||||
|
session: AsyncSession,
|
||||||
|
binding: GatewayConversationBinding,
|
||||||
|
) -> NewChatThread:
|
||||||
|
if binding.active_thread_id is not None:
|
||||||
|
result = await session.execute(
|
||||||
|
select(NewChatThread).where(NewChatThread.id == binding.active_thread_id)
|
||||||
|
)
|
||||||
|
thread = result.scalars().first()
|
||||||
|
if thread is not None and not thread.archived:
|
||||||
|
return thread
|
||||||
|
|
||||||
|
thread = NewChatThread(
|
||||||
|
title="Telegram chat",
|
||||||
|
search_space_id=binding.search_space_id,
|
||||||
|
created_by_id=binding.user_id,
|
||||||
|
visibility=ChatVisibility.PRIVATE,
|
||||||
|
source="telegram",
|
||||||
|
binding_id=binding.id,
|
||||||
|
)
|
||||||
|
session.add(thread)
|
||||||
|
await session.flush()
|
||||||
|
binding.active_thread_id = thread.id
|
||||||
|
return thread
|
||||||
|
|
||||||
|
|
||||||
|
def suspend_binding(binding: GatewayConversationBinding, reason: str) -> None:
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
binding.state = GatewayBindingState.SUSPENDED
|
||||||
|
binding.suspended_at = now
|
||||||
|
binding.suspended_reason = reason
|
||||||
|
|
||||||
|
|
||||||
|
def revoke_binding(binding: GatewayConversationBinding) -> None:
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
binding.state = GatewayBindingState.REVOKED
|
||||||
|
binding.revoked_at = now
|
||||||
|
binding.active_thread_id = None
|
||||||
|
|
||||||
|
|
||||||
|
def resume_binding(binding: GatewayConversationBinding) -> None:
|
||||||
|
binding.state = GatewayBindingState.BOUND
|
||||||
|
binding.suspended_at = None
|
||||||
|
binding.suspended_reason = None
|
||||||
|
|
||||||
35
surfsense_backend/app/gateway/hitl_filter.py
Normal file
35
surfsense_backend/app/gateway/hitl_filter.py
Normal file
|
|
@ -0,0 +1,35 @@
|
||||||
|
"""Filter approval-required tools from gateway agent invocations."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
DEFAULT_HITL_TOOL_NAMES = {
|
||||||
|
"delete_document",
|
||||||
|
"delete_folder",
|
||||||
|
"delete_note",
|
||||||
|
"delete_report",
|
||||||
|
"delete_connector",
|
||||||
|
"send_email",
|
||||||
|
"share_chat",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_name(tool: Any) -> str | None:
|
||||||
|
if isinstance(tool, str):
|
||||||
|
return tool
|
||||||
|
return getattr(tool, "name", None) or getattr(tool, "__name__", None)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_hitl_tools(
|
||||||
|
toolkit: Iterable[Any] | None,
|
||||||
|
*,
|
||||||
|
blocked_names: set[str] | None = None,
|
||||||
|
) -> list[Any] | None:
|
||||||
|
"""Return a toolkit with known approval-required tools removed."""
|
||||||
|
if toolkit is None:
|
||||||
|
return None
|
||||||
|
blocked = blocked_names or DEFAULT_HITL_TOOL_NAMES
|
||||||
|
return [tool for tool in toolkit if (_tool_name(tool) or "") not in blocked]
|
||||||
|
|
||||||
44
surfsense_backend/app/gateway/inbox.py
Normal file
44
surfsense_backend/app/gateway/inbox.py
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
"""Durable gateway inbox helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import GatewayInboundEvent, GatewayPlatform
|
||||||
|
|
||||||
|
|
||||||
|
def telegram_event_dedupe_key(update_id: int | str) -> str:
|
||||||
|
return f"update:{update_id}"
|
||||||
|
|
||||||
|
|
||||||
|
async def persist_inbound_event(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
account_id: int,
|
||||||
|
platform: GatewayPlatform,
|
||||||
|
event_dedupe_key: str,
|
||||||
|
event_kind: str,
|
||||||
|
raw_payload: dict,
|
||||||
|
external_event_id: str | None = None,
|
||||||
|
external_message_id: str | None = None,
|
||||||
|
) -> int | None:
|
||||||
|
stmt = (
|
||||||
|
insert(GatewayInboundEvent)
|
||||||
|
.values(
|
||||||
|
account_id=account_id,
|
||||||
|
platform=platform,
|
||||||
|
event_dedupe_key=event_dedupe_key,
|
||||||
|
external_event_id=external_event_id,
|
||||||
|
external_message_id=external_message_id,
|
||||||
|
event_kind=event_kind,
|
||||||
|
raw_payload=raw_payload,
|
||||||
|
)
|
||||||
|
.on_conflict_do_nothing(
|
||||||
|
index_elements=["account_id", "event_dedupe_key"],
|
||||||
|
)
|
||||||
|
.returning(GatewayInboundEvent.id)
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
54
surfsense_backend/app/gateway/pairing.py
Normal file
54
surfsense_backend/app/gateway/pairing.py
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
"""Pairing code lifecycle for gateway bindings."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import GatewayBindingState, GatewayConversationBinding
|
||||||
|
|
||||||
|
PAIRING_CODE_TTL = timedelta(minutes=10)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_pairing_code() -> str:
|
||||||
|
return secrets.token_urlsafe(6)
|
||||||
|
|
||||||
|
|
||||||
|
def pairing_expires_at() -> datetime:
|
||||||
|
return datetime.now(UTC) + PAIRING_CODE_TTL
|
||||||
|
|
||||||
|
|
||||||
|
async def redeem_pairing_code(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
code: str,
|
||||||
|
external_peer_id: str,
|
||||||
|
external_peer_kind: str,
|
||||||
|
external_display_name: str | None,
|
||||||
|
external_username: str | None,
|
||||||
|
external_metadata: dict | None = None,
|
||||||
|
) -> GatewayConversationBinding | None:
|
||||||
|
result = await session.execute(
|
||||||
|
select(GatewayConversationBinding).where(
|
||||||
|
GatewayConversationBinding.pairing_code == code,
|
||||||
|
GatewayConversationBinding.state == GatewayBindingState.PENDING,
|
||||||
|
GatewayConversationBinding.pairing_code_expires_at > datetime.now(UTC),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
binding = result.scalars().first()
|
||||||
|
if binding is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
binding.state = GatewayBindingState.BOUND
|
||||||
|
binding.pairing_code = None
|
||||||
|
binding.pairing_code_expires_at = None
|
||||||
|
binding.external_peer_id = external_peer_id
|
||||||
|
binding.external_peer_kind = external_peer_kind
|
||||||
|
binding.external_display_name = external_display_name
|
||||||
|
binding.external_username = external_username
|
||||||
|
binding.external_metadata = external_metadata or {}
|
||||||
|
return binding
|
||||||
|
|
||||||
15
surfsense_backend/tests/unit/gateway/test_hitl_filter.py
Normal file
15
surfsense_backend/tests/unit/gateway/test_hitl_filter.py
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
from app.gateway.hitl_filter import filter_hitl_tools
|
||||||
|
|
||||||
|
|
||||||
|
class Tool:
|
||||||
|
def __init__(self, name: str) -> None:
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_hitl_tools_removes_known_approval_tools():
|
||||||
|
tools = [Tool("delete_document"), Tool("search"), "send_email", "summarize"]
|
||||||
|
|
||||||
|
filtered = filter_hitl_tools(tools)
|
||||||
|
|
||||||
|
assert [getattr(tool, "name", tool) for tool in filtered] == ["search", "summarize"]
|
||||||
|
|
||||||
41
surfsense_backend/tests/unit/gateway/test_pairing.py
Normal file
41
surfsense_backend/tests/unit/gateway/test_pairing.py
Normal file
|
|
@ -0,0 +1,41 @@
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.db import GatewayBindingState
|
||||||
|
from app.gateway.pairing import generate_pairing_code, redeem_pairing_code
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_pairing_code_is_short_display_token():
|
||||||
|
code = generate_pairing_code()
|
||||||
|
|
||||||
|
assert len(code) >= 8
|
||||||
|
assert "\n" not in code
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redeem_pairing_code_binds_pending_row(mocker):
|
||||||
|
binding = mocker.Mock()
|
||||||
|
binding.state = GatewayBindingState.PENDING
|
||||||
|
binding.pairing_code_expires_at = datetime.now(UTC) + timedelta(minutes=1)
|
||||||
|
scalars = mocker.Mock()
|
||||||
|
scalars.first.return_value = binding
|
||||||
|
result = mocker.Mock()
|
||||||
|
result.scalars.return_value = scalars
|
||||||
|
session = mocker.AsyncMock()
|
||||||
|
session.execute.return_value = result
|
||||||
|
|
||||||
|
redeemed = await redeem_pairing_code(
|
||||||
|
session,
|
||||||
|
code="abc",
|
||||||
|
external_peer_id="telegram:123",
|
||||||
|
external_peer_kind="direct",
|
||||||
|
external_display_name="Anish",
|
||||||
|
external_username="anish",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert redeemed is binding
|
||||||
|
assert binding.state == GatewayBindingState.BOUND
|
||||||
|
assert binding.external_peer_id == "telegram:123"
|
||||||
|
assert binding.pairing_code is None
|
||||||
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue