feat(gateway): process inbound events through the agent

This commit is contained in:
Anish Sarkar 2026-05-27 23:38:52 +05:30
parent 967ec099c8
commit b8538655bb
4 changed files with 518 additions and 0 deletions

View file

@ -0,0 +1,80 @@
"""Invoke SurfSense chat agent for gateway channels."""
from __future__ import annotations
import json
import logging
from collections.abc import AsyncIterator
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import GatewayConversationBinding
from app.gateway.auth_invariant import assert_authorization_invariant
from app.gateway.base.translator import GatewayStreamEvent
from app.gateway.bindings import get_or_create_thread_for_binding
from app.gateway.hitl_filter import DEFAULT_HITL_TOOL_NAMES
from app.gateway.telegram.translator import TelegramStreamTranslator
from app.gateway.thread_lock import acquire_thread_lock, release_thread_lock
from app.observability.metrics import record_gateway_turn_latency
from app.tasks.chat.stream_new_chat import stream_new_chat
logger = logging.getLogger(__name__)
async def _events_from_sse(chunks: AsyncIterator[str]) -> AsyncIterator[GatewayStreamEvent]:
async for chunk in chunks:
for raw_line in chunk.splitlines():
line = raw_line.strip()
if not line.startswith("data:"):
continue
payload = line.removeprefix("data:").strip()
if payload == "[DONE]":
yield GatewayStreamEvent(type="done")
continue
try:
data = json.loads(payload)
except json.JSONDecodeError:
continue
event_type = str(data.get("type") or "")
if event_type == "text-delta":
yield GatewayStreamEvent(type="text-delta", data={"delta": data.get("delta", "")})
elif event_type == "text-end":
yield GatewayStreamEvent(type="text-end", data=data)
elif event_type == "finish":
yield GatewayStreamEvent(type="finish", data=data)
elif event_type == "data-interrupt-request":
yield GatewayStreamEvent(type="data-interrupt-request", data=data)
async def call_agent_for_gateway(
*,
session: AsyncSession,
binding: GatewayConversationBinding,
user_text: str,
translator: TelegramStreamTranslator,
request_id: str | None = None,
) -> None:
user = await assert_authorization_invariant(session, binding)
thread = await get_or_create_thread_for_binding(session, binding)
await session.commit()
if not acquire_thread_lock(thread.id):
raise RuntimeError("gateway_thread_busy")
try:
stream = stream_new_chat(
user_query=user_text,
search_space_id=binding.search_space_id,
chat_id=thread.id,
user_id=str(user.id),
needs_history_bootstrap=thread.needs_history_bootstrap,
thread_visibility=thread.visibility,
current_user_display_name=user.display_name or "A team member",
disabled_tools=sorted(DEFAULT_HITL_TOOL_NAMES),
request_id=request_id or "gateway",
)
await translator.translate(_events_from_sse(stream))
record_gateway_turn_latency(0, platform="telegram")
finally:
release_thread_lock(thread.id)

View file

