mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-02 19:55:18 +02:00
feat(gateway): process inbound events through the agent
This commit is contained in:
parent
967ec099c8
commit
b8538655bb
4 changed files with 518 additions and 0 deletions
80
surfsense_backend/app/gateway/agent_invoke.py
Normal file
80
surfsense_backend/app/gateway/agent_invoke.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
"""Invoke SurfSense chat agent for gateway channels."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import GatewayConversationBinding
|
||||
from app.gateway.auth_invariant import assert_authorization_invariant
|
||||
from app.gateway.base.translator import GatewayStreamEvent
|
||||
from app.gateway.bindings import get_or_create_thread_for_binding
|
||||
from app.gateway.hitl_filter import DEFAULT_HITL_TOOL_NAMES
|
||||
from app.gateway.telegram.translator import TelegramStreamTranslator
|
||||
from app.gateway.thread_lock import acquire_thread_lock, release_thread_lock
|
||||
from app.observability.metrics import record_gateway_turn_latency
|
||||
from app.tasks.chat.stream_new_chat import stream_new_chat
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _events_from_sse(chunks: AsyncIterator[str]) -> AsyncIterator[GatewayStreamEvent]:
|
||||
async for chunk in chunks:
|
||||
for raw_line in chunk.splitlines():
|
||||
line = raw_line.strip()
|
||||
if not line.startswith("data:"):
|
||||
continue
|
||||
payload = line.removeprefix("data:").strip()
|
||||
if payload == "[DONE]":
|
||||
yield GatewayStreamEvent(type="done")
|
||||
continue
|
||||
try:
|
||||
data = json.loads(payload)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
event_type = str(data.get("type") or "")
|
||||
if event_type == "text-delta":
|
||||
yield GatewayStreamEvent(type="text-delta", data={"delta": data.get("delta", "")})
|
||||
elif event_type == "text-end":
|
||||
yield GatewayStreamEvent(type="text-end", data=data)
|
||||
elif event_type == "finish":
|
||||
yield GatewayStreamEvent(type="finish", data=data)
|
||||
elif event_type == "data-interrupt-request":
|
||||
yield GatewayStreamEvent(type="data-interrupt-request", data=data)
|
||||
|
||||
|
||||
async def call_agent_for_gateway(
|
||||
*,
|
||||
session: AsyncSession,
|
||||
binding: GatewayConversationBinding,
|
||||
user_text: str,
|
||||
translator: TelegramStreamTranslator,
|
||||
request_id: str | None = None,
|
||||
) -> None:
|
||||
user = await assert_authorization_invariant(session, binding)
|
||||
thread = await get_or_create_thread_for_binding(session, binding)
|
||||
await session.commit()
|
||||
|
||||
if not acquire_thread_lock(thread.id):
|
||||
raise RuntimeError("gateway_thread_busy")
|
||||
|
||||
try:
|
||||
stream = stream_new_chat(
|
||||
user_query=user_text,
|
||||
search_space_id=binding.search_space_id,
|
||||
chat_id=thread.id,
|
||||
user_id=str(user.id),
|
||||
needs_history_bootstrap=thread.needs_history_bootstrap,
|
||||
thread_visibility=thread.visibility,
|
||||
current_user_display_name=user.display_name or "A team member",
|
||||
disabled_tools=sorted(DEFAULT_HITL_TOOL_NAMES),
|
||||
request_id=request_id or "gateway",
|
||||
)
|
||||
await translator.translate(_events_from_sse(stream))
|
||||
record_gateway_turn_latency(0, platform="telegram")
|
||||
finally:
|
||||
release_thread_lock(thread.id)
|
||||
|
||||
262
surfsense_backend/app/gateway/inbox_processor.py
Normal file
262
surfsense_backend/app/gateway/inbox_processor.py
Normal file
|
|
@ -0,0 +1,262 @@
|
|||
"""Long-lived gateway inbox processing.
|
||||
|
||||
This module owns the agent-turn execution path for messaging gateways. It is
|
||||
intentionally independent of Celery so LangGraph, async Postgres, Redis, and
|
||||
Telegram clients all run on one stable event loop in ``GatewayRunner``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
GatewayBindingState,
|
||||
GatewayConversationBinding,
|
||||
GatewayEventStatus,
|
||||
GatewayInboundEvent,
|
||||
GatewayPeerKind,
|
||||
GatewayPlatformAccount,
|
||||
NewChatThread,
|
||||
async_session_maker,
|
||||
)
|
||||
from app.gateway.accounts import account_token
|
||||
from app.gateway.agent_invoke import call_agent_for_gateway
|
||||
from app.gateway.bindings import get_or_create_thread_for_binding
|
||||
from app.gateway.telegram.adapter import TelegramAdapter
|
||||
from app.gateway.telegram.commands import (
|
||||
command_name,
|
||||
handle_help_command,
|
||||
handle_start_command,
|
||||
send_unbound_onboarding,
|
||||
)
|
||||
from app.gateway.telegram.translator import TelegramStreamTranslator
|
||||
from app.observability.metrics import record_gateway_inbox_processed
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SessionMaker = async_sessionmaker[AsyncSession] | Callable[[], AsyncSession]
|
||||
|
||||
|
||||
def _dashboard_url() -> str:
|
||||
return config.NEXT_FRONTEND_URL or "/dashboard"
|
||||
|
||||
|
||||
async def claim_next_inbound_event(
|
||||
session_maker: SessionMaker = async_session_maker,
|
||||
) -> int | None:
|
||||
"""Claim the oldest received inbox event for processing."""
|
||||
|
||||
async with session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(GatewayInboundEvent)
|
||||
.where(GatewayInboundEvent.status == GatewayEventStatus.RECEIVED)
|
||||
.order_by(GatewayInboundEvent.received_at.asc())
|
||||
.with_for_update(skip_locked=True)
|
||||
.limit(1)
|
||||
)
|
||||
event = result.scalars().first()
|
||||
if event is None:
|
||||
return None
|
||||
event.status = GatewayEventStatus.PROCESSING
|
||||
event.attempt_count += 1
|
||||
await session.commit()
|
||||
return int(event.id)
|
||||
|
||||
|
||||
async def process_inbound_event(
|
||||
inbox_id: int,
|
||||
session_maker: SessionMaker = async_session_maker,
|
||||
) -> None:
|
||||
"""Process one gateway inbox row and mark its terminal status."""
|
||||
|
||||
async with session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(GatewayInboundEvent)
|
||||
.where(GatewayInboundEvent.id == inbox_id)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
event = result.scalars().first()
|
||||
if event is None or event.status in {
|
||||
GatewayEventStatus.PROCESSED,
|
||||
GatewayEventStatus.IGNORED,
|
||||
}:
|
||||
return
|
||||
if event.status == GatewayEventStatus.RECEIVED:
|
||||
event.status = GatewayEventStatus.PROCESSING
|
||||
event.attempt_count += 1
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
await _dispatch_inbound_event(inbox_id, session_maker)
|
||||
except RuntimeError as exc:
|
||||
if str(exc) == "gateway_thread_busy":
|
||||
async with session_maker() as session:
|
||||
await session.execute(
|
||||
update(GatewayInboundEvent)
|
||||
.where(GatewayInboundEvent.id == inbox_id)
|
||||
.values(
|
||||
status=GatewayEventStatus.RECEIVED,
|
||||
last_error="gateway_thread_busy",
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
return
|
||||
await _mark_failed(inbox_id, str(exc), session_maker)
|
||||
raise
|
||||
except Exception as exc:
|
||||
await _mark_failed(inbox_id, str(exc), session_maker)
|
||||
raise
|
||||
|
||||
async with session_maker() as session:
|
||||
event = await session.get(GatewayInboundEvent, inbox_id)
|
||||
if event is not None and event.status == GatewayEventStatus.PROCESSING:
|
||||
event.status = GatewayEventStatus.PROCESSED
|
||||
event.processed_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
record_gateway_inbox_processed(platform=event.platform.value, status="processed")
|
||||
|
||||
|
||||
async def _mark_failed(
|
||||
inbox_id: int,
|
||||
error: str,
|
||||
session_maker: SessionMaker,
|
||||
) -> None:
|
||||
async with session_maker() as session:
|
||||
await session.execute(
|
||||
update(GatewayInboundEvent)
|
||||
.where(GatewayInboundEvent.id == inbox_id)
|
||||
.values(status=GatewayEventStatus.FAILED, last_error=error)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def _dispatch_inbound_event(
|
||||
inbox_id: int,
|
||||
session_maker: SessionMaker,
|
||||
) -> None:
|
||||
async with session_maker() as session:
|
||||
event = await session.get(GatewayInboundEvent, inbox_id)
|
||||
if event is None:
|
||||
return
|
||||
account = await session.get(GatewayPlatformAccount, event.account_id)
|
||||
if account is None:
|
||||
event.status = GatewayEventStatus.IGNORED
|
||||
event.last_error = "account_missing"
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
token = account_token(account)
|
||||
if not token:
|
||||
event.status = GatewayEventStatus.FAILED
|
||||
event.last_error = "missing_telegram_token"
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
adapter = TelegramAdapter(token)
|
||||
parsed = adapter.parse_inbound(event.raw_payload or {})
|
||||
if parsed.external_peer_id is None:
|
||||
event.status = GatewayEventStatus.IGNORED
|
||||
event.last_error = "missing_external_peer_id"
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
_update_account_cursor(account, parsed.metadata.get("update_id"))
|
||||
|
||||
result = await session.execute(
|
||||
select(GatewayConversationBinding).where(
|
||||
GatewayConversationBinding.account_id == account.id,
|
||||
GatewayConversationBinding.external_peer_id == parsed.external_peer_id,
|
||||
GatewayConversationBinding.state.in_(
|
||||
[GatewayBindingState.BOUND, GatewayBindingState.SUSPENDED]
|
||||
),
|
||||
)
|
||||
)
|
||||
binding = result.scalars().first()
|
||||
|
||||
if parsed.external_peer_kind != GatewayPeerKind.DIRECT.value:
|
||||
await adapter.leave_chat(external_peer_id=parsed.external_peer_id)
|
||||
event.status = GatewayEventStatus.IGNORED
|
||||
event.last_error = "group_rejected"
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
cmd = command_name(parsed.text)
|
||||
if cmd == "/start":
|
||||
handled = await handle_start_command(
|
||||
session=session, adapter=adapter, event=parsed
|
||||
)
|
||||
await session.commit()
|
||||
if handled:
|
||||
return
|
||||
|
||||
if binding is None:
|
||||
await send_unbound_onboarding(
|
||||
adapter=adapter,
|
||||
event=parsed,
|
||||
dashboard_url=_dashboard_url(),
|
||||
)
|
||||
event.status = GatewayEventStatus.IGNORED
|
||||
event.last_error = "unbound_chat"
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
event.binding_id = binding.id
|
||||
|
||||
if cmd == "/help":
|
||||
await handle_help_command(adapter=adapter, event=parsed)
|
||||
event.status = GatewayEventStatus.PROCESSED
|
||||
await session.commit()
|
||||
return
|
||||
if cmd == "/new":
|
||||
binding.active_thread_id = None
|
||||
await adapter.send_message(
|
||||
external_peer_id=parsed.external_peer_id,
|
||||
text="Started a new SurfSense conversation.",
|
||||
)
|
||||
event.status = GatewayEventStatus.PROCESSED
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
if not parsed.text:
|
||||
event.status = GatewayEventStatus.IGNORED
|
||||
event.last_error = "empty_message"
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
thread = await get_or_create_thread_for_binding(session, binding)
|
||||
await session.commit()
|
||||
|
||||
translator = TelegramStreamTranslator(
|
||||
adapter=adapter,
|
||||
external_peer_id=parsed.external_peer_id,
|
||||
)
|
||||
await call_agent_for_gateway(
|
||||
session=session,
|
||||
binding=binding,
|
||||
user_text=parsed.text,
|
||||
translator=translator,
|
||||
request_id=f"gateway:{inbox_id}",
|
||||
)
|
||||
|
||||
thread = await session.get(NewChatThread, thread.id)
|
||||
if thread is not None:
|
||||
thread.source = "telegram"
|
||||
await session.commit()
|
||||
|
||||
|
||||
def _update_account_cursor(account: GatewayPlatformAccount, update_id: object) -> None:
|
||||
if update_id is None:
|
||||
return
|
||||
account.cursor_state = {
|
||||
**(account.cursor_state or {}),
|
||||
"last_update_id": max(
|
||||
int((account.cursor_state or {}).get("last_update_id", 0)),
|
||||
int(update_id),
|
||||
),
|
||||
}
|
||||
136
surfsense_backend/app/gateway/ratelimit.py
Normal file
136
surfsense_backend/app/gateway/ratelimit.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
"""Redis token-bucket rate limiter for gateway outbound traffic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from app.config import config
|
||||
from app.observability.metrics import record_gateway_redis_fallback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TOKEN_BUCKET_LUA = """
|
||||
local capacity = tonumber(ARGV[1])
|
||||
local refill_rate = tonumber(ARGV[2])
|
||||
local now = tonumber(ARGV[3])
|
||||
local consume = tonumber(ARGV[4])
|
||||
|
||||
local bucket = redis.call('HMGET', KEYS[1], 'tokens', 'last_refill')
|
||||
local tokens = tonumber(bucket[1]) or capacity
|
||||
local last_refill = tonumber(bucket[2]) or now
|
||||
|
||||
local elapsed = math.max(0, now - last_refill)
|
||||
tokens = math.min(capacity, tokens + (elapsed * refill_rate))
|
||||
|
||||
if tokens >= consume then
|
||||
tokens = tokens - consume
|
||||
redis.call('HMSET', KEYS[1], 'tokens', tokens, 'last_refill', now)
|
||||
redis.call('EXPIRE', KEYS[1], 3600)
|
||||
return 0
|
||||
else
|
||||
redis.call('HMSET', KEYS[1], 'tokens', tokens, 'last_refill', now)
|
||||
redis.call('EXPIRE', KEYS[1], 3600)
|
||||
local needed = consume - tokens
|
||||
return math.ceil((needed / refill_rate) * 1000)
|
||||
end
|
||||
"""
|
||||
|
||||
_redis_client: aioredis.Redis | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _MemoryBucket:
|
||||
tokens: float
|
||||
last_refill: float
|
||||
|
||||
|
||||
_memory_buckets: dict[str, _MemoryBucket] = {}
|
||||
_memory_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def _redis() -> aioredis.Redis:
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = aioredis.from_url(config.REDIS_APP_URL, decode_responses=True)
|
||||
return _redis_client
|
||||
|
||||
|
||||
async def _memory_fallback_acquire(
|
||||
scope: str,
|
||||
capacity: int,
|
||||
refill_per_sec: float,
|
||||
consume: float,
|
||||
) -> int:
|
||||
now = time.time()
|
||||
async with _memory_lock:
|
||||
bucket = _memory_buckets.get(scope)
|
||||
if bucket is None:
|
||||
bucket = _MemoryBucket(tokens=float(capacity), last_refill=now)
|
||||
_memory_buckets[scope] = bucket
|
||||
|
||||
elapsed = max(0.0, now - bucket.last_refill)
|
||||
bucket.tokens = min(float(capacity), bucket.tokens + elapsed * refill_per_sec)
|
||||
bucket.last_refill = now
|
||||
|
||||
if bucket.tokens >= consume:
|
||||
bucket.tokens -= consume
|
||||
return 0
|
||||
|
||||
needed = consume - bucket.tokens
|
||||
return int((needed / refill_per_sec) * 1000) if refill_per_sec > 0 else 1000
|
||||
|
||||
|
||||
async def acquire_token(
|
||||
scope: str,
|
||||
*,
|
||||
capacity: int,
|
||||
refill_per_sec: float,
|
||||
consume: float = 1.0,
|
||||
) -> int:
|
||||
"""Return 0 if allowed, otherwise milliseconds to wait.
|
||||
|
||||
Redis is the primary coordination mechanism. If Redis is unavailable,
|
||||
fall back to per-process memory so the gateway degrades instead of failing
|
||||
closed during a short Redis outage.
|
||||
"""
|
||||
|
||||
redis_key = f"gateway:bucket:{scope}"
|
||||
try:
|
||||
wait_ms = await _redis().eval(
|
||||
_TOKEN_BUCKET_LUA,
|
||||
1,
|
||||
redis_key,
|
||||
capacity,
|
||||
refill_per_sec,
|
||||
time.time(),
|
||||
consume,
|
||||
)
|
||||
return int(wait_ms)
|
||||
except (aioredis.RedisError, OSError) as exc:
|
||||
logger.warning("Redis rate limiter unavailable; using memory fallback: %s", exc)
|
||||
record_gateway_redis_fallback()
|
||||
return await _memory_fallback_acquire(scope, capacity, refill_per_sec, consume)
|
||||
|
||||
|
||||
async def wait_for_token(
|
||||
scope: str,
|
||||
*,
|
||||
capacity: int,
|
||||
refill_per_sec: float,
|
||||
consume: float = 1.0,
|
||||
) -> int:
|
||||
wait_ms = await acquire_token(
|
||||
scope,
|
||||
capacity=capacity,
|
||||
refill_per_sec=refill_per_sec,
|
||||
consume=consume,
|
||||
)
|
||||
if wait_ms > 0:
|
||||
await asyncio.sleep(wait_ms / 1000)
|
||||
return wait_ms
|
||||
|
||||
40
surfsense_backend/app/gateway/thread_lock.py
Normal file
40
surfsense_backend/app/gateway/thread_lock.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
"""Redis-backed distributed locks for gateway conversation turns."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import redis
|
||||
|
||||
from app.config import config
|
||||
from app.observability.metrics import record_gateway_thread_lock_contention
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_redis_client: redis.Redis | None = None
|
||||
|
||||
|
||||
def _redis() -> redis.Redis:
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = redis.from_url(config.REDIS_APP_URL, decode_responses=True)
|
||||
return _redis_client
|
||||
|
||||
|
||||
def _lock_key(thread_id: int) -> str:
|
||||
return f"gateway:thread_lock:{thread_id}"
|
||||
|
||||
|
||||
def acquire_thread_lock(thread_id: int, ttl: int = 60) -> bool:
|
||||
acquired = bool(_redis().set(_lock_key(thread_id), "1", nx=True, ex=ttl))
|
||||
if not acquired:
|
||||
record_gateway_thread_lock_contention()
|
||||
return acquired
|
||||
|
||||
|
||||
def release_thread_lock(thread_id: int) -> None:
|
||||
try:
|
||||
_redis().delete(_lock_key(thread_id))
|
||||
except redis.RedisError as exc:
|
||||
logger.warning("Failed to release gateway thread lock for %s: %s", thread_id, exc)
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue