fix(gateway): preserve request context during inbox processing

This commit is contained in:
Anish Sarkar 2026-05-28 04:38:20 +05:30
parent 08bf3cc023
commit afcadfb4bf
2 changed files with 60 additions and 54 deletions

View file

@ -1,4 +1,4 @@
"""Invoke SurfSense chat agent for gateway channels.""" """Invoke SurfSense chat agent for external chat surfaces."""
from __future__ import annotations from __future__ import annotations
@ -6,9 +6,10 @@ import json
import logging import logging
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import GatewayConversationBinding from app.db import ExternalChatBinding, NewChatMessage
from app.gateway.auth_invariant import assert_authorization_invariant from app.gateway.auth_invariant import assert_authorization_invariant
from app.gateway.base.translator import GatewayStreamEvent from app.gateway.base.translator import GatewayStreamEvent
from app.gateway.bindings import get_or_create_thread_for_binding from app.gateway.bindings import get_or_create_thread_for_binding
@ -55,7 +56,7 @@ async def _events_from_sse(chunks: AsyncIterator[str]) -> AsyncIterator[GatewayS
async def call_agent_for_gateway( async def call_agent_for_gateway(
*, *,
session: AsyncSession, session: AsyncSession,
binding: GatewayConversationBinding, binding: ExternalChatBinding,
user_text: str, user_text: str,
translator: TelegramStreamTranslator, translator: TelegramStreamTranslator,
request_id: str | None = None, request_id: str | None = None,
@ -85,6 +86,12 @@ async def call_agent_for_gateway(
finally: finally:
await events.aclose() await events.aclose()
await stream.aclose() await stream.aclose()
await session.execute(
update(NewChatMessage)
.where(NewChatMessage.thread_id == thread.id, NewChatMessage.source == "web")
.values(source="telegram")
)
await session.commit()
record_gateway_turn_latency(0, platform="telegram") record_gateway_turn_latency(0, platform="telegram")
finally: finally:
release_thread_lock(thread.id) release_thread_lock(thread.id)

View file

@ -1,8 +1,7 @@
"""Long-lived gateway inbox processing. """Long-lived external chat inbox processing.
This module owns the agent-turn execution path for messaging gateways. It is This module owns the agent-turn execution path for external chat surfaces.
intentionally independent of Celery so LangGraph, async Postgres, Redis, and FastAPI calls into it after webhook and BYO long-poll intake persist inbox rows.
Telegram clients all run on one stable event loop in ``GatewayRunner``.
""" """
from __future__ import annotations from __future__ import annotations
@ -16,12 +15,12 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.config import config from app.config import config
from app.db import ( from app.db import (
GatewayBindingState, ExternalChatBindingState,
GatewayConversationBinding, ExternalChatBinding,
GatewayEventStatus, ExternalChatEventStatus,
GatewayInboundEvent, ExternalChatInboundEvent,
GatewayPeerKind, ExternalChatPeerKind,
GatewayPlatformAccount, ExternalChatAccount,
NewChatThread, NewChatThread,
async_session_maker, async_session_maker,
) )
@ -54,16 +53,16 @@ async def claim_next_inbound_event(
async with session_maker() as session: async with session_maker() as session:
result = await session.execute( result = await session.execute(
select(GatewayInboundEvent) select(ExternalChatInboundEvent)
.where(GatewayInboundEvent.status == GatewayEventStatus.RECEIVED) .where(ExternalChatInboundEvent.status == ExternalChatEventStatus.RECEIVED)
.order_by(GatewayInboundEvent.received_at.asc()) .order_by(ExternalChatInboundEvent.received_at.asc())
.with_for_update(skip_locked=True) .with_for_update(skip_locked=True)
.limit(1) .limit(1)
) )
event = result.scalars().first() event = result.scalars().first()
if event is None: if event is None:
return None return None
event.status = GatewayEventStatus.PROCESSING event.status = ExternalChatEventStatus.PROCESSING
event.attempt_count += 1 event.attempt_count += 1
await session.commit() await session.commit()
return int(event.id) return int(event.id)
@ -73,22 +72,22 @@ async def process_inbound_event(
inbox_id: int, inbox_id: int,
session_maker: SessionMaker = async_session_maker, session_maker: SessionMaker = async_session_maker,
) -> None: ) -> None:
"""Process one gateway inbox row and mark its terminal status.""" """Process one external chat inbox row and mark its terminal status."""
async with session_maker() as session: async with session_maker() as session:
result = await session.execute( result = await session.execute(
select(GatewayInboundEvent) select(ExternalChatInboundEvent)
.where(GatewayInboundEvent.id == inbox_id) .where(ExternalChatInboundEvent.id == inbox_id)
.with_for_update(skip_locked=True) .with_for_update(skip_locked=True)
) )
event = result.scalars().first() event = result.scalars().first()
if event is None or event.status in { if event is None or event.status in {
GatewayEventStatus.PROCESSED, ExternalChatEventStatus.PROCESSED,
GatewayEventStatus.IGNORED, ExternalChatEventStatus.IGNORED,
}: }:
return return
if event.status == GatewayEventStatus.RECEIVED: if event.status == ExternalChatEventStatus.RECEIVED:
event.status = GatewayEventStatus.PROCESSING event.status = ExternalChatEventStatus.PROCESSING
event.attempt_count += 1 event.attempt_count += 1
await session.commit() await session.commit()
@ -98,15 +97,15 @@ async def process_inbound_event(
if str(exc) == "gateway_thread_busy": if str(exc) == "gateway_thread_busy":
async with session_maker() as session: async with session_maker() as session:
await session.execute( await session.execute(
update(GatewayInboundEvent) update(ExternalChatInboundEvent)
.where(GatewayInboundEvent.id == inbox_id) .where(ExternalChatInboundEvent.id == inbox_id)
.values( .values(
status=GatewayEventStatus.RECEIVED, status=ExternalChatEventStatus.RECEIVED,
last_error="gateway_thread_busy", last_error="gateway_thread_busy",
) )
) )
await session.commit() await session.commit()
return raise
await _mark_failed(inbox_id, str(exc), session_maker) await _mark_failed(inbox_id, str(exc), session_maker)
raise raise
except Exception as exc: except Exception as exc:
@ -114,9 +113,9 @@ async def process_inbound_event(
raise raise
async with session_maker() as session: async with session_maker() as session:
event = await session.get(GatewayInboundEvent, inbox_id) event = await session.get(ExternalChatInboundEvent, inbox_id)
if event is not None and event.status == GatewayEventStatus.PROCESSING: if event is not None and event.status == ExternalChatEventStatus.PROCESSING:
event.status = GatewayEventStatus.PROCESSED event.status = ExternalChatEventStatus.PROCESSED
event.processed_at = datetime.now(UTC) event.processed_at = datetime.now(UTC)
await session.commit() await session.commit()
record_gateway_inbox_processed(platform=event.platform.value, status="processed") record_gateway_inbox_processed(platform=event.platform.value, status="processed")
@ -129,9 +128,9 @@ async def _mark_failed(
) -> None: ) -> None:
async with session_maker() as session: async with session_maker() as session:
await session.execute( await session.execute(
update(GatewayInboundEvent) update(ExternalChatInboundEvent)
.where(GatewayInboundEvent.id == inbox_id) .where(ExternalChatInboundEvent.id == inbox_id)
.values(status=GatewayEventStatus.FAILED, last_error=error) .values(status=ExternalChatEventStatus.FAILED, last_error=error)
) )
await session.commit() await session.commit()
@ -141,19 +140,19 @@ async def _dispatch_inbound_event(
session_maker: SessionMaker, session_maker: SessionMaker,
) -> None: ) -> None:
async with session_maker() as session: async with session_maker() as session:
event = await session.get(GatewayInboundEvent, inbox_id) event = await session.get(ExternalChatInboundEvent, inbox_id)
if event is None: if event is None:
return return
account = await session.get(GatewayPlatformAccount, event.account_id) account = await session.get(ExternalChatAccount, event.account_id)
if account is None: if account is None:
event.status = GatewayEventStatus.IGNORED event.status = ExternalChatEventStatus.IGNORED
event.last_error = "account_missing" event.last_error = "account_missing"
await session.commit() await session.commit()
return return
token = account_token(account) token = account_token(account)
if not token: if not token:
event.status = GatewayEventStatus.FAILED event.status = ExternalChatEventStatus.FAILED
event.last_error = "missing_telegram_token" event.last_error = "missing_telegram_token"
await session.commit() await session.commit()
return return
@ -161,7 +160,7 @@ async def _dispatch_inbound_event(
adapter = TelegramAdapter(token) adapter = TelegramAdapter(token)
parsed = adapter.parse_inbound(event.raw_payload or {}) parsed = adapter.parse_inbound(event.raw_payload or {})
if parsed.external_peer_id is None: if parsed.external_peer_id is None:
event.status = GatewayEventStatus.IGNORED event.status = ExternalChatEventStatus.IGNORED
event.last_error = "missing_external_peer_id" event.last_error = "missing_external_peer_id"
await session.commit() await session.commit()
return return
@ -169,19 +168,19 @@ async def _dispatch_inbound_event(
_update_account_cursor(account, parsed.metadata.get("update_id")) _update_account_cursor(account, parsed.metadata.get("update_id"))
result = await session.execute( result = await session.execute(
select(GatewayConversationBinding).where( select(ExternalChatBinding).where(
GatewayConversationBinding.account_id == account.id, ExternalChatBinding.account_id == account.id,
GatewayConversationBinding.external_peer_id == parsed.external_peer_id, ExternalChatBinding.external_peer_id == parsed.external_peer_id,
GatewayConversationBinding.state.in_( ExternalChatBinding.state.in_(
[GatewayBindingState.BOUND, GatewayBindingState.SUSPENDED] [ExternalChatBindingState.BOUND, ExternalChatBindingState.SUSPENDED]
), ),
) )
) )
binding = result.scalars().first() binding = result.scalars().first()
if parsed.external_peer_kind != GatewayPeerKind.DIRECT.value: if parsed.external_peer_kind != ExternalChatPeerKind.DIRECT.value:
await adapter.leave_chat(external_peer_id=parsed.external_peer_id) await adapter.leave_chat(external_peer_id=parsed.external_peer_id)
event.status = GatewayEventStatus.IGNORED event.status = ExternalChatEventStatus.IGNORED
event.last_error = "group_rejected" event.last_error = "group_rejected"
await session.commit() await session.commit()
return return
@ -201,30 +200,30 @@ async def _dispatch_inbound_event(
event=parsed, event=parsed,
dashboard_url=_dashboard_url(), dashboard_url=_dashboard_url(),
) )
event.status = GatewayEventStatus.IGNORED event.status = ExternalChatEventStatus.IGNORED
event.last_error = "unbound_chat" event.last_error = "unbound_chat"
await session.commit() await session.commit()
return return
event.binding_id = binding.id event.external_chat_binding_id = binding.id
if cmd == "/help": if cmd == "/help":
await handle_help_command(adapter=adapter, event=parsed) await handle_help_command(adapter=adapter, event=parsed)
event.status = GatewayEventStatus.PROCESSED event.status = ExternalChatEventStatus.PROCESSED
await session.commit() await session.commit()
return return
if cmd == "/new": if cmd == "/new":
binding.active_thread_id = None binding.new_chat_thread_id = None
await adapter.send_message( await adapter.send_message(
external_peer_id=parsed.external_peer_id, external_peer_id=parsed.external_peer_id,
text="Started a new SurfSense conversation.", text="Started a new SurfSense conversation.",
) )
event.status = GatewayEventStatus.PROCESSED event.status = ExternalChatEventStatus.PROCESSED
await session.commit() await session.commit()
return return
if not parsed.text: if not parsed.text:
event.status = GatewayEventStatus.IGNORED event.status = ExternalChatEventStatus.IGNORED
event.last_error = "empty_message" event.last_error = "empty_message"
await session.commit() await session.commit()
return return
@ -241,7 +240,7 @@ async def _dispatch_inbound_event(
binding=binding, binding=binding,
user_text=parsed.text, user_text=parsed.text,
translator=translator, translator=translator,
request_id=f"gateway:{inbox_id}", request_id=event.request_id or f"gateway:{inbox_id}",
) )
thread = await session.get(NewChatThread, thread.id) thread = await session.get(NewChatThread, thread.id)
@ -250,7 +249,7 @@ async def _dispatch_inbound_event(
await session.commit() await session.commit()
def _update_account_cursor(account: GatewayPlatformAccount, update_id: object) -> None: def _update_account_cursor(account: ExternalChatAccount, update_id: object) -> None:
if update_id is None: if update_id is None:
return return
account.cursor_state = { account.cursor_state = {