@ -0,0 +1,262 @@
"""Long-lived gateway inbox processing.
This module owns the agent-turn execution path for messaging gateways. It is
intentionally independent of Celery so LangGraph, async Postgres, Redis, and
Telegram clients all run on one stable event loop in ``GatewayRunner``.
"""
from __future__ import annotations
import logging
from collections.abc import Callable
from datetime import UTC, datetime
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.config import config
from app.db import (
GatewayBindingState,
GatewayConversationBinding,
GatewayEventStatus,
GatewayInboundEvent,
GatewayPeerKind,
GatewayPlatformAccount,
NewChatThread,
async_session_maker,
)
from app.gateway.accounts import account_token
from app.gateway.agent_invoke import call_agent_for_gateway
from app.gateway.bindings import get_or_create_thread_for_binding
from app.gateway.telegram.adapter import TelegramAdapter
from app.gateway.telegram.commands import (
command_name,
handle_help_command,
handle_start_command,
send_unbound_onboarding,
)
from app.gateway.telegram.translator import TelegramStreamTranslator
from app.observability.metrics import record_gateway_inbox_processed
logger = logging.getLogger(__name__)
SessionMaker = async_sessionmaker[AsyncSession] | Callable[[], AsyncSession]
def _dashboard_url() -> str:
return config.NEXT_FRONTEND_URL or "/dashboard"
async def claim_next_inbound_event(
session_maker: SessionMaker = async_session_maker,
) -> int | None:
"""Claim the oldest received inbox event for processing."""
async with session_maker() as session:
result = await session.execute(
select(GatewayInboundEvent)
.where(GatewayInboundEvent.status == GatewayEventStatus.RECEIVED)
.order_by(GatewayInboundEvent.received_at.asc())
.with_for_update(skip_locked=True)
.limit(1)
)
event = result.scalars().first()
if event is None:
return None
event.status = GatewayEventStatus.PROCESSING
event.attempt_count += 1
await session.commit()
return int(event.id)
async def process_inbound_event(
inbox_id: int,
session_maker: SessionMaker = async_session_maker,
) -> None:
"""Process one gateway inbox row and mark its terminal status."""
async with session_maker() as session:
result = await session.execute(
select(GatewayInboundEvent)
.where(GatewayInboundEvent.id == inbox_id)
.with_for_update(skip_locked=True)
)
event = result.scalars().first()
if event is None or event.status in {
GatewayEventStatus.PROCESSED,
GatewayEventStatus.IGNORED,
}:
return
if event.status == GatewayEventStatus.RECEIVED:
event.status = GatewayEventStatus.PROCESSING
event.attempt_count += 1
await session.commit()
try:
await _dispatch_inbound_event(inbox_id, session_maker)
except RuntimeError as exc:
if str(exc) == "gateway_thread_busy":
async with session_maker() as session:
await session.execute(
update(GatewayInboundEvent)
.where(GatewayInboundEvent.id == inbox_id)
.values(
status=GatewayEventStatus.RECEIVED,
last_error="gateway_thread_busy",
)
)
await session.commit()
return
await _mark_failed(inbox_id, str(exc), session_maker)
raise
except Exception as exc:
await _mark_failed(inbox_id, str(exc), session_maker)
raise
async with session_maker() as session:
event = await session.get(GatewayInboundEvent, inbox_id)
if event is not None and event.status == GatewayEventStatus.PROCESSING:
event.status = GatewayEventStatus.PROCESSED
event.processed_at = datetime.now(UTC)
await session.commit()
record_gateway_inbox_processed(platform=event.platform.value, status="processed")
async def _mark_failed(
inbox_id: int,
error: str,
session_maker: SessionMaker,
) -> None:
async with session_maker() as session:
await session.execute(
update(GatewayInboundEvent)
.where(GatewayInboundEvent.id == inbox_id)
.values(status=GatewayEventStatus.FAILED, last_error=error)
)
await session.commit()
async def _dispatch_inbound_event(
inbox_id: int,
session_maker: SessionMaker,
) -> None:
async with session_maker() as session:
event = await session.get(GatewayInboundEvent, inbox_id)
if event is None:
return
account = await session.get(GatewayPlatformAccount, event.account_id)
if account is None:
event.status = GatewayEventStatus.IGNORED
event.last_error = "account_missing"
await session.commit()
return
token = account_token(account)
if not token:
event.status = GatewayEventStatus.FAILED
event.last_error = "missing_telegram_token"
await session.commit()
return
adapter = TelegramAdapter(token)
parsed = adapter.parse_inbound(event.raw_payload or {})
if parsed.external_peer_id is None:
event.status = GatewayEventStatus.IGNORED
event.last_error = "missing_external_peer_id"
await session.commit()
return
_update_account_cursor(account, parsed.metadata.get("update_id"))
result = await session.execute(
select(GatewayConversationBinding).where(
GatewayConversationBinding.account_id == account.id,
GatewayConversationBinding.external_peer_id == parsed.external_peer_id,
GatewayConversationBinding.state.in_(
[GatewayBindingState.BOUND, GatewayBindingState.SUSPENDED]
),
)
)
binding = result.scalars().first()
if parsed.external_peer_kind != GatewayPeerKind.DIRECT.value:
await adapter.leave_chat(external_peer_id=parsed.external_peer_id)
event.status = GatewayEventStatus.IGNORED
event.last_error = "group_rejected"
await session.commit()
return
cmd = command_name(parsed.text)
if cmd == "/start":
handled = await handle_start_command(
session=session, adapter=adapter, event=parsed
)
await session.commit()
if handled:
return
if binding is None:
await send_unbound_onboarding(
adapter=adapter,
event=parsed,
dashboard_url=_dashboard_url(),
)
event.status = GatewayEventStatus.IGNORED
event.last_error = "unbound_chat"
await session.commit()
return
event.binding_id = binding.id
if cmd == "/help":
await handle_help_command(adapter=adapter, event=parsed)
event.status = GatewayEventStatus.PROCESSED
await session.commit()
return
if cmd == "/new":
binding.active_thread_id = None
await adapter.send_message(
external_peer_id=parsed.external_peer_id,
text="Started a new SurfSense conversation.",
)
event.status = GatewayEventStatus.PROCESSED
await session.commit()
return
if not parsed.text:
event.status = GatewayEventStatus.IGNORED
event.last_error = "empty_message"
await session.commit()
return
thread = await get_or_create_thread_for_binding(session, binding)
await session.commit()
translator = TelegramStreamTranslator(
adapter=adapter,
external_peer_id=parsed.external_peer_id,
)
await call_agent_for_gateway(
session=session,
binding=binding,
user_text=parsed.text,
translator=translator,
request_id=f"gateway:{inbox_id}",
)
thread = await session.get(NewChatThread, thread.id)
if thread is not None:
thread.source = "telegram"
await session.commit()
def _update_account_cursor(account: GatewayPlatformAccount, update_id: object) -> None:
if update_id is None:
return
account.cursor_state = {
**(account.cursor_state or {}),
"last_update_id": max(
int((account.cursor_state or {}).get("last_update_id", 0)),
int(update_id),
),
}

View file

@ -0,0 +1,136 @@
"""Redis token-bucket rate limiter for gateway outbound traffic."""
from __future__ import annotations
import asyncio
import logging
import time
from dataclasses import dataclass
import redis.asyncio as aioredis
from app.config import config
from app.observability.metrics import record_gateway_redis_fallback
logger = logging.getLogger(__name__)
_TOKEN_BUCKET_LUA = """
local capacity = tonumber(ARGV[1])
local refill_rate = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local consume = tonumber(ARGV[4])
local bucket = redis.call('HMGET', KEYS[1], 'tokens', 'last_refill')
local tokens = tonumber(bucket[1]) or capacity
local last_refill = tonumber(bucket[2]) or now
local elapsed = math.max(0, now - last_refill)
tokens = math.min(capacity, tokens + (elapsed * refill_rate))
if tokens >= consume then
tokens = tokens - consume
redis.call('HMSET', KEYS[1], 'tokens', tokens, 'last_refill', now)
redis.call('EXPIRE', KEYS[1], 3600)
return 0
else
redis.call('HMSET', KEYS[1], 'tokens', tokens, 'last_refill', now)
redis.call('EXPIRE', KEYS[1], 3600)
local needed = consume - tokens
return math.ceil((needed / refill_rate) * 1000)
end
"""
_redis_client: aioredis.Redis | None = None
@dataclass
class _MemoryBucket:
tokens: float
last_refill: float
_memory_buckets: dict[str, _MemoryBucket] = {}
_memory_lock = asyncio.Lock()
def _redis() -> aioredis.Redis:
global _redis_client
if _redis_client is None:
_redis_client = aioredis.from_url(config.REDIS_APP_URL, decode_responses=True)
return _redis_client
async def _memory_fallback_acquire(
scope: str,
capacity: int,
refill_per_sec: float,
consume: float,
) -> int:
now = time.time()
async with _memory_lock:
bucket = _memory_buckets.get(scope)
if bucket is None:
bucket = _MemoryBucket(tokens=float(capacity), last_refill=now)
_memory_buckets[scope] = bucket
elapsed = max(0.0, now - bucket.last_refill)
bucket.tokens = min(float(capacity), bucket.tokens + elapsed * refill_per_sec)
bucket.last_refill = now
if bucket.tokens >= consume:
bucket.tokens -= consume
return 0
needed = consume - bucket.tokens
return int((needed / refill_per_sec) * 1000) if refill_per_sec > 0 else 1000
async def acquire_token(
scope: str,
*,
capacity: int,
refill_per_sec: float,
consume: float = 1.0,
) -> int:
"""Return 0 if allowed, otherwise milliseconds to wait.
Redis is the primary coordination mechanism. If Redis is unavailable,
fall back to per-process memory so the gateway degrades instead of failing
closed during a short Redis outage.
"""
redis_key = f"gateway:bucket:{scope}"
try:
wait_ms = await _redis().eval(
_TOKEN_BUCKET_LUA,
1,
redis_key,
capacity,
refill_per_sec,
time.time(),
consume,
)
return int(wait_ms)
except (aioredis.RedisError, OSError) as exc:
logger.warning("Redis rate limiter unavailable; using memory fallback: %s", exc)
record_gateway_redis_fallback()
return await _memory_fallback_acquire(scope, capacity, refill_per_sec, consume)
async def wait_for_token(
scope: str,
*,
capacity: int,
refill_per_sec: float,
consume: float = 1.0,
) -> int:
wait_ms = await acquire_token(
scope,
capacity=capacity,
refill_per_sec=refill_per_sec,
consume=consume,
)
if wait_ms > 0:
await asyncio.sleep(wait_ms / 1000)
return wait_ms

View file

@ -0,0 +1,40 @@
"""Redis-backed distributed locks for gateway conversation turns."""
from __future__ import annotations
import logging
import redis
from app.config import config
from app.observability.metrics import record_gateway_thread_lock_contention
logger = logging.getLogger(__name__)
_redis_client: redis.Redis | None = None
def _redis() -> redis.Redis:
global _redis_client
if _redis_client is None:
_redis_client = redis.from_url(config.REDIS_APP_URL, decode_responses=True)
return _redis_client
def _lock_key(thread_id: int) -> str:
return f"gateway:thread_lock:{thread_id}"
def acquire_thread_lock(thread_id: int, ttl: int = 60) -> bool:
acquired = bool(_redis().set(_lock_key(thread_id), "1", nx=True, ex=ttl))
if not acquired:
record_gateway_thread_lock_contention()
return acquired
def release_thread_lock(thread_id: int) -> None:
try:
_redis().delete(_lock_key(thread_id))
except redis.RedisError as exc:
logger.warning("Failed to release gateway thread lock for %s: %s", thread_id, exc)