mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-10 20:35:17 +02:00
Merge remote-tracking branch 'upstream/dev' into fix/changelogs
This commit is contained in:
commit
3f21d5fdd6
196 changed files with 8405 additions and 5757 deletions
|
|
@ -185,7 +185,6 @@ STRIPE_PAGE_BUYING_ENABLED=FALSE
|
|||
# STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
||||
# STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
||||
# STRIPE_CREDIT_MICROS_PER_UNIT=1000000
|
||||
# DEPRECATED: STRIPE_TOKENS_PER_UNIT=1000000
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# TTS & STT (Text-to-Speech / Speech-to-Text)
|
||||
|
|
@ -340,9 +339,6 @@ STT_SERVICE=local/base
|
|||
# External API Keys (optional)
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Firecrawl (web scraping)
|
||||
# FIRECRAWL_API_KEY=
|
||||
|
||||
# Unstructured (if ETL_SERVICE=UNSTRUCTURED)
|
||||
# UNSTRUCTURED_API_KEY=
|
||||
|
||||
|
|
@ -418,7 +414,6 @@ SURFSENSE_ENABLE_DOOM_LOOP=true
|
|||
# Premium turns are debited at the actual per-call provider cost reported
|
||||
# by LiteLLM. Only applies to models with billing_tier=premium.
|
||||
# PREMIUM_CREDIT_MICROS_LIMIT=5000000
|
||||
# DEPRECATED: PREMIUM_TOKEN_LIMIT=5000000
|
||||
|
||||
# Safety ceiling on per-call premium reservation, in micro-USD ($1.00 default).
|
||||
# QUOTA_MAX_RESERVE_MICROS=1000000
|
||||
|
|
|
|||
|
|
@ -3,18 +3,25 @@ DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense
|
|||
# Deployment environment: dev or production
|
||||
SURFSENSE_ENV=dev
|
||||
|
||||
#Celery Config
|
||||
CELERY_BROKER_URL=redis://localhost:6379/0
|
||||
CELERY_RESULT_BACKEND=redis://localhost:6379/0
|
||||
# Redis (single endpoint for Celery broker/result backend + app features)
|
||||
REDIS_URL=redis://localhost:6379/0
|
||||
# Optional: override individually only to split Redis across instances.
|
||||
# Each defaults to REDIS_URL when unset.
|
||||
# CELERY_BROKER_URL=redis://localhost:6379/0
|
||||
# CELERY_RESULT_BACKEND=redis://localhost:6379/0
|
||||
# REDIS_APP_URL=redis://localhost:6379/0
|
||||
# Optional: isolate queues when sharing Redis with other apps
|
||||
CELERY_TASK_DEFAULT_QUEUE=surfsense
|
||||
|
||||
# Redis for app-level features (heartbeats, podcast markers)
|
||||
# Defaults to CELERY_BROKER_URL when not set
|
||||
REDIS_APP_URL=redis://localhost:6379/0
|
||||
# Optional: TTL in seconds for connector indexing lock key
|
||||
# CONNECTOR_INDEXING_LOCK_TTL_SECONDS=28800
|
||||
|
||||
# Messaging Gateway (global)
|
||||
# GATEWAY_ENABLED: master switch for ALL messaging gateway channels (Telegram, WhatsApp,
|
||||
# Slack, Discord). When FALSE, no gateway background workers/supervisors start and all
|
||||
# gateway HTTP routes (webhooks, OAuth callbacks, pairing) return 404. Set per-channel
|
||||
# flags below to control individual platforms once the gateway is enabled.
|
||||
GATEWAY_ENABLED=TRUE
|
||||
|
||||
# Telegram Gateway
|
||||
# TELEGRAM_WEBHOOK_SECRET must be 1-256 chars and contain only A-Z, a-z, 0-9, _ or -
|
||||
# GATEWAY_TELEGRAM_INTAKE_MODE: `webhook` for production, `longpoll` for single-replica self-host fallback, `disabled` to skip Telegram intake
|
||||
|
|
@ -85,8 +92,6 @@ STRIPE_PAGE_BUYING_ENABLED=TRUE
|
|||
STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
||||
STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
||||
STRIPE_CREDIT_MICROS_PER_UNIT=1000000
|
||||
# DEPRECATED — use STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping):
|
||||
# STRIPE_TOKENS_PER_UNIT=1000000
|
||||
|
||||
# Periodic Stripe safety net for purchases left in PENDING (minutes old)
|
||||
STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
||||
|
|
@ -225,8 +230,6 @@ PAGES_LIMIT=500
|
|||
# models bill proportionally. Applies only to models with
|
||||
# billing_tier=premium in global_llm_config.yaml.
|
||||
PREMIUM_CREDIT_MICROS_LIMIT=5000000
|
||||
# DEPRECATED — use PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping):
|
||||
# PREMIUM_TOKEN_LIMIT=5000000
|
||||
|
||||
# Safety ceiling on per-call premium reservation, in micro-USD.
|
||||
# stream_new_chat estimates an upper-bound cost from the model's
|
||||
|
|
@ -274,17 +277,19 @@ TURNSTILE_ENABLED=FALSE
|
|||
TURNSTILE_SECRET_KEY=
|
||||
|
||||
|
||||
# Proxy provider selection. Selects a ProxyProvider implementation registered in
|
||||
# app/utils/proxy/registry.py. Default: "anonymous_proxies". Add new vendors there.
|
||||
# PROXY_PROVIDER=anonymous_proxies
|
||||
|
||||
# Residential Proxy Configuration (anonymous-proxies.net)
|
||||
# Used for web crawling, link previews, and YouTube transcript fetching to avoid IP bans.
|
||||
# Leave commented out to disable proxying.
|
||||
# Consumed by the "anonymous_proxies" provider. Leave commented out to disable proxying.
|
||||
# RESIDENTIAL_PROXY_USERNAME=your_proxy_username
|
||||
# RESIDENTIAL_PROXY_PASSWORD=your_proxy_password
|
||||
# RESIDENTIAL_PROXY_HOSTNAME=rotating.dnsproxifier.com:31230
|
||||
# RESIDENTIAL_PROXY_LOCATION=
|
||||
# RESIDENTIAL_PROXY_TYPE=1
|
||||
|
||||
FIRECRAWL_API_KEY=fcr-01J0000000000000000000000
|
||||
|
||||
# File Parser Service
|
||||
ETL_SERVICE=UNSTRUCTURED or LLAMACLOUD or DOCLING
|
||||
UNSTRUCTURED_API_KEY=Tpu3P0U8iy
|
||||
|
|
@ -357,6 +362,13 @@ LANGSMITH_PROJECT=surfsense
|
|||
# SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS=false
|
||||
# SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE=false
|
||||
|
||||
# KB retrieval mode (default OFF = lazy). When OFF, the main agent retrieves
|
||||
# KB content on demand via the `search_knowledge_base` tool and skips the
|
||||
# expensive per-turn pre-injection (planner LLM + embed + hybrid search,
|
||||
# ~2.3s); explicit @-mentions are still surfaced cheaply. Set to true to
|
||||
# restore the original eager `<priority_documents>` pre-injection.
|
||||
# SURFSENSE_ENABLE_KB_PRIORITY_PREINJECTION=false
|
||||
|
||||
# Snapshot / revert
|
||||
# SURFSENSE_ENABLE_ACTION_LOG=false
|
||||
# SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships
|
||||
|
|
@ -377,6 +389,15 @@ LANGSMITH_PROJECT=surfsense
|
|||
# rollback if you suspect cache-related staleness.
|
||||
# SURFSENSE_ENABLE_AGENT_CACHE=true
|
||||
|
||||
# Cross-thread reuse (default ON). Drops thread_id from the cache key so a
|
||||
# returning user's NEW chats (same user + search space + config + visibility)
|
||||
# hit the already-compiled graph instead of paying a fresh ~4-5s compile —
|
||||
# turning a cold first turn into a warm one. Safe because ActionLog,
|
||||
# KB-persistence, and the deliverables tools now resolve the chat thread from
|
||||
# the live RunnableConfig at call time rather than a build-time closure. Flip
|
||||
# OFF to fall back to a per-thread cache key (instant rollback).
|
||||
# SURFSENSE_ENABLE_CROSS_THREAD_AGENT_CACHE=true
|
||||
|
||||
# Cache capacity (max number of compiled-agent entries kept in memory)
|
||||
# and TTL per entry (seconds). Working set is typically one entry per
|
||||
# active thread on this replica; tune up for very large deployments.
|
||||
|
|
|
|||
|
|
@ -109,8 +109,10 @@ RUN --mount=type=secret,id=HF_TOKEN \
|
|||
HF_TOKEN="$(cat /run/secrets/HF_TOKEN 2>/dev/null || true)" \
|
||||
python -c "from chonkie import AutoEmbeddings; AutoEmbeddings.get_embeddings('${EMBEDDING_MODEL}')"
|
||||
|
||||
# Install Playwright browsers (the playwright python package itself is in deps)
|
||||
RUN playwright install chromium --with-deps
|
||||
# Install Scrapling's browser engines (patchright Chromium + Camoufox).
|
||||
# Scrapling pulls playwright/patchright via the `fetchers` extra; `scrapling install`
|
||||
# downloads the matching browser binaries used by DynamicFetcher/StealthyFetcher.
|
||||
RUN scrapling install
|
||||
|
||||
# Shared temp directory for file uploads between API and Worker containers.
|
||||
# Python's tempfile module uses TMPDIR, so uploaded files land here.
|
||||
|
|
|
|||
|
|
@ -165,9 +165,7 @@ def downgrade() -> None:
|
|||
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
|
||||
with tx:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
f"COMMENT ON PUBLICATION {PUBLICATION_NAME} IS 'pre-148-downgrade'"
|
||||
)
|
||||
sa.text(f"COMMENT ON PUBLICATION {PUBLICATION_NAME} IS 'pre-148-downgrade'")
|
||||
)
|
||||
conn.execute(sa.text(ddl))
|
||||
conn.execute(
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ AUTOMATION_RUN_COLS = [
|
|||
"created_at",
|
||||
]
|
||||
|
||||
|
||||
def _has_zero_version(conn, table: str) -> bool:
|
||||
return (
|
||||
conn.execute(
|
||||
|
|
@ -190,7 +191,8 @@ def upgrade() -> None:
|
|||
"external_chat_peer_kind", ("direct", "group", "channel", "unknown")
|
||||
)
|
||||
external_chat_event_kind_enum = _create_enum(
|
||||
"external_chat_event_kind", ("message", "edited_message", "callback_query", "other")
|
||||
"external_chat_event_kind",
|
||||
("message", "edited_message", "callback_query", "other"),
|
||||
)
|
||||
external_chat_event_status_enum = _create_enum(
|
||||
"external_chat_event_status",
|
||||
|
|
@ -205,7 +207,12 @@ def upgrade() -> None:
|
|||
sa.Column("mode", external_chat_account_mode_enum, nullable=False),
|
||||
sa.Column("owner_user_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("owner_search_space_id", sa.Integer(), nullable=True),
|
||||
sa.Column("is_system_account", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.Column(
|
||||
"is_system_account",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
sa.Column("encrypted_credentials", sa.Text(), nullable=True),
|
||||
sa.Column("bot_username", sa.String(255), nullable=True),
|
||||
sa.Column("webhook_secret", sa.String(64), nullable=True),
|
||||
|
|
@ -221,7 +228,9 @@ def upgrade() -> None:
|
|||
nullable=False,
|
||||
server_default="unknown",
|
||||
),
|
||||
sa.Column("last_health_check_at", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"last_health_check_at", sa.TIMESTAMP(timezone=True), nullable=True
|
||||
),
|
||||
sa.Column("suspended_at", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
sa.Column("suspended_reason", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
|
|
@ -285,7 +294,9 @@ def upgrade() -> None:
|
|||
server_default="pending",
|
||||
),
|
||||
sa.Column("pairing_code", sa.Text(), nullable=True),
|
||||
sa.Column("pairing_code_expires_at", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"pairing_code_expires_at", sa.TIMESTAMP(timezone=True), nullable=True
|
||||
),
|
||||
sa.Column("external_peer_id", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"external_peer_kind",
|
||||
|
|
@ -327,7 +338,9 @@ def upgrade() -> None:
|
|||
["account_id"], ["external_chat_accounts.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_space_id"], ["searchspaces.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["new_chat_thread_id"], ["new_chat_threads.id"], ondelete="SET NULL"
|
||||
),
|
||||
|
|
@ -386,7 +399,9 @@ def upgrade() -> None:
|
|||
nullable=False,
|
||||
server_default="received",
|
||||
),
|
||||
sa.Column("attempt_count", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column(
|
||||
"attempt_count", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
sa.Column("last_error", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"received_at",
|
||||
|
|
@ -405,7 +420,9 @@ def upgrade() -> None:
|
|||
["account_id"], ["external_chat_accounts.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["external_chat_binding_id"], ["external_chat_bindings.id"], ondelete="SET NULL"
|
||||
["external_chat_binding_id"],
|
||||
["external_chat_bindings.id"],
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"account_id",
|
||||
|
|
@ -445,7 +462,9 @@ def upgrade() -> None:
|
|||
sa.Column("external_chat_binding_id", sa.BigInteger(), nullable=True),
|
||||
)
|
||||
if not _constraint_exists(
|
||||
conn, "new_chat_threads", "fk_new_chat_threads_external_chat_external_chat_binding_id"
|
||||
conn,
|
||||
"new_chat_threads",
|
||||
"fk_new_chat_threads_external_chat_external_chat_binding_id",
|
||||
):
|
||||
op.create_foreign_key(
|
||||
"fk_new_chat_threads_external_chat_external_chat_binding_id",
|
||||
|
|
@ -455,7 +474,9 @@ def upgrade() -> None:
|
|||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
op.create_index("ix_new_chat_threads_source", "new_chat_threads", ["source"], if_not_exists=True)
|
||||
op.create_index(
|
||||
"ix_new_chat_threads_source", "new_chat_threads", ["source"], if_not_exists=True
|
||||
)
|
||||
op.create_index(
|
||||
"ix_new_chat_threads_external_chat_binding_id",
|
||||
"new_chat_threads",
|
||||
|
|
@ -472,7 +493,11 @@ def upgrade() -> None:
|
|||
if not _column_exists(conn, "new_chat_messages", "platform_metadata"):
|
||||
op.add_column(
|
||||
"new_chat_messages",
|
||||
sa.Column("platform_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column(
|
||||
"platform_metadata",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_new_chat_messages_source",
|
||||
|
|
@ -553,11 +578,15 @@ def downgrade() -> None:
|
|||
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
|
||||
with tx:
|
||||
conn.execute(
|
||||
sa.text(f"COMMENT ON PUBLICATION {PUBLICATION_NAME} IS 'pre-144-downgrade'")
|
||||
sa.text(
|
||||
f"COMMENT ON PUBLICATION {PUBLICATION_NAME} IS 'pre-144-downgrade'"
|
||||
)
|
||||
)
|
||||
conn.execute(sa.text(ddl))
|
||||
conn.execute(
|
||||
sa.text(f"COMMENT ON PUBLICATION {PUBLICATION_NAME} IS 'post-144-downgrade'")
|
||||
sa.text(
|
||||
f"COMMENT ON PUBLICATION {PUBLICATION_NAME} IS 'post-144-downgrade'"
|
||||
)
|
||||
)
|
||||
|
||||
if _column_exists(conn, "new_chat_messages", "source"):
|
||||
|
|
@ -567,10 +596,14 @@ def downgrade() -> None:
|
|||
_drop_column_if_exists("new_chat_messages", "platform_metadata")
|
||||
_drop_column_if_exists("new_chat_messages", "source")
|
||||
|
||||
_drop_index_if_exists("ix_new_chat_threads_external_chat_binding_id", "new_chat_threads")
|
||||
_drop_index_if_exists(
|
||||
"ix_new_chat_threads_external_chat_binding_id", "new_chat_threads"
|
||||
)
|
||||
_drop_index_if_exists("ix_new_chat_threads_source", "new_chat_threads")
|
||||
if _constraint_exists(
|
||||
conn, "new_chat_threads", "fk_new_chat_threads_external_chat_external_chat_binding_id"
|
||||
conn,
|
||||
"new_chat_threads",
|
||||
"fk_new_chat_threads_external_chat_external_chat_binding_id",
|
||||
):
|
||||
op.drop_constraint(
|
||||
"fk_new_chat_threads_external_chat_external_chat_binding_id",
|
||||
|
|
@ -583,8 +616,12 @@ def downgrade() -> None:
|
|||
_drop_index_if_exists(
|
||||
"ix_external_chat_inbound_binding_received_at", "external_chat_inbound_events"
|
||||
)
|
||||
_drop_index_if_exists("ix_external_chat_inbound_request_id", "external_chat_inbound_events")
|
||||
_drop_index_if_exists("ix_external_chat_inbound_status_received_at", "external_chat_inbound_events")
|
||||
_drop_index_if_exists(
|
||||
"ix_external_chat_inbound_request_id", "external_chat_inbound_events"
|
||||
)
|
||||
_drop_index_if_exists(
|
||||
"ix_external_chat_inbound_status_received_at", "external_chat_inbound_events"
|
||||
)
|
||||
if _table_exists(conn, "external_chat_inbound_events"):
|
||||
op.drop_table("external_chat_inbound_events")
|
||||
|
||||
|
|
@ -606,9 +643,15 @@ def downgrade() -> None:
|
|||
if _table_exists(conn, "external_chat_bindings"):
|
||||
op.drop_table("external_chat_bindings")
|
||||
|
||||
_drop_index_if_exists("uq_external_chat_accounts_system_platform", "external_chat_accounts")
|
||||
_drop_index_if_exists("uq_external_chat_accounts_owner_platform", "external_chat_accounts")
|
||||
_drop_index_if_exists("uq_external_chat_accounts_webhook_secret", "external_chat_accounts")
|
||||
_drop_index_if_exists(
|
||||
"uq_external_chat_accounts_system_platform", "external_chat_accounts"
|
||||
)
|
||||
_drop_index_if_exists(
|
||||
"uq_external_chat_accounts_owner_platform", "external_chat_accounts"
|
||||
)
|
||||
_drop_index_if_exists(
|
||||
"uq_external_chat_accounts_webhook_secret", "external_chat_accounts"
|
||||
)
|
||||
if _table_exists(conn, "external_chat_accounts"):
|
||||
op.drop_table("external_chat_accounts")
|
||||
|
||||
|
|
|
|||
|
|
@ -63,8 +63,7 @@ def upgrade() -> None:
|
|||
"ON document_files(search_space_id);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_document_files_kind "
|
||||
"ON document_files(kind);"
|
||||
"CREATE INDEX IF NOT EXISTS ix_document_files_kind ON document_files(kind);"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_document_files_created_by_id "
|
||||
|
|
|
|||
|
|
@ -68,8 +68,12 @@ def _has_zero_version(conn, table: str) -> bool:
|
|||
|
||||
|
||||
def _set_table_ddl(*, with_automation_runs: bool, conn) -> str:
|
||||
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if _has_zero_version(conn, "documents") else [])
|
||||
user_cols = USER_COLS + (['"_0_version"'] if _has_zero_version(conn, "user") else [])
|
||||
doc_cols = DOCUMENT_COLS + (
|
||||
['"_0_version"'] if _has_zero_version(conn, "documents") else []
|
||||
)
|
||||
user_cols = USER_COLS + (
|
||||
['"_0_version"'] if _has_zero_version(conn, "user") else []
|
||||
)
|
||||
tables = [
|
||||
"notifications",
|
||||
f"documents ({', '.join(doc_cols)})",
|
||||
|
|
@ -96,9 +100,17 @@ def _resync(*, with_automation_runs: bool, tag: str) -> None:
|
|||
|
||||
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
|
||||
with tx:
|
||||
conn.execute(sa.text(f"COMMENT ON PUBLICATION {PUBLICATION_NAME} IS 'pre-{tag}'"))
|
||||
conn.execute(sa.text(_set_table_ddl(with_automation_runs=with_automation_runs, conn=conn)))
|
||||
conn.execute(sa.text(f"COMMENT ON PUBLICATION {PUBLICATION_NAME} IS 'post-{tag}'"))
|
||||
conn.execute(
|
||||
sa.text(f"COMMENT ON PUBLICATION {PUBLICATION_NAME} IS 'pre-{tag}'")
|
||||
)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
_set_table_ddl(with_automation_runs=with_automation_runs, conn=conn)
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
sa.text(f"COMMENT ON PUBLICATION {PUBLICATION_NAME} IS 'post-{tag}'")
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
|
|
|||
|
|
@ -67,8 +67,12 @@ def _has_zero_version(conn, table: str) -> bool:
|
|||
|
||||
|
||||
def _set_table_ddl(conn) -> str:
|
||||
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if _has_zero_version(conn, "documents") else [])
|
||||
user_cols = USER_COLS + (['"_0_version"'] if _has_zero_version(conn, "user") else [])
|
||||
doc_cols = DOCUMENT_COLS + (
|
||||
['"_0_version"'] if _has_zero_version(conn, "documents") else []
|
||||
)
|
||||
user_cols = USER_COLS + (
|
||||
['"_0_version"'] if _has_zero_version(conn, "user") else []
|
||||
)
|
||||
tables = [
|
||||
"notifications",
|
||||
f"documents ({', '.join(doc_cols)})",
|
||||
|
|
@ -94,9 +98,13 @@ def _resync_zero_publication(tag: str) -> None:
|
|||
|
||||
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
|
||||
with tx:
|
||||
conn.execute(sa.text(f"COMMENT ON PUBLICATION {PUBLICATION_NAME} IS 'pre-{tag}'"))
|
||||
conn.execute(
|
||||
sa.text(f"COMMENT ON PUBLICATION {PUBLICATION_NAME} IS 'pre-{tag}'")
|
||||
)
|
||||
conn.execute(sa.text(_set_table_ddl(conn)))
|
||||
conn.execute(sa.text(f"COMMENT ON PUBLICATION {PUBLICATION_NAME} IS 'post-{tag}'"))
|
||||
conn.execute(
|
||||
sa.text(f"COMMENT ON PUBLICATION {PUBLICATION_NAME} IS 'post-{tag}'")
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
|
@ -117,7 +125,12 @@ def downgrade() -> None:
|
|||
if not _column_exists(conn, "searchspaces", "document_summary_llm_id"):
|
||||
op.add_column(
|
||||
"searchspaces",
|
||||
sa.Column("document_summary_llm_id", sa.Integer(), nullable=True, server_default="0"),
|
||||
sa.Column(
|
||||
"document_summary_llm_id",
|
||||
sa.Integer(),
|
||||
nullable=True,
|
||||
server_default="0",
|
||||
),
|
||||
)
|
||||
|
||||
if not _column_exists(conn, "search_source_connectors", "enable_summary"):
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -18,6 +19,9 @@ from app.agents.chat.multi_agent_chat.shared.feature_flags import AgentFeatureFl
|
|||
from app.agents.chat.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.chat.shared.context import SurfSenseContextSchema
|
||||
from app.db import ChatVisibility
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
def build_compiled_agent_graph_sync(
|
||||
|
|
@ -43,6 +47,7 @@ def build_compiled_agent_graph_sync(
|
|||
disabled_tools: list[str] | None = None,
|
||||
):
|
||||
"""Sync compile: middleware + ``create_agent`` (run via ``asyncio.to_thread``)."""
|
||||
mw_start = time.perf_counter()
|
||||
main_agent_middleware = build_main_agent_deepagent_middleware(
|
||||
llm=llm,
|
||||
tools=tools,
|
||||
|
|
@ -63,7 +68,9 @@ def build_compiled_agent_graph_sync(
|
|||
mcp_tools_by_agent=mcp_tools_by_agent,
|
||||
disabled_tools=disabled_tools,
|
||||
)
|
||||
mw_elapsed = time.perf_counter() - mw_start
|
||||
|
||||
create_start = time.perf_counter()
|
||||
agent = create_agent(
|
||||
llm,
|
||||
system_prompt=final_system_prompt,
|
||||
|
|
@ -72,6 +79,15 @@ def build_compiled_agent_graph_sync(
|
|||
context_schema=SurfSenseContextSchema,
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
create_elapsed = time.perf_counter() - create_start
|
||||
_perf_log.info(
|
||||
"[graph_compile] middleware_build=%.3fs main_create_agent=%.3fs "
|
||||
"total=%.3fs mw_count=%d",
|
||||
mw_elapsed,
|
||||
create_elapsed,
|
||||
mw_elapsed + create_elapsed,
|
||||
len(main_agent_middleware),
|
||||
)
|
||||
return agent.with_config(
|
||||
{
|
||||
"recursion_limit": 10_000,
|
||||
|
|
|
|||
|
|
@ -108,18 +108,32 @@ class ActionLogMiddleware(AgentMiddleware):
|
|||
self._user_id = user_id
|
||||
self._tool_definitions = dict(tool_definitions or {})
|
||||
|
||||
def _enabled(self) -> bool:
|
||||
def _enabled(self, thread_id: int | None) -> bool:
|
||||
flags = get_flags()
|
||||
if flags.disable_new_agent_stack:
|
||||
return False
|
||||
return bool(flags.enable_action_log) and self._thread_id is not None
|
||||
return bool(flags.enable_action_log) and thread_id is not None
|
||||
|
||||
def _resolve_thread_id(self, request: ToolCallRequest) -> int | None:
|
||||
"""Resolve the live thread id, preferring the runtime config.
|
||||
|
||||
Reading ``configurable.thread_id`` from the active ``RunnableConfig``
|
||||
(rather than the value captured at ``__init__``) lets a single cached
|
||||
compiled graph safely serve many threads — without it, a cache hit
|
||||
would attribute action-log rows to whichever thread first built the
|
||||
graph. Falls back to the constructor value for legacy/test runtimes
|
||||
that don't surface a config.
|
||||
"""
|
||||
resolved = _resolve_thread_id(request)
|
||||
return resolved if resolved is not None else self._thread_id
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
if not self._enabled():
|
||||
thread_id = self._resolve_thread_id(request)
|
||||
if not self._enabled(thread_id):
|
||||
return await handler(request)
|
||||
|
||||
result: ToolMessage | Command[Any]
|
||||
|
|
@ -134,10 +148,16 @@ class ActionLogMiddleware(AgentMiddleware):
|
|||
request=request,
|
||||
result=None,
|
||||
error_payload=error_payload,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
raise
|
||||
|
||||
await self._record(request=request, result=result, error_payload=None)
|
||||
await self._record(
|
||||
request=request,
|
||||
result=result,
|
||||
error_payload=None,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
return result
|
||||
|
||||
async def _record(
|
||||
|
|
@ -146,6 +166,7 @@ class ActionLogMiddleware(AgentMiddleware):
|
|||
request: ToolCallRequest,
|
||||
result: ToolMessage | Command[Any] | None,
|
||||
error_payload: dict[str, Any] | None,
|
||||
thread_id: int | None,
|
||||
) -> None:
|
||||
"""Persist one ``agent_action_log`` row. Defensive: never raises."""
|
||||
try:
|
||||
|
|
@ -164,7 +185,7 @@ class ActionLogMiddleware(AgentMiddleware):
|
|||
chat_turn_id = _resolve_chat_turn_id(request)
|
||||
|
||||
row = AgentActionLog(
|
||||
thread_id=self._thread_id,
|
||||
thread_id=thread_id,
|
||||
user_id=self._user_id,
|
||||
search_space_id=self._search_space_id,
|
||||
# ``turn_id`` is the deprecated alias of ``tool_call_id``
|
||||
|
|
@ -350,6 +371,36 @@ def _resolve_chat_turn_id(request: Any) -> str | None:
|
|||
return None
|
||||
|
||||
|
||||
def _resolve_thread_id(request: Any) -> int | None:
|
||||
"""Return ``configurable.thread_id`` (as int) for this request, if accessible.
|
||||
|
||||
Mirrors :func:`_resolve_chat_turn_id`: ``ToolRuntime.config`` is exposed by
|
||||
LangGraph at ``request.runtime.config``, and the chat thread id lives at
|
||||
``configurable.thread_id`` (a stringified ``chat_id`` at the main-graph
|
||||
level). Returns ``None`` when absent or unparseable so the caller can fall
|
||||
back to the constructor value.
|
||||
"""
|
||||
try:
|
||||
runtime = getattr(request, "runtime", None)
|
||||
if runtime is None:
|
||||
return None
|
||||
config = getattr(runtime, "config", None)
|
||||
if not isinstance(config, dict):
|
||||
return None
|
||||
configurable = config.get("configurable")
|
||||
if not isinstance(configurable, dict):
|
||||
return None
|
||||
value = configurable.get("thread_id")
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_message_id(request: Any) -> str | None:
|
||||
"""Tool-call IDs serve as best-available message correlator at this layer."""
|
||||
return _resolve_tool_call_id(request)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from deepagents.backends.protocol import BackendFactory, BackendProtocol
|
||||
|
|
@ -14,10 +15,12 @@ from deepagents.middleware.subagents import (
|
|||
)
|
||||
from langchain.agents import create_agent
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.runnables import Runnable
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.chat.multi_agent_chat.subagents.shared.spec import (
|
||||
SURF_CONTEXT_HINT_PROVIDER_KEY,
|
||||
SURF_LAZY_SPEC_FACTORY_KEY,
|
||||
)
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
|
|
@ -52,15 +55,32 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
|
|||
# switch keys on it so an operator can quarantine one workspace
|
||||
# without affecting the rest of the deployment.
|
||||
self._search_space_id = search_space_id
|
||||
subagent_specs = self._surf_compile_subagent_graphs()
|
||||
|
||||
# Lazy subagent compilation. Compiling a subagent graph via
|
||||
# ``create_agent`` is expensive (~250-400ms each) and there can be up
|
||||
# to ~17 of them. Doing it all in ``__init__`` put the full cost on
|
||||
# every cold ``agent_cache`` miss (i.e. on time-to-first-token), even
|
||||
# though a turn usually invokes zero or one subagent. We instead index
|
||||
# the raw specs here and compile each graph on first ``task(name)``
|
||||
# use, memoizing the result for the life of this (cached) instance.
|
||||
self._compiled: dict[str, Runnable] = {}
|
||||
self._lazy_specs: dict[str, dict[str, Any]] = {}
|
||||
# Subagents whose *spec itself* is built lazily (not just compiled).
|
||||
# Keyed by name → zero-arg factory returning the full spec dict. Used
|
||||
# for the write knowledge_base subagent, whose filesystem middleware
|
||||
# builds ~13 tool schemas (~2s) that almost never matter on turn 1.
|
||||
self._lazy_spec_factories: dict[str, Callable[[], dict[str, Any]]] = {}
|
||||
descriptors = self._build_subagent_registry()
|
||||
|
||||
task_tool = build_task_tool_with_parent_config(
|
||||
subagent_specs,
|
||||
descriptors,
|
||||
task_description,
|
||||
search_space_id=search_space_id,
|
||||
resolve_subagent=self._resolve_subagent,
|
||||
)
|
||||
if system_prompt and subagent_specs:
|
||||
if system_prompt and descriptors:
|
||||
agents_desc = "\n".join(
|
||||
f"- {s['name']}: {s['description']}" for s in subagent_specs
|
||||
f"- {s['name']}: {s['description']}" for s in descriptors
|
||||
)
|
||||
self.system_prompt = (
|
||||
system_prompt + "\n\nAvailable subagent types:\n" + agents_desc
|
||||
|
|
@ -69,84 +89,100 @@ class SurfSenseCheckpointedSubAgentMiddleware(SubAgentMiddleware):
|
|||
self.system_prompt = system_prompt
|
||||
self.tools = [task_tool]
|
||||
|
||||
def _surf_compile_subagent_graphs(self) -> list[dict[str, Any]]:
|
||||
"""Mirror of ``SubAgentMiddleware._get_subagents`` that threads the parent checkpointer."""
|
||||
specs: list[dict[str, Any]] = []
|
||||
loop_start = time.perf_counter()
|
||||
timings: list[tuple[str, float, str]] = [] # (name, elapsed, source)
|
||||
def _build_subagent_registry(self) -> list[dict[str, Any]]:
|
||||
"""Index subagents for lazy compilation; return lightweight descriptors.
|
||||
|
||||
Pre-compiled specs (those carrying a ``runnable``) are seeded directly
|
||||
into the memo. Lazy specs are stashed by name and compiled on first
|
||||
``task(...)`` use via :meth:`_resolve_subagent`. The returned
|
||||
descriptors carry only ``name``/``description`` plus the optional
|
||||
context-hint provider — everything the ``task`` tool needs to validate
|
||||
names, render its catalog, and run hints, without paying the
|
||||
``create_agent`` cost up front.
|
||||
"""
|
||||
descriptors: list[dict[str, Any]] = []
|
||||
for spec in self._subagents:
|
||||
spec_start = time.perf_counter()
|
||||
# Provider may be ``None`` (no hint), in which case task_tool
|
||||
# skips the prepend step. We forward the key unconditionally so
|
||||
# the registry shape is uniform.
|
||||
# Provider may be ``None`` (no hint), in which case task_tool skips
|
||||
# the prepend step. We forward the key unconditionally so the
|
||||
# descriptor shape is uniform.
|
||||
hint_provider = cast(dict, spec).get(SURF_CONTEXT_HINT_PROVIDER_KEY)
|
||||
if "runnable" in spec:
|
||||
name = spec["name"]
|
||||
spec_factory = cast(dict, spec).get(SURF_LAZY_SPEC_FACTORY_KEY)
|
||||
if spec_factory is not None:
|
||||
# Descriptor-only entry: the spec dict is built on first use.
|
||||
self._lazy_spec_factories[name] = spec_factory
|
||||
elif "runnable" in spec:
|
||||
compiled = cast(CompiledSubAgent, spec)
|
||||
specs.append(
|
||||
{
|
||||
"name": compiled["name"],
|
||||
"description": compiled["description"],
|
||||
"runnable": compiled["runnable"],
|
||||
SURF_CONTEXT_HINT_PROVIDER_KEY: hint_provider,
|
||||
}
|
||||
)
|
||||
timings.append(
|
||||
(compiled["name"], time.perf_counter() - spec_start, "precompiled")
|
||||
)
|
||||
continue
|
||||
|
||||
if "model" not in spec:
|
||||
msg = f"SubAgent '{spec['name']}' must specify 'model'"
|
||||
raise ValueError(msg)
|
||||
if "tools" not in spec:
|
||||
msg = f"SubAgent '{spec['name']}' must specify 'tools'"
|
||||
raise ValueError(msg)
|
||||
|
||||
model = spec["model"]
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
|
||||
middleware: list[Any] = list(spec.get("middleware", []))
|
||||
tools_count = len(spec.get("tools") or [])
|
||||
mw_count = len(middleware)
|
||||
|
||||
compile_start = time.perf_counter()
|
||||
runnable = create_agent(
|
||||
model,
|
||||
system_prompt=spec["system_prompt"],
|
||||
tools=spec["tools"],
|
||||
middleware=middleware,
|
||||
name=spec["name"],
|
||||
checkpointer=self._surf_checkpointer,
|
||||
)
|
||||
compile_elapsed = time.perf_counter() - compile_start
|
||||
specs.append(
|
||||
self._compiled[name] = compiled["runnable"]
|
||||
else:
|
||||
if "model" not in spec:
|
||||
msg = f"SubAgent '{name}' must specify 'model'"
|
||||
raise ValueError(msg)
|
||||
if "tools" not in spec:
|
||||
msg = f"SubAgent '{name}' must specify 'tools'"
|
||||
raise ValueError(msg)
|
||||
self._lazy_specs[name] = cast(dict, spec)
|
||||
descriptors.append(
|
||||
{
|
||||
"name": spec["name"],
|
||||
"name": name,
|
||||
"description": spec["description"],
|
||||
"runnable": runnable,
|
||||
SURF_CONTEXT_HINT_PROVIDER_KEY: hint_provider,
|
||||
}
|
||||
)
|
||||
timings.append(
|
||||
(
|
||||
spec["name"],
|
||||
compile_elapsed,
|
||||
f"compiled tools={tools_count} mw={mw_count}",
|
||||
)
|
||||
)
|
||||
return descriptors
|
||||
|
||||
total_elapsed = time.perf_counter() - loop_start
|
||||
per_subagent = ", ".join(
|
||||
f"{name}={elapsed * 1000:.0f}ms[{source}]"
|
||||
for name, elapsed, source in timings
|
||||
def _resolve_subagent(self, name: str) -> Runnable:
|
||||
"""Return the compiled subagent graph for ``name``, compiling on first use.
|
||||
|
||||
Memoized: the ``create_agent`` cost is paid once per subagent per
|
||||
cached middleware instance. Raises ``KeyError`` for unknown names
|
||||
(callers in the ``task`` tool validate membership before resolving).
|
||||
"""
|
||||
cached = self._compiled.get(name)
|
||||
if cached is not None:
|
||||
return cached
|
||||
spec = self._lazy_specs.get(name)
|
||||
if spec is None:
|
||||
factory = self._lazy_spec_factories.get(name)
|
||||
if factory is None:
|
||||
raise KeyError(name)
|
||||
# Build the spec on first use (pays the deferred construction cost
|
||||
# here, off the cold agent-build path), then compile and memoize.
|
||||
build_start = time.perf_counter()
|
||||
spec = factory()
|
||||
_perf_log.info(
|
||||
"[subagent_spec_lazy] name=%s (deferred spec build) in %.3fs",
|
||||
name,
|
||||
time.perf_counter() - build_start,
|
||||
)
|
||||
runnable = self._compile_one(spec)
|
||||
self._compiled[name] = runnable
|
||||
return runnable
|
||||
|
||||
def _compile_one(self, spec: dict[str, Any]) -> Runnable:
|
||||
"""Compile a single subagent graph against the parent checkpointer."""
|
||||
model = spec["model"]
|
||||
if isinstance(model, str):
|
||||
model = init_chat_model(model)
|
||||
|
||||
middleware: list[Any] = list(spec.get("middleware", []))
|
||||
tools_count = len(spec.get("tools") or [])
|
||||
mw_count = len(middleware)
|
||||
|
||||
compile_start = time.perf_counter()
|
||||
runnable = create_agent(
|
||||
model,
|
||||
system_prompt=spec["system_prompt"],
|
||||
tools=spec["tools"],
|
||||
middleware=middleware,
|
||||
name=spec["name"],
|
||||
checkpointer=self._surf_checkpointer,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[subagent_compile] total=%.3fs count=%d details=[%s]",
|
||||
total_elapsed,
|
||||
len(timings),
|
||||
per_subagent,
|
||||
"[subagent_compile_lazy] name=%s in %.3fs tools=%d mw=%d",
|
||||
spec["name"],
|
||||
time.perf_counter() - compile_start,
|
||||
tools_count,
|
||||
mw_count,
|
||||
)
|
||||
|
||||
return specs
|
||||
return runnable
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import asyncio
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated, Any, NoReturn, TypeVar
|
||||
|
||||
from deepagents.middleware.subagents import TASK_TOOL_DESCRIPTION
|
||||
|
|
@ -143,11 +143,28 @@ def build_task_tool_with_parent_config(
|
|||
task_description: str | None = None,
|
||||
*,
|
||||
search_space_id: int | None = None,
|
||||
resolve_subagent: Callable[[str], Runnable] | None = None,
|
||||
) -> BaseTool:
|
||||
"""Upstream ``_build_task_tool`` + parent ``runtime.config`` propagation + resume bridging."""
|
||||
subagent_graphs: dict[str, Runnable] = {
|
||||
spec["name"]: spec["runnable"] for spec in subagents
|
||||
}
|
||||
"""Upstream ``_build_task_tool`` + parent ``runtime.config`` propagation + resume bridging.
|
||||
|
||||
``subagents`` are lightweight descriptors (``name``/``description`` + the
|
||||
optional context-hint provider); the actual compiled graph is fetched
|
||||
lazily via ``resolve_subagent(name)`` so subagent ``create_agent`` cost is
|
||||
paid on first ``task(name)`` use rather than at graph-build time.
|
||||
|
||||
For backward compatibility (and tests), ``resolve_subagent`` may be omitted
|
||||
when every descriptor already carries a pre-compiled ``runnable``; in that
|
||||
case a trivial dict-backed resolver is used.
|
||||
"""
|
||||
subagent_names: set[str] = {spec["name"] for spec in subagents}
|
||||
if resolve_subagent is None:
|
||||
_eager_graphs: dict[str, Runnable] = {
|
||||
spec["name"]: spec["runnable"] for spec in subagents if "runnable" in spec
|
||||
}
|
||||
|
||||
def resolve_subagent(name: str) -> Runnable:
|
||||
return _eager_graphs[name]
|
||||
|
||||
# Sparse map of opt-in context-hint providers; each runs once per task()
|
||||
# call to prepend a string to the subagent's first HumanMessage. Failures
|
||||
# are swallowed so a broken hint never blocks the task.
|
||||
|
|
@ -329,7 +346,7 @@ def build_task_tool_with_parent_config(
|
|||
def _validate_and_prepare_state(
|
||||
subagent_type: str, description: str, runtime: ToolRuntime
|
||||
) -> tuple[Runnable, dict]:
|
||||
subagent = subagent_graphs[subagent_type]
|
||||
subagent = resolve_subagent(subagent_type)
|
||||
subagent_state = {
|
||||
k: v for k, v in runtime.state.items() if k not in EXCLUDED_STATE_KEYS
|
||||
}
|
||||
|
|
@ -442,8 +459,8 @@ def build_task_tool_with_parent_config(
|
|||
batched HITL is intentionally out of scope.
|
||||
"""
|
||||
async with semaphore:
|
||||
if subagent_type not in subagent_graphs:
|
||||
allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs])
|
||||
if subagent_type not in subagent_names:
|
||||
allowed_types = ", ".join([f"`{k}`" for k in subagent_names])
|
||||
return (
|
||||
task_index,
|
||||
subagent_type,
|
||||
|
|
@ -618,8 +635,8 @@ def build_task_tool_with_parent_config(
|
|||
"task: must provide either single-mode (`description`+`subagent_type`) "
|
||||
"or batch-mode (`tasks`)."
|
||||
)
|
||||
if subagent_type not in subagent_graphs:
|
||||
allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs])
|
||||
if subagent_type not in subagent_names:
|
||||
allowed_types = ", ".join([f"`{k}`" for k in subagent_names])
|
||||
return (
|
||||
f"We cannot invoke subagent {subagent_type} because it does not exist, "
|
||||
f"the only allowed types are {allowed_types}"
|
||||
|
|
@ -827,8 +844,8 @@ def build_task_tool_with_parent_config(
|
|||
subagent_type,
|
||||
runtime.tool_call_id,
|
||||
)
|
||||
if subagent_type not in subagent_graphs:
|
||||
allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs])
|
||||
if subagent_type not in subagent_names:
|
||||
allowed_types = ", ".join([f"`{k}`" for k in subagent_names])
|
||||
return (
|
||||
f"We cannot invoke subagent {subagent_type} because it does not exist, "
|
||||
f"the only allowed types are {allowed_types}"
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from typing import Any
|
|||
from fractional_indexing import generate_key_between
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
|
@ -1436,9 +1437,33 @@ class KnowledgeBasePersistenceMiddleware(AgentMiddleware): # type: ignore[type-
|
|||
search_space_id=self.search_space_id,
|
||||
created_by_id=self.created_by_id,
|
||||
filesystem_mode=self.filesystem_mode,
|
||||
thread_id=self.thread_id,
|
||||
thread_id=self._resolve_thread_id(),
|
||||
)
|
||||
|
||||
def _resolve_thread_id(self) -> int | None:
|
||||
"""Resolve the live thread id from the active ``RunnableConfig``.
|
||||
|
||||
``aafter_agent`` only receives a ``Runtime`` (which does NOT carry the
|
||||
config), so we read ``configurable.thread_id`` via
|
||||
:func:`langgraph.config.get_config` — the same node-context pattern used
|
||||
by ``BusyMutexMiddleware``. Resolving at runtime (rather than using the
|
||||
value captured at ``__init__``) lets one cached compiled graph commit
|
||||
staged writes against the correct thread across many chats. Falls back
|
||||
to the constructor value for legacy/test runtimes.
|
||||
"""
|
||||
try:
|
||||
config = get_config()
|
||||
except Exception:
|
||||
config = None
|
||||
if isinstance(config, dict):
|
||||
value = (config.get("configurable") or {}).get("thread_id")
|
||||
if value is not None:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
return self.thread_id
|
||||
|
||||
|
||||
__all__ = [
|
||||
"KnowledgeBasePersistenceMiddleware",
|
||||
|
|
|
|||
|
|
@ -19,7 +19,16 @@ def build_knowledge_priority_mw(
|
|||
available_connectors: list[str] | None,
|
||||
available_document_types: list[str] | None,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
preinjection_enabled: bool = True,
|
||||
) -> KnowledgePriorityMiddleware:
|
||||
"""Build the KB priority middleware.
|
||||
|
||||
When ``preinjection_enabled`` is False (the lazy default), the middleware
|
||||
runs in mentions-only mode: it skips the expensive planner LLM + embedding
|
||||
+ hybrid search and only surfaces explicit @-mentions. The main agent is
|
||||
expected to pull relevant KB content on demand via the
|
||||
``search_knowledge_base`` tool instead.
|
||||
"""
|
||||
return KnowledgePriorityMiddleware(
|
||||
llm=llm,
|
||||
planner_llm=get_planner_llm(),
|
||||
|
|
@ -29,4 +38,5 @@ def build_knowledge_priority_mw(
|
|||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
inject_system_message=False,
|
||||
mentions_only=not preinjection_enabled,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from langchain.agents.middleware import AgentMiddleware
|
|||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.observability import metrics as ot_metrics, otel as ot
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover — type-only
|
||||
from langchain.agents.middleware.types import (
|
||||
|
|
@ -34,6 +35,7 @@ if TYPE_CHECKING: # pragma: no cover — type-only
|
|||
from langgraph.types import Command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
class OtelSpanMiddleware(AgentMiddleware):
|
||||
|
|
@ -60,7 +62,23 @@ class OtelSpanMiddleware(AgentMiddleware):
|
|||
handler: Callable[[ModelRequest], Awaitable[ModelResponse | AIMessage | Any]],
|
||||
) -> ModelResponse | AIMessage | Any:
|
||||
if not ot.is_enabled():
|
||||
return await handler(request)
|
||||
# Always emit a [PERF] line for the model step even when OTel is
|
||||
# disabled. This isolates provider/model latency from the agent's
|
||||
# pre-flight (before_agent KB-priority/memory/tree) work, which is
|
||||
# the usual culprit when the multi-agent path feels slow to start.
|
||||
# ``perf_counter`` at entry doubles as the "before_agent finished /
|
||||
# model call started" marker on the first step of a turn.
|
||||
model_id, _provider = _resolve_model_attrs(request)
|
||||
_t0 = time.perf_counter()
|
||||
_perf_log.info("[model_call] start model=%s", model_id)
|
||||
try:
|
||||
return await handler(request)
|
||||
finally:
|
||||
_perf_log.info(
|
||||
"[model_call] done model=%s elapsed=%.3fs",
|
||||
model_id,
|
||||
time.perf_counter() - _t0,
|
||||
)
|
||||
|
||||
model_id, provider = _resolve_model_attrs(request)
|
||||
t0 = time.perf_counter()
|
||||
|
|
|
|||
|
|
@ -10,13 +10,15 @@ turn (cloud mode).
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from deepagents import SubAgent
|
||||
from deepagents.backends import StateBackend
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
|
|
@ -49,16 +51,25 @@ from app.agents.chat.multi_agent_chat.subagents import (
|
|||
get_subagents_to_exclude,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.subagents.builtins.knowledge_base.agent import (
|
||||
NAME as KB_WRITE_NAME,
|
||||
READONLY_NAME as KB_READONLY_NAME,
|
||||
build_readonly_subagent as build_kb_readonly_subagent,
|
||||
build_subagent as build_kb_write_subagent,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.subagents.builtins.knowledge_base.ask_knowledge_base_tool import (
|
||||
build_ask_knowledge_base_tool,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.subagents.builtins.knowledge_base.prompts import (
|
||||
load_description as load_kb_write_description,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.subagents.middleware_stack import (
|
||||
build_subagent_middleware_stack,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.subagents.shared.spec import (
|
||||
SURF_LAZY_SPEC_FACTORY_KEY,
|
||||
)
|
||||
from app.db import ChatVisibility
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
from .action_log import build_action_log_mw
|
||||
from .anonymous_document import build_anonymous_doc_mw
|
||||
|
|
@ -81,6 +92,8 @@ from .plugins import build_plugin_middlewares
|
|||
from .skills import build_skills_mw
|
||||
from .tool_call_repair import build_repair_mw
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
def build_main_agent_deepagent_middleware(
|
||||
*,
|
||||
|
|
@ -104,6 +117,7 @@ def build_main_agent_deepagent_middleware(
|
|||
disabled_tools: list[str] | None = None,
|
||||
) -> list[Any]:
|
||||
"""Ordered middleware for ``create_agent`` (None entries already stripped)."""
|
||||
stack_build_start = time.perf_counter()
|
||||
resilience = build_resilience_middlewares(flags)
|
||||
|
||||
memory_mw = build_memory_mw(
|
||||
|
|
@ -118,38 +132,98 @@ def build_main_agent_deepagent_middleware(
|
|||
"filesystem_mode": filesystem_mode,
|
||||
"flags": flags,
|
||||
}
|
||||
shared_mw_start = time.perf_counter()
|
||||
shared_subagent_middleware = build_subagent_middleware_stack(
|
||||
resilience=resilience,
|
||||
flags=flags,
|
||||
)
|
||||
shared_mw_elapsed = time.perf_counter() - shared_mw_start
|
||||
|
||||
kb_readonly = build_kb_readonly_subagent(
|
||||
dependencies=subagent_dependencies,
|
||||
model=llm,
|
||||
middleware_stack=shared_subagent_middleware,
|
||||
)
|
||||
kb_readonly_spec = kb_readonly.spec
|
||||
kb_readonly_runnable = create_agent(
|
||||
llm,
|
||||
system_prompt=kb_readonly_spec["system_prompt"],
|
||||
tools=kb_readonly_spec["tools"],
|
||||
middleware=kb_readonly_spec["middleware"],
|
||||
name=KB_READONLY_NAME,
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
ask_kb_tool = build_ask_knowledge_base_tool(kb_readonly_runnable)
|
||||
def _compile_kb_readonly() -> Runnable:
|
||||
"""Build *and* compile the read-only KB graph on first ``ask_knowledge_base`` use.
|
||||
|
||||
Both the spec build (``build_kb_readonly_subagent`` — middleware +
|
||||
tool-schema construction, ~the same cost as one regular subagent) and
|
||||
the ``create_agent`` compile are deferred here (memoized by
|
||||
``build_ask_knowledge_base_tool``) so neither is paid on the cold
|
||||
agent-build / TTFT path; most first turns never call a subagent.
|
||||
"""
|
||||
build_start = time.perf_counter()
|
||||
kb_readonly_spec = build_kb_readonly_subagent(
|
||||
dependencies=subagent_dependencies,
|
||||
model=llm,
|
||||
middleware_stack=shared_subagent_middleware,
|
||||
).spec
|
||||
runnable = create_agent(
|
||||
llm,
|
||||
system_prompt=kb_readonly_spec["system_prompt"],
|
||||
tools=kb_readonly_spec["tools"],
|
||||
middleware=kb_readonly_spec["middleware"],
|
||||
name=KB_READONLY_NAME,
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[subagent_compile_lazy] name=%s (spec+compile) in %.3fs",
|
||||
KB_READONLY_NAME,
|
||||
time.perf_counter() - build_start,
|
||||
)
|
||||
return runnable
|
||||
|
||||
ask_kb_tool = build_ask_knowledge_base_tool(_compile_kb_readonly)
|
||||
|
||||
def _build_kb_write_spec() -> dict[str, Any]:
|
||||
"""Build the *write* knowledge_base subagent spec on first ``task`` use.
|
||||
|
||||
The KB filesystem middleware builds ~13 tool schemas at ~150ms each
|
||||
(~2s total), all of which used to land on the cold agent-build / TTFT
|
||||
path even though ``task("knowledge_base")`` is essentially never the
|
||||
first thing a turn does. Deferring the whole spec build here (memoized
|
||||
by the checkpointed subagent middleware) moves that cost to the first
|
||||
actual KB-write delegation. Captures the same ``subagent_dependencies``
|
||||
the eager build would have used, so cross-thread cache behaviour is
|
||||
unchanged.
|
||||
"""
|
||||
spec = build_kb_write_subagent(
|
||||
dependencies=subagent_dependencies,
|
||||
model=llm,
|
||||
middleware_stack=shared_subagent_middleware,
|
||||
).spec
|
||||
if disabled_tools:
|
||||
disabled = frozenset(disabled_tools)
|
||||
tools = spec.get("tools") # type: ignore[typeddict-item]
|
||||
if isinstance(tools, list):
|
||||
spec["tools"] = [ # type: ignore[typeddict-unknown-key]
|
||||
t for t in tools if getattr(t, "name", None) not in disabled
|
||||
]
|
||||
return cast(dict[str, Any], spec)
|
||||
|
||||
subagents_start = time.perf_counter()
|
||||
# The write knowledge_base subagent is excluded from the eager build and
|
||||
# registered as a lazy descriptor (name + description cheap; spec built on
|
||||
# first ``task("knowledge_base")`` use) — see ``_build_kb_write_spec``.
|
||||
exclude_names = [*get_subagents_to_exclude(available_connectors), KB_WRITE_NAME]
|
||||
subagents: list[SubAgent] = build_subagents(
|
||||
dependencies=subagent_dependencies,
|
||||
model=llm,
|
||||
middleware_stack=shared_subagent_middleware,
|
||||
mcp_tools_by_agent=mcp_tools_by_agent or {},
|
||||
exclude=get_subagents_to_exclude(available_connectors),
|
||||
exclude=exclude_names,
|
||||
disabled_tools=disabled_tools,
|
||||
ask_kb_tool=ask_kb_tool,
|
||||
)
|
||||
kb_write_descriptor = cast(
|
||||
SubAgent,
|
||||
{
|
||||
"name": KB_WRITE_NAME,
|
||||
"description": load_kb_write_description(),
|
||||
SURF_LAZY_SPEC_FACTORY_KEY: _build_kb_write_spec,
|
||||
},
|
||||
)
|
||||
subagents.append(kb_write_descriptor)
|
||||
subagents_elapsed = time.perf_counter() - subagents_start
|
||||
logging.debug("Subagents registry: %s", [s["name"] for s in subagents])
|
||||
|
||||
assembly_start = time.perf_counter()
|
||||
stack: list[Any] = [
|
||||
build_busy_mutex_mw(flags),
|
||||
build_otel_mw(flags),
|
||||
|
|
@ -170,6 +244,7 @@ def build_main_agent_deepagent_middleware(
|
|||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
preinjection_enabled=flags.enable_kb_priority_preinjection,
|
||||
),
|
||||
build_kb_context_projection_mw(),
|
||||
build_kb_persistence_mw(
|
||||
|
|
@ -223,4 +298,17 @@ def build_main_agent_deepagent_middleware(
|
|||
),
|
||||
build_anthropic_cache_mw(),
|
||||
]
|
||||
return [m for m in stack if m is not None]
|
||||
result = [m for m in stack if m is not None]
|
||||
assembly_elapsed = time.perf_counter() - assembly_start
|
||||
_perf_log.info(
|
||||
"[stack_build] total=%.3fs shared_subagent_mw=%.3fs "
|
||||
"build_subagents=%.3fs stack_assembly=%.3fs subagents=%d mw=%d "
|
||||
"(kb_readonly deferred to first ask_knowledge_base)",
|
||||
time.perf_counter() - stack_build_start,
|
||||
shared_mw_elapsed,
|
||||
subagents_elapsed,
|
||||
assembly_elapsed,
|
||||
len(subagents),
|
||||
len(result),
|
||||
)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -91,10 +91,18 @@ async def build_agent_with_cache(
|
|||
# Every per-request value any middleware closes over at __init__ must be in
|
||||
# the key, otherwise a hit will leak state across threads. Bump the schema
|
||||
# version when the component list changes shape.
|
||||
#
|
||||
# Cross-thread reuse: when enabled, ``thread_id`` is dropped from the key so
|
||||
# one compiled graph serves all of a user's (same space/config/visibility)
|
||||
# chats. This is only safe because ActionLog, KB-persistence, and the
|
||||
# deliverables tools now resolve the chat thread from the live
|
||||
# RunnableConfig instead of a constructor closure; the schema tag is bumped
|
||||
# so v2 (per-thread) entries are never confused with v3 (shared) ones.
|
||||
cross_thread = flags.enable_cross_thread_agent_cache
|
||||
cache_key = stable_hash(
|
||||
"multi-agent-v2",
|
||||
"multi-agent-v3" if cross_thread else "multi-agent-v2",
|
||||
config_id,
|
||||
thread_id,
|
||||
None if cross_thread else thread_id,
|
||||
user_id,
|
||||
search_space_id,
|
||||
visibility,
|
||||
|
|
|
|||
|
|
@ -67,13 +67,13 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.config import config
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -328,8 +328,8 @@ def _short(key: str, n: int = 16) -> str:
|
|||
# Module-level singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DEFAULT_MAXSIZE = int(os.getenv("SURFSENSE_AGENT_CACHE_MAXSIZE", "256"))
|
||||
_DEFAULT_TTL = float(os.getenv("SURFSENSE_AGENT_CACHE_TTL_SECONDS", "1800"))
|
||||
_DEFAULT_MAXSIZE = config.AGENT_CACHE_MAXSIZE
|
||||
_DEFAULT_TTL = config.AGENT_CACHE_TTL_SECONDS
|
||||
|
||||
_cache: _AgentCache = _AgentCache(maxsize=_DEFAULT_MAXSIZE, ttl_seconds=_DEFAULT_TTL)
|
||||
|
||||
|
|
|
|||
|
|
@ -209,9 +209,6 @@ async def create_multi_agent_chat_deep_agent(
|
|||
|
||||
modified_disabled_tools = list(disabled_tools) if disabled_tools else []
|
||||
|
||||
if "search_knowledge_base" not in modified_disabled_tools:
|
||||
modified_disabled_tools.append("search_knowledge_base")
|
||||
|
||||
if enabled_tools is not None:
|
||||
main_agent_enabled_tools = [
|
||||
n for n in enabled_tools if n in MAIN_AGENT_SURFSENSE_TOOL_NAMES
|
||||
|
|
|
|||
|
|
@ -1,9 +1,17 @@
|
|||
<knowledge_base_first>
|
||||
CRITICAL — ground factual answers in what you actually receive this turn:
|
||||
- the user's knowledge base via `search_knowledge_base` (your PRIMARY source
|
||||
for anything about their documents, notes, or connected data — the
|
||||
`<workspace_tree>` only lists what exists, so call the tool to read the
|
||||
actual content before answering),
|
||||
- injected workspace context (see `<dynamic_context>`),
|
||||
- results from your own tool calls (`web_search`, `scrape_webpage`),
|
||||
- results from your other tool calls (`web_search`, `scrape_webpage`),
|
||||
- or substantive summaries returned by a `task` specialist you invoked.
|
||||
|
||||
For questions about the user's own workspace, call `search_knowledge_base`
|
||||
first rather than answering from the tree or from memory. Use
|
||||
`task(knowledge_base)` when you need a document's full text or deeper reads.
|
||||
|
||||
Do **not** answer factual or informational questions from general knowledge
|
||||
unless the user explicitly authorises it after you say you couldn't find
|
||||
enough in those sources. The flow when nothing is found:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,19 @@
|
|||
- `search_knowledge_base` — Search the user's own knowledge base (their
|
||||
indexed documents, notes, files, and connected sources) with hybrid
|
||||
semantic + keyword retrieval.
|
||||
- This is your PRIMARY way to ground factual answers about the user's
|
||||
workspace. The `<workspace_tree>` shows what files exist; this tool pulls
|
||||
the actual relevant content. Call it BEFORE answering any question about
|
||||
the user's documents, notes, or connected data — don't answer from the
|
||||
tree alone or from memory.
|
||||
- Each hit returns the document's virtual path, a relevance score, and the
|
||||
matched snippets. The snippets are often enough to answer directly with a
|
||||
citation.
|
||||
- When you need a document's full text (not just snippets), delegate a read
|
||||
to the `knowledge_base` specialist via `task`, passing the path from the
|
||||
results.
|
||||
- Args: `query` (focused; include concrete entities, acronyms, people,
|
||||
projects, or terms), `top_k` (default 5, max 20).
|
||||
- If nothing relevant comes back, tell the user you couldn't find it in
|
||||
their workspace before offering to search the web or answer from general
|
||||
knowledge.
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
<example>
|
||||
user: "What did our Q3 planning doc say about hiring?"
|
||||
→ search_knowledge_base(query="Q3 planning hiring headcount plan")
|
||||
(Answer from the returned snippets with a citation; if you need the full
|
||||
document, task the knowledge_base specialist with the returned path.)
|
||||
</example>
|
||||
|
||||
<example>
|
||||
user: "Summarize my notes on the Acme migration."
|
||||
→ search_knowledge_base(query="Acme migration notes")
|
||||
→ task(subagent_type="knowledge_base", description="Read <path> and return a
|
||||
detailed summary of the Acme migration plan, risks, and timeline.")
|
||||
</example>
|
||||
|
|
@ -6,6 +6,7 @@ Connector integrations, MCP, deliverables, etc. are delegated via ``task`` subag
|
|||
from __future__ import annotations
|
||||
|
||||
MAIN_AGENT_SURFSENSE_TOOL_NAMES_ORDERED: tuple[str, ...] = (
|
||||
"search_knowledge_base",
|
||||
"web_search",
|
||||
"scrape_webpage",
|
||||
"update_memory",
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from app.agents.chat.shared.tools.web_search import create_web_search_tool
|
|||
from app.db import ChatVisibility
|
||||
|
||||
from .scrape_webpage import create_scrape_webpage_tool
|
||||
from .search_knowledge_base import create_search_knowledge_base_tool
|
||||
from .update_memory import (
|
||||
create_update_memory_tool,
|
||||
create_update_team_memory_tool,
|
||||
|
|
@ -35,6 +36,14 @@ def _build_scrape_webpage_tool(deps: dict[str, Any]) -> BaseTool:
|
|||
return create_scrape_webpage_tool(firecrawl_api_key=deps.get("firecrawl_api_key"))
|
||||
|
||||
|
||||
def _build_search_knowledge_base_tool(deps: dict[str, Any]) -> BaseTool:
|
||||
return create_search_knowledge_base_tool(
|
||||
search_space_id=deps["search_space_id"],
|
||||
available_connectors=deps.get("available_connectors"),
|
||||
available_document_types=deps.get("available_document_types"),
|
||||
)
|
||||
|
||||
|
||||
def _build_web_search_tool(deps: dict[str, Any]) -> BaseTool:
|
||||
return create_web_search_tool(
|
||||
search_space_id=deps.get("search_space_id"),
|
||||
|
|
@ -75,6 +84,10 @@ def _build_update_memory_tool(deps: dict[str, Any]) -> BaseTool:
|
|||
_MAIN_AGENT_TOOL_FACTORIES: dict[
|
||||
str, tuple[Callable[[dict[str, Any]], BaseTool], tuple[str, ...]]
|
||||
] = {
|
||||
"search_knowledge_base": (
|
||||
_build_search_knowledge_base_tool,
|
||||
("search_space_id",),
|
||||
),
|
||||
"scrape_webpage": (_build_scrape_webpage_tool, ()),
|
||||
"web_search": (_build_web_search_tool, ()),
|
||||
"create_automation": (
|
||||
|
|
|
|||
|
|
@ -8,18 +8,19 @@ transcript directly via the YouTubeTranscriptApi instead of crawling the page.
|
|||
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
from fake_useragent import UserAgent
|
||||
from langchain_core.tools import tool
|
||||
from requests import Session
|
||||
from scrapling.fetchers import AsyncFetcher
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
|
||||
from app.connectors.webcrawler_connector import WebCrawlerConnector
|
||||
from app.tasks.document_processors.youtube_processor import get_youtube_video_id
|
||||
from app.utils.proxy_config import get_requests_proxies
|
||||
from app.utils.proxy import get_proxy_url, get_requests_proxies
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -85,15 +86,20 @@ async def _scrape_youtube_video(
|
|||
oembed_url = "https://www.youtube.com/oembed"
|
||||
|
||||
try:
|
||||
async with (
|
||||
aiohttp.ClientSession() as http_session,
|
||||
http_session.get(
|
||||
oembed_url,
|
||||
params=params,
|
||||
proxy=residential_proxies["http"] if residential_proxies else None,
|
||||
) as response,
|
||||
):
|
||||
video_data = await response.json()
|
||||
oembed_fetch_start = time.perf_counter()
|
||||
oembed_page = await AsyncFetcher.get(
|
||||
oembed_url,
|
||||
params=params,
|
||||
proxy=get_proxy_url(),
|
||||
stealthy_headers=True,
|
||||
)
|
||||
logger.info(
|
||||
"[scrape_webpage][perf] source=oembed video=%s status=%s fetch_ms=%.1f",
|
||||
video_id,
|
||||
getattr(oembed_page, "status", None),
|
||||
(time.perf_counter() - oembed_fetch_start) * 1000,
|
||||
)
|
||||
video_data = oembed_page.json()
|
||||
except Exception:
|
||||
video_data = {}
|
||||
|
||||
|
|
@ -102,6 +108,7 @@ async def _scrape_youtube_video(
|
|||
|
||||
# --- Transcript via YouTubeTranscriptApi ---
|
||||
try:
|
||||
transcript_fetch_start = time.perf_counter()
|
||||
ua = UserAgent()
|
||||
http_client = Session()
|
||||
http_client.headers.update({"User-Agent": ua.random})
|
||||
|
|
@ -115,6 +122,11 @@ async def _scrape_youtube_video(
|
|||
transcript = next(iter(transcript_list))
|
||||
captions = transcript.fetch()
|
||||
|
||||
logger.info(
|
||||
"[scrape_webpage][perf] source=transcript video=%s fetch_ms=%.1f",
|
||||
video_id,
|
||||
(time.perf_counter() - transcript_fetch_start) * 1000,
|
||||
)
|
||||
logger.info(
|
||||
f"[scrape_webpage] Fetched transcript for {video_id} "
|
||||
f"in {transcript.language} ({transcript.language_code})"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,232 @@
|
|||
"""On-demand ``search_knowledge_base`` main-agent tool (OpenCode-style lazy RAG).
|
||||
|
||||
The main agent no longer receives eagerly pre-injected KB context on every
|
||||
turn (see :class:`KnowledgePriorityMiddleware`, now gated off by default).
|
||||
Instead it calls this tool only when it decides it needs knowledge-base
|
||||
content. The tool runs a single hybrid search (embed + DB search, ~0.5s),
|
||||
formats the top matches for the model, and writes ``kb_matched_chunk_ids``
|
||||
into graph state so matched-section highlighting is preserved when the agent
|
||||
later reads a document via ``task(knowledge_base)``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Annotated, Any
|
||||
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langchain_core.tools import BaseTool, StructuredTool
|
||||
from langgraph.types import Command
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.middleware.knowledge_search import (
|
||||
search_knowledge_base as _hybrid_search_kb,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.agents.chat.runtime.path_resolver import (
|
||||
PathIndex,
|
||||
build_path_index,
|
||||
doc_to_virtual_path,
|
||||
)
|
||||
from app.db import Document, shielded_async_session
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
_DEFAULT_TOP_K = 5
|
||||
_MAX_TOP_K = 20
|
||||
_PER_DOC_SNIPPET_CHARS = 1200
|
||||
_MAX_TOTAL_CHARS = 16_000
|
||||
|
||||
_TOOL_DESCRIPTION = (
|
||||
"Search the user's knowledge base (their indexed documents, files, and "
|
||||
"connector content) for passages relevant to a query, using hybrid "
|
||||
"semantic + keyword retrieval.\n\n"
|
||||
"Use this FIRST to ground any factual or informational answer about the "
|
||||
"user's own documents, notes, or connected sources. The workspace tree "
|
||||
"shows which files exist; this tool pulls the actual relevant content. "
|
||||
"Each hit returns the document's virtual path, a relevance score, and the "
|
||||
"matched snippets. If you need a document's full text, delegate a read to "
|
||||
"the knowledge_base specialist via `task` using the returned path.\n\n"
|
||||
"Write a focused, specific query containing the concrete entities, "
|
||||
"acronyms, people, projects, or terms you are looking for."
|
||||
)
|
||||
|
||||
|
||||
async def _resolve_virtual_paths(
|
||||
results: list[dict[str, Any]],
|
||||
*,
|
||||
search_space_id: int,
|
||||
) -> dict[int, str]:
|
||||
"""Resolve ``Document.id`` -> canonical virtual path for the search hits."""
|
||||
doc_ids = [
|
||||
doc_id
|
||||
for doc_id in (
|
||||
(doc.get("document") or {}).get("id")
|
||||
for doc in results
|
||||
if isinstance(doc, dict)
|
||||
)
|
||||
if isinstance(doc_id, int)
|
||||
]
|
||||
if not doc_ids:
|
||||
return {}
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
index: PathIndex = await build_path_index(session, search_space_id)
|
||||
folder_rows = await session.execute(
|
||||
select(Document.id, Document.folder_id).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.id.in_(doc_ids),
|
||||
)
|
||||
)
|
||||
folder_by_doc_id = {row.id: row.folder_id for row in folder_rows.all()}
|
||||
|
||||
paths: dict[int, str] = {}
|
||||
for doc in results:
|
||||
doc_meta = doc.get("document") or {}
|
||||
doc_id = doc_meta.get("id")
|
||||
if not isinstance(doc_id, int):
|
||||
continue
|
||||
folder_id = folder_by_doc_id.get(doc_id, doc_meta.get("folder_id"))
|
||||
paths[doc_id] = doc_to_virtual_path(
|
||||
doc_id=doc_id,
|
||||
title=str(doc_meta.get("title") or "untitled"),
|
||||
folder_id=folder_id if isinstance(folder_id, int) else None,
|
||||
index=index,
|
||||
)
|
||||
return paths
|
||||
|
||||
|
||||
def _format_hits(
|
||||
results: list[dict[str, Any]],
|
||||
*,
|
||||
paths: dict[int, str],
|
||||
query: str,
|
||||
) -> str:
|
||||
"""Render search hits as a compact, model-readable block."""
|
||||
if not results:
|
||||
return (
|
||||
f"No knowledge-base matches found for query: {query!r}.\n"
|
||||
"Tell the user nothing relevant was found in their workspace, or "
|
||||
"try a different query."
|
||||
)
|
||||
|
||||
lines: list[str] = [f"<knowledge_base_results query={query!r}>"]
|
||||
total = len(lines[0])
|
||||
for rank, doc in enumerate(results, start=1):
|
||||
doc_meta = doc.get("document") or {}
|
||||
doc_id = doc_meta.get("id")
|
||||
title = str(doc_meta.get("title") or "untitled")
|
||||
doc_type = doc_meta.get("document_type") or doc.get("source") or "document"
|
||||
score = doc.get("score")
|
||||
score_str = f"{score:.3f}" if isinstance(score, int | float) else "n/a"
|
||||
path = paths.get(doc_id) if isinstance(doc_id, int) else None
|
||||
|
||||
header = f"\n{rank}. {title} (type={doc_type}, score={score_str})" + (
|
||||
f"\n path: {path}" if path else ""
|
||||
)
|
||||
|
||||
content = (doc.get("content") or "").strip()
|
||||
if content:
|
||||
snippet = content[:_PER_DOC_SNIPPET_CHARS].strip()
|
||||
if len(content) > _PER_DOC_SNIPPET_CHARS:
|
||||
snippet += " ..."
|
||||
body = "\n " + snippet.replace("\n", "\n ")
|
||||
else:
|
||||
body = "\n (no preview available; read the document for details)"
|
||||
|
||||
entry = header + body
|
||||
if total + len(entry) > _MAX_TOTAL_CHARS:
|
||||
lines.append("\n<!-- additional matches truncated to fit context -->")
|
||||
break
|
||||
lines.append(entry)
|
||||
total += len(entry)
|
||||
|
||||
lines.append(
|
||||
"\n\nTo read a full document, delegate to the knowledge_base specialist "
|
||||
"with `task`, referencing the path above."
|
||||
)
|
||||
lines.append("\n</knowledge_base_results>")
|
||||
return "".join(lines)
|
||||
|
||||
|
||||
def _matched_chunk_ids(results: list[dict[str, Any]]) -> dict[int, list[int]]:
|
||||
"""Extract ``Document.id`` -> matched chunk ids for state hand-off."""
|
||||
matched: dict[int, list[int]] = {}
|
||||
for doc in results:
|
||||
doc_id = (doc.get("document") or {}).get("id")
|
||||
if not isinstance(doc_id, int):
|
||||
continue
|
||||
chunk_ids = doc.get("matched_chunk_ids") or []
|
||||
normalized = [int(cid) for cid in chunk_ids if isinstance(cid, int | str)]
|
||||
if normalized:
|
||||
matched[doc_id] = normalized
|
||||
return matched
|
||||
|
||||
|
||||
def create_search_knowledge_base_tool(
|
||||
*,
|
||||
search_space_id: int,
|
||||
available_connectors: list[str] | None = None,
|
||||
available_document_types: list[str] | None = None,
|
||||
) -> BaseTool:
|
||||
"""Factory for the on-demand ``search_knowledge_base`` tool."""
|
||||
|
||||
_space_id = search_space_id
|
||||
_connectors = available_connectors
|
||||
_doc_types = available_document_types
|
||||
|
||||
async def _impl(
|
||||
query: Annotated[
|
||||
str,
|
||||
"Focused search query with the concrete entities/terms to look for.",
|
||||
],
|
||||
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||
top_k: Annotated[
|
||||
int,
|
||||
"Maximum number of documents to return (default 5).",
|
||||
] = _DEFAULT_TOP_K,
|
||||
) -> Command | str:
|
||||
cleaned_query = (query or "").strip()
|
||||
if not cleaned_query:
|
||||
return "Error: provide a non-empty search query."
|
||||
|
||||
clamped_top_k = min(max(1, top_k), _MAX_TOP_K)
|
||||
t0 = time.perf_counter()
|
||||
results = await _hybrid_search_kb(
|
||||
query=cleaned_query,
|
||||
search_space_id=_space_id,
|
||||
available_connectors=_connectors,
|
||||
available_document_types=_doc_types,
|
||||
top_k=clamped_top_k,
|
||||
)
|
||||
|
||||
paths = await _resolve_virtual_paths(results, search_space_id=_space_id)
|
||||
rendered = _format_hits(results, paths=paths, query=cleaned_query)
|
||||
matched = _matched_chunk_ids(results)
|
||||
|
||||
_perf_log.info(
|
||||
"[search_knowledge_base] tool query=%r results=%d chars=%d in %.3fs",
|
||||
cleaned_query[:60],
|
||||
len(results),
|
||||
len(rendered),
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
|
||||
update: dict[str, Any] = {
|
||||
"messages": [
|
||||
ToolMessage(content=rendered, tool_call_id=runtime.tool_call_id)
|
||||
],
|
||||
}
|
||||
if matched:
|
||||
update["kb_matched_chunk_ids"] = matched
|
||||
return Command(update=update)
|
||||
|
||||
return StructuredTool.from_function(
|
||||
name="search_knowledge_base",
|
||||
description=_TOOL_DESCRIPTION,
|
||||
coroutine=_impl,
|
||||
)
|
||||
|
|
@ -55,6 +55,13 @@ class AgentFeatureFlags:
|
|||
enable_specialized_subagents: bool = True
|
||||
enable_kb_planner_runnable: bool = True
|
||||
|
||||
# KB retrieval mode — when False (default), the main agent retrieves KB
|
||||
# content lazily via the on-demand ``search_knowledge_base`` tool and the
|
||||
# expensive per-turn pre-injection (planner LLM + embed + hybrid search,
|
||||
# ~2.3s) is skipped; explicit @-mentions are still surfaced cheaply. Set
|
||||
# True to restore the original eager ``<priority_documents>`` pre-injection.
|
||||
enable_kb_priority_preinjection: bool = False
|
||||
|
||||
# Snapshot / revert
|
||||
enable_action_log: bool = True
|
||||
enable_revert_route: bool = True
|
||||
|
|
@ -71,6 +78,14 @@ class AgentFeatureFlags:
|
|||
# is read from runtime.context, not the constructor closure. Rollback via
|
||||
# SURFSENSE_ENABLE_AGENT_CACHE=false.
|
||||
enable_agent_cache: bool = True
|
||||
# Reuse one compiled graph across a returning user's *new* chats by dropping
|
||||
# ``thread_id`` from the agent_cache key. Safe because every middleware/tool
|
||||
# that needs the chat thread now resolves it from the live RunnableConfig
|
||||
# (ActionLog, KB-persistence, deliverables) rather than a constructor
|
||||
# closure, and mutation tools open fresh per-call sessions. Turns a
|
||||
# returning user's cold first turn into a cache hit (cold == warm).
|
||||
# Rollback via SURFSENSE_ENABLE_CROSS_THREAD_AGENT_CACHE=false.
|
||||
enable_cross_thread_agent_cache: bool = True
|
||||
# Deferred: only helps on outer-cache MISSES, so off until data shows cold
|
||||
# misses are frequent enough to justify the extra global state.
|
||||
enable_agent_cache_share_gp_subagent: bool = False
|
||||
|
|
@ -104,11 +119,14 @@ class AgentFeatureFlags:
|
|||
enable_skills=False,
|
||||
enable_specialized_subagents=False,
|
||||
enable_kb_planner_runnable=False,
|
||||
# Full rollback restores the original eager KB pre-injection.
|
||||
enable_kb_priority_preinjection=True,
|
||||
enable_action_log=False,
|
||||
enable_revert_route=False,
|
||||
enable_plugin_loader=False,
|
||||
enable_otel=False,
|
||||
enable_agent_cache=False,
|
||||
enable_cross_thread_agent_cache=False,
|
||||
enable_agent_cache_share_gp_subagent=False,
|
||||
)
|
||||
|
||||
|
|
@ -141,6 +159,9 @@ class AgentFeatureFlags:
|
|||
enable_kb_planner_runnable=_env_bool(
|
||||
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", True
|
||||
),
|
||||
enable_kb_priority_preinjection=_env_bool(
|
||||
"SURFSENSE_ENABLE_KB_PRIORITY_PREINJECTION", False
|
||||
),
|
||||
# Snapshot / revert
|
||||
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True),
|
||||
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True),
|
||||
|
|
@ -150,6 +171,9 @@ class AgentFeatureFlags:
|
|||
enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False),
|
||||
# Performance
|
||||
enable_agent_cache=_env_bool("SURFSENSE_ENABLE_AGENT_CACHE", True),
|
||||
enable_cross_thread_agent_cache=_env_bool(
|
||||
"SURFSENSE_ENABLE_CROSS_THREAD_AGENT_CACHE", True
|
||||
),
|
||||
enable_agent_cache_share_gp_subagent=_env_bool(
|
||||
"SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT", False
|
||||
),
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import time as _perf_time
|
||||
from typing import Any
|
||||
|
||||
from deepagents import FilesystemMiddleware
|
||||
|
|
@ -14,6 +15,7 @@ from app.agents.chat.multi_agent_chat.shared.middleware.filesystem.sandbox impor
|
|||
from app.agents.chat.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
from ..system_prompt import build_system_prompt
|
||||
from ..tools import (
|
||||
|
|
@ -34,6 +36,8 @@ from ..tools.glob.description import select_description as glob_description
|
|||
from ..tools.grep.description import select_description as grep_description
|
||||
from .read_only_policy import READ_ONLY_TOOL_NAMES
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||
"""SurfSense-specific filesystem middleware (cloud + desktop)."""
|
||||
|
|
@ -60,16 +64,22 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
is_sandbox_enabled() and thread_id is not None and not read_only
|
||||
)
|
||||
|
||||
_t0 = _perf_time.perf_counter()
|
||||
system_prompt = build_system_prompt(
|
||||
filesystem_mode,
|
||||
sandbox_available=self._sandbox_available,
|
||||
)
|
||||
_t_prompt = _perf_time.perf_counter() - _t0
|
||||
|
||||
_t0 = _perf_time.perf_counter()
|
||||
super().__init__(
|
||||
backend=backend,
|
||||
system_prompt=system_prompt,
|
||||
tool_token_limit_before_evict=tool_token_limit_before_evict,
|
||||
)
|
||||
_t_super = _perf_time.perf_counter() - _t0
|
||||
|
||||
_t0 = _perf_time.perf_counter()
|
||||
self.tools = [t for t in self.tools if t.name != "execute"]
|
||||
self.tools.append(create_mkdir_tool(self))
|
||||
self.tools.append(create_cd_tool(self))
|
||||
|
|
@ -83,6 +93,15 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
|
||||
if read_only:
|
||||
self.tools = [t for t in self.tools if t.name in READ_ONLY_TOOL_NAMES]
|
||||
_t_tools = _perf_time.perf_counter() - _t0
|
||||
_perf_log.info(
|
||||
"[fs_middleware_init] ro=%s system_prompt=%.3fs super_init=%.3fs "
|
||||
"surf_tools=%.3fs",
|
||||
read_only,
|
||||
_t_prompt,
|
||||
_t_super,
|
||||
_t_tools,
|
||||
)
|
||||
|
||||
# ----------------------------------------- base-class tool overrides
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
|
@ -29,6 +28,10 @@ from daytona.common.errors import DaytonaError
|
|||
from deepagents.backends.protocol import ExecuteResponse
|
||||
from langchain_daytona import DaytonaSandbox
|
||||
|
||||
# Aliased to avoid clashing with the local ``config = DaytonaConfig(...)``
|
||||
# variable used inside ``_get_client``.
|
||||
from app.config import config as app_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -73,7 +76,7 @@ SANDBOX_DOCUMENTS_ROOT = "/home/daytona/documents"
|
|||
|
||||
|
||||
def is_sandbox_enabled() -> bool:
|
||||
return os.environ.get("DAYTONA_SANDBOX_ENABLED", "FALSE").upper() == "TRUE"
|
||||
return app_config.DAYTONA_SANDBOX_ENABLED
|
||||
|
||||
|
||||
def _get_client() -> Daytona:
|
||||
|
|
@ -81,9 +84,9 @@ def _get_client() -> Daytona:
|
|||
with _client_lock:
|
||||
if _daytona_client is None:
|
||||
config = DaytonaConfig(
|
||||
api_key=os.environ.get("DAYTONA_API_KEY", ""),
|
||||
api_url=os.environ.get("DAYTONA_API_URL", "https://app.daytona.io/api"),
|
||||
target=os.environ.get("DAYTONA_TARGET", "us"),
|
||||
api_key=app_config.DAYTONA_API_KEY,
|
||||
api_url=app_config.DAYTONA_API_URL,
|
||||
target=app_config.DAYTONA_TARGET,
|
||||
)
|
||||
_daytona_client = Daytona(config)
|
||||
return _daytona_client
|
||||
|
|
@ -92,7 +95,7 @@ def _get_client() -> Daytona:
|
|||
def _sandbox_create_params(
|
||||
labels: dict[str, str],
|
||||
) -> CreateSandboxFromSnapshotParams:
|
||||
snapshot_id = os.environ.get("DAYTONA_SNAPSHOT_ID") or None
|
||||
snapshot_id = app_config.DAYTONA_SNAPSHOT_ID
|
||||
return CreateSandboxFromSnapshotParams(
|
||||
language="python",
|
||||
labels=labels,
|
||||
|
|
@ -302,7 +305,7 @@ async def delete_sandbox(thread_id: int | str) -> None:
|
|||
|
||||
|
||||
def _get_sandbox_files_dir() -> Path:
|
||||
return Path(os.environ.get("SANDBOX_FILES_DIR", "sandbox_files"))
|
||||
return Path(app_config.SANDBOX_FILES_DIR)
|
||||
|
||||
|
||||
def _local_path_for(thread_id: int | str, sandbox_path: str) -> Path:
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ import asyncio
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
|
@ -346,6 +347,7 @@ async def browse_recent_documents(
|
|||
|
||||
from app.db import DocumentType
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
async with shielded_async_session() as session:
|
||||
base_conditions = [
|
||||
Document.search_space_id == search_space_id,
|
||||
|
|
@ -445,6 +447,12 @@ async def browse_recent_documents(
|
|||
),
|
||||
}
|
||||
)
|
||||
_perf_log.info(
|
||||
"[kb_priority.recent] db=%.3fs docs=%d space=%d",
|
||||
time.perf_counter() - _t0,
|
||||
len(results),
|
||||
search_space_id,
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
|
|
@ -462,10 +470,18 @@ async def search_knowledge_base(
|
|||
if not query:
|
||||
return []
|
||||
|
||||
# ``embed_texts`` serializes behind a global embedding lock and, for API
|
||||
# models, makes a network round-trip — so this can stall while another
|
||||
# turn is embedding. Timed separately from the DB search to tell the two
|
||||
# apart when debugging slow time-to-first-token.
|
||||
_t_embed = time.perf_counter()
|
||||
[embedding] = await asyncio.to_thread(embed_texts, [query])
|
||||
_embed_elapsed = time.perf_counter() - _t_embed
|
||||
|
||||
doc_types = _resolve_search_types(available_connectors, available_document_types)
|
||||
retriever_top_k = min(top_k * 3, 30)
|
||||
|
||||
_t_search = time.perf_counter()
|
||||
async with shielded_async_session() as session:
|
||||
retriever = ChucksHybridSearchRetriever(session)
|
||||
results = await retriever.hybrid_search(
|
||||
|
|
@ -477,7 +493,16 @@ async def search_knowledge_base(
|
|||
end_date=end_date,
|
||||
query_embedding=embedding.tolist(),
|
||||
)
|
||||
_search_elapsed = time.perf_counter() - _t_search
|
||||
|
||||
_perf_log.info(
|
||||
"[kb_priority.search] embed=%.3fs hybrid_search=%.3fs results=%d space=%d query=%r",
|
||||
_embed_elapsed,
|
||||
_search_elapsed,
|
||||
len(results),
|
||||
search_space_id,
|
||||
query[:80],
|
||||
)
|
||||
return results[:top_k]
|
||||
|
||||
|
||||
|
|
@ -490,6 +515,7 @@ async def fetch_mentioned_documents(
|
|||
if not document_ids:
|
||||
return []
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
async with shielded_async_session() as session:
|
||||
doc_result = await session.execute(
|
||||
select(Document).where(
|
||||
|
|
@ -546,6 +572,12 @@ async def fetch_mentioned_documents(
|
|||
"_user_mentioned": True,
|
||||
}
|
||||
)
|
||||
_perf_log.info(
|
||||
"[kb_priority.mentioned] db=%.3fs requested=%d resolved=%d",
|
||||
time.perf_counter() - _t0,
|
||||
len(document_ids),
|
||||
len(results),
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
|
|
@ -592,6 +624,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
top_k: int = 10,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
inject_system_message: bool = True, # For backwards compatibility
|
||||
mentions_only: bool = False,
|
||||
) -> None:
|
||||
self.llm = llm
|
||||
# Cheap model for structured internal tasks (query rewrite, date
|
||||
|
|
@ -605,6 +638,10 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
self.top_k = top_k
|
||||
self.mentioned_document_ids = mentioned_document_ids or []
|
||||
self.inject_system_message = inject_system_message
|
||||
# Lazy mode: skip the planner LLM + embedding + hybrid search and only
|
||||
# surface explicit @-mentions. The agent retrieves topical KB content on
|
||||
# demand via the ``search_knowledge_base`` tool instead.
|
||||
self.mentions_only = mentions_only
|
||||
# Compiled lazily and memoized to avoid the per-turn create_agent cost.
|
||||
self._planner: Runnable | None = None
|
||||
self._planner_compile_failed = False
|
||||
|
|
@ -793,15 +830,6 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
runtime: Runtime[Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
t0 = asyncio.get_event_loop().time()
|
||||
(
|
||||
planned_query,
|
||||
start_date,
|
||||
end_date,
|
||||
is_recency,
|
||||
) = await self._plan_search_inputs(
|
||||
messages=messages,
|
||||
user_text=user_text,
|
||||
)
|
||||
|
||||
# Prefer per-turn mentions from runtime.context (lets a cached graph
|
||||
# serve different turns); fall back to the constructor closure, draining
|
||||
|
|
@ -832,6 +860,52 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
if ctx_folders:
|
||||
folder_mention_ids = list(ctx_folders)
|
||||
|
||||
# Lazy mode: skip the planner LLM + embedding + hybrid search entirely.
|
||||
# With no explicit mentions there is nothing cheap to surface, so we bail
|
||||
# out early and let the agent decide to call ``search_knowledge_base``.
|
||||
if self.mentions_only:
|
||||
if not mention_ids and not folder_mention_ids:
|
||||
return None
|
||||
planned_query = user_text
|
||||
start_date = end_date = None
|
||||
is_recency = False
|
||||
search_results: list[dict[str, Any]] = []
|
||||
_search_phase_elapsed = 0.0
|
||||
else:
|
||||
(
|
||||
planned_query,
|
||||
start_date,
|
||||
end_date,
|
||||
is_recency,
|
||||
) = await self._plan_search_inputs(
|
||||
messages=messages,
|
||||
user_text=user_text,
|
||||
)
|
||||
|
||||
_t_search_phase = time.perf_counter()
|
||||
if is_recency:
|
||||
doc_types = _resolve_search_types(
|
||||
self.available_connectors, self.available_document_types
|
||||
)
|
||||
search_results = await browse_recent_documents(
|
||||
search_space_id=self.search_space_id,
|
||||
document_type=doc_types,
|
||||
top_k=self.top_k,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
else:
|
||||
search_results = await search_knowledge_base(
|
||||
query=planned_query,
|
||||
search_space_id=self.search_space_id,
|
||||
available_connectors=self.available_connectors,
|
||||
available_document_types=self.available_document_types,
|
||||
top_k=self.top_k,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
_search_phase_elapsed = time.perf_counter() - _t_search_phase
|
||||
|
||||
mentioned_results: list[dict[str, Any]] = []
|
||||
if mention_ids:
|
||||
mentioned_results = await fetch_mentioned_documents(
|
||||
|
|
@ -839,28 +913,6 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
search_space_id=self.search_space_id,
|
||||
)
|
||||
|
||||
if is_recency:
|
||||
doc_types = _resolve_search_types(
|
||||
self.available_connectors, self.available_document_types
|
||||
)
|
||||
search_results = await browse_recent_documents(
|
||||
search_space_id=self.search_space_id,
|
||||
document_type=doc_types,
|
||||
top_k=self.top_k,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
else:
|
||||
search_results = await search_knowledge_base(
|
||||
query=planned_query,
|
||||
search_space_id=self.search_space_id,
|
||||
available_connectors=self.available_connectors,
|
||||
available_document_types=self.available_document_types,
|
||||
top_k=self.top_k,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
seen_doc_ids: set[int] = set()
|
||||
merged: list[dict[str, Any]] = []
|
||||
for doc in mentioned_results:
|
||||
|
|
@ -874,15 +926,26 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
continue
|
||||
merged.append(doc)
|
||||
|
||||
_t_materialize = time.perf_counter()
|
||||
priority, matched_chunk_ids = await self._materialize_priority(merged)
|
||||
|
||||
if folder_mention_ids:
|
||||
folder_entries = await self._materialize_folder_priority(folder_mention_ids)
|
||||
priority = folder_entries + priority
|
||||
_materialize_elapsed = time.perf_counter() - _t_materialize
|
||||
|
||||
# ``recency=...`` reflects which retrieval path ran (recency browse vs
|
||||
# hybrid search). The planner phase is logged separately by
|
||||
# ``_plan_search_inputs``; here ``search_phase`` and ``materialize``
|
||||
# break down the remaining DB-bound work so a slow turn can be
|
||||
# attributed to planner / search / materialize at a glance.
|
||||
_perf_log.info(
|
||||
"[kb_priority] completed in %.3fs query=%r priority=%d mentioned=%d folders=%d",
|
||||
"[kb_priority] completed in %.3fs (search_phase=%.3fs materialize=%.3fs "
|
||||
"recency=%s) query=%r priority=%d mentioned=%d folders=%d",
|
||||
asyncio.get_event_loop().time() - t0,
|
||||
_search_phase_elapsed,
|
||||
_materialize_elapsed,
|
||||
is_recency,
|
||||
user_text[:80],
|
||||
len(priority),
|
||||
len(mentioned_results),
|
||||
|
|
@ -958,6 +1021,7 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
if not merged:
|
||||
return priority, matched_chunk_ids
|
||||
|
||||
_t0 = time.perf_counter()
|
||||
async with shielded_async_session() as session:
|
||||
index: PathIndex = await build_path_index(session, self.search_space_id)
|
||||
doc_ids = [
|
||||
|
|
@ -1006,6 +1070,11 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
matched_chunk_ids[doc_id] = [
|
||||
int(cid) for cid in chunk_ids if isinstance(cid, int | str)
|
||||
]
|
||||
_perf_log.info(
|
||||
"[kb_priority.materialize] db=%.3fs docs=%d",
|
||||
time.perf_counter() - _t0,
|
||||
len(merged),
|
||||
)
|
||||
return priority, matched_chunk_ids
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -40,44 +40,153 @@ class ToolMetadata:
|
|||
# up in the UI tool picker. This list carries metadata only — wire the actual
|
||||
# implementation in the relevant builder/registry module.
|
||||
TOOL_CATALOG: list[ToolMetadata] = [
|
||||
ToolMetadata(name="generate_podcast", description="Generate an audio podcast from provided content"),
|
||||
ToolMetadata(name="generate_video_presentation", description="Generate a video presentation with slides and narration from provided content"),
|
||||
ToolMetadata(name="generate_report", description="Generate a structured report from provided content and export it"),
|
||||
ToolMetadata(name="generate_resume", description="Generate a professional resume as a Typst document"),
|
||||
ToolMetadata(name="generate_image", description="Generate images from text descriptions using AI image models"),
|
||||
ToolMetadata(name="scrape_webpage", description="Scrape and extract the main content from a webpage"),
|
||||
ToolMetadata(name="web_search", description="Search the web for real-time information using configured search engines"),
|
||||
ToolMetadata(name="create_automation", description="Draft an automation from an NL intent; user approves the card; tool saves"),
|
||||
ToolMetadata(name="update_memory", description="Save important long-term facts, preferences, and instructions to the (personal or team) memory"),
|
||||
ToolMetadata(name="create_notion_page", description="Create a new page in the user's Notion workspace"),
|
||||
ToolMetadata(name="update_notion_page", description="Append new content to an existing Notion page"),
|
||||
ToolMetadata(name="delete_notion_page", description="Delete an existing Notion page"),
|
||||
ToolMetadata(name="create_google_drive_file", description="Create a new Google Doc or Google Sheet in Google Drive"),
|
||||
ToolMetadata(name="delete_google_drive_file", description="Move an indexed Google Drive file to trash"),
|
||||
ToolMetadata(name="create_dropbox_file", description="Create a new file in Dropbox"),
|
||||
ToolMetadata(
|
||||
name="generate_podcast",
|
||||
description="Generate an audio podcast from provided content",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="generate_video_presentation",
|
||||
description="Generate a video presentation with slides and narration from provided content",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="generate_report",
|
||||
description="Generate a structured report from provided content and export it",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="generate_resume",
|
||||
description="Generate a professional resume as a Typst document",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="generate_image",
|
||||
description="Generate images from text descriptions using AI image models",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="search_knowledge_base",
|
||||
description="Search the user's knowledge base with hybrid semantic + keyword retrieval",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="scrape_webpage",
|
||||
description="Scrape and extract the main content from a webpage",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="web_search",
|
||||
description="Search the web for real-time information using configured search engines",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="create_automation",
|
||||
description="Draft an automation from an NL intent; user approves the card; tool saves",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="update_memory",
|
||||
description="Save important long-term facts, preferences, and instructions to the (personal or team) memory",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="create_notion_page",
|
||||
description="Create a new page in the user's Notion workspace",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="update_notion_page",
|
||||
description="Append new content to an existing Notion page",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="delete_notion_page", description="Delete an existing Notion page"
|
||||
),
|
||||
ToolMetadata(
|
||||
name="create_google_drive_file",
|
||||
description="Create a new Google Doc or Google Sheet in Google Drive",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="delete_google_drive_file",
|
||||
description="Move an indexed Google Drive file to trash",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="create_dropbox_file", description="Create a new file in Dropbox"
|
||||
),
|
||||
ToolMetadata(name="delete_dropbox_file", description="Delete a file from Dropbox"),
|
||||
ToolMetadata(name="create_onedrive_file", description="Create a new file in Microsoft OneDrive"),
|
||||
ToolMetadata(name="delete_onedrive_file", description="Move a OneDrive file to the recycle bin"),
|
||||
ToolMetadata(name="search_calendar_events", description="Search Google Calendar events within a date range"),
|
||||
ToolMetadata(name="create_calendar_event", description="Create a new event on Google Calendar"),
|
||||
ToolMetadata(name="update_calendar_event", description="Update an existing indexed Google Calendar event"),
|
||||
ToolMetadata(name="delete_calendar_event", description="Delete an existing indexed Google Calendar event"),
|
||||
ToolMetadata(name="search_gmail", description="Search emails in Gmail using Gmail search syntax"),
|
||||
ToolMetadata(name="read_gmail_email", description="Read the full content of a specific Gmail email"),
|
||||
ToolMetadata(name="create_gmail_draft", description="Create a draft email in Gmail"),
|
||||
ToolMetadata(
|
||||
name="create_onedrive_file",
|
||||
description="Create a new file in Microsoft OneDrive",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="delete_onedrive_file",
|
||||
description="Move a OneDrive file to the recycle bin",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="search_calendar_events",
|
||||
description="Search Google Calendar events within a date range",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="create_calendar_event",
|
||||
description="Create a new event on Google Calendar",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="update_calendar_event",
|
||||
description="Update an existing indexed Google Calendar event",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="delete_calendar_event",
|
||||
description="Delete an existing indexed Google Calendar event",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="search_gmail",
|
||||
description="Search emails in Gmail using Gmail search syntax",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="read_gmail_email",
|
||||
description="Read the full content of a specific Gmail email",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="create_gmail_draft", description="Create a draft email in Gmail"
|
||||
),
|
||||
ToolMetadata(name="send_gmail_email", description="Send an email via Gmail"),
|
||||
ToolMetadata(name="trash_gmail_email", description="Move an indexed email to trash in Gmail"),
|
||||
ToolMetadata(name="update_gmail_draft", description="Update an existing Gmail draft"),
|
||||
ToolMetadata(name="create_confluence_page", description="Create a new page in the user's Confluence space"),
|
||||
ToolMetadata(name="update_confluence_page", description="Update an existing indexed Confluence page"),
|
||||
ToolMetadata(name="delete_confluence_page", description="Delete an existing indexed Confluence page"),
|
||||
ToolMetadata(name="list_discord_channels", description="List text channels in the connected Discord server"),
|
||||
ToolMetadata(name="read_discord_messages", description="Read recent messages from a Discord text channel"),
|
||||
ToolMetadata(name="send_discord_message", description="Send a message to a Discord text channel"),
|
||||
ToolMetadata(name="list_teams_channels", description="List Microsoft Teams and their channels"),
|
||||
ToolMetadata(name="read_teams_messages", description="Read recent messages from a Microsoft Teams channel"),
|
||||
ToolMetadata(name="send_teams_message", description="Send a message to a Microsoft Teams channel"),
|
||||
ToolMetadata(name="list_luma_events", description="List upcoming and recent Luma events"),
|
||||
ToolMetadata(name="read_luma_event", description="Read detailed information about a specific Luma event"),
|
||||
ToolMetadata(
|
||||
name="trash_gmail_email", description="Move an indexed email to trash in Gmail"
|
||||
),
|
||||
ToolMetadata(
|
||||
name="update_gmail_draft", description="Update an existing Gmail draft"
|
||||
),
|
||||
ToolMetadata(
|
||||
name="create_confluence_page",
|
||||
description="Create a new page in the user's Confluence space",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="update_confluence_page",
|
||||
description="Update an existing indexed Confluence page",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="delete_confluence_page",
|
||||
description="Delete an existing indexed Confluence page",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="list_discord_channels",
|
||||
description="List text channels in the connected Discord server",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="read_discord_messages",
|
||||
description="Read recent messages from a Discord text channel",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="send_discord_message",
|
||||
description="Send a message to a Discord text channel",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="list_teams_channels",
|
||||
description="List Microsoft Teams and their channels",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="read_teams_messages",
|
||||
description="Read recent messages from a Microsoft Teams channel",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="send_teams_message",
|
||||
description="Send a message to a Microsoft Teams channel",
|
||||
),
|
||||
ToolMetadata(
|
||||
name="list_luma_events", description="List upcoming and recent Luma events"
|
||||
),
|
||||
ToolMetadata(
|
||||
name="read_luma_event",
|
||||
description="Read detailed information about a specific Luma event",
|
||||
),
|
||||
ToolMetadata(name="create_luma_event", description="Create a new event on Luma"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@ from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receip
|
|||
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.deliverable_wait import (
|
||||
wait_for_deliverable,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.thread_resolver import (
|
||||
resolve_root_thread_id,
|
||||
)
|
||||
from app.db import Podcast, PodcastStatus, shielded_async_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -71,7 +74,7 @@ def create_generate_podcast_tool(
|
|||
title=podcast_title,
|
||||
status=PodcastStatus.PENDING,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
thread_id=resolve_root_thread_id(runtime, thread_id),
|
||||
)
|
||||
session.add(podcast)
|
||||
await session.commit()
|
||||
|
|
|
|||
|
|
@ -14,6 +14,9 @@ from langgraph.types import Command
|
|||
|
||||
from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt
|
||||
from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt
|
||||
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.thread_resolver import (
|
||||
resolve_root_thread_id,
|
||||
)
|
||||
from app.db import Report, shielded_async_session
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.services.llm_service import get_agent_llm
|
||||
|
|
@ -687,7 +690,7 @@ def create_generate_report_tool(
|
|||
},
|
||||
report_style=report_style,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
thread_id=resolve_root_thread_id(runtime, thread_id),
|
||||
report_group_id=report_group_id,
|
||||
)
|
||||
session.add(failed_report)
|
||||
|
|
@ -991,7 +994,7 @@ def create_generate_report_tool(
|
|||
report_metadata=metadata,
|
||||
report_style=report_style,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
thread_id=resolve_root_thread_id(runtime, thread_id),
|
||||
report_group_id=report_group_id,
|
||||
)
|
||||
write_session.add(report)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,9 @@ from langgraph.types import Command
|
|||
|
||||
from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt
|
||||
from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt
|
||||
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.thread_resolver import (
|
||||
resolve_root_thread_id,
|
||||
)
|
||||
from app.db import Report, shielded_async_session
|
||||
from app.services.llm_service import get_agent_llm
|
||||
|
||||
|
|
@ -529,7 +532,7 @@ def create_generate_resume_tool(
|
|||
},
|
||||
report_style="resume",
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
thread_id=resolve_root_thread_id(runtime, thread_id),
|
||||
report_group_id=report_group_id,
|
||||
)
|
||||
session.add(failed)
|
||||
|
|
@ -817,7 +820,7 @@ def create_generate_resume_tool(
|
|||
report_metadata=metadata,
|
||||
report_style="resume",
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
thread_id=resolve_root_thread_id(runtime, thread_id),
|
||||
report_group_id=report_group_id,
|
||||
)
|
||||
write_session.add(report)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
"""Resolve the root chat ``thread_id`` from a deliverables tool's runtime.
|
||||
|
||||
Deliverables tools run inside the ``deliverables`` subagent, which is invoked
|
||||
with a *namespaced* ``thread_id`` of the form ``{chat_id}::task:{tool_call_id}``
|
||||
(see :func:`subagent_invoke_config`). To attribute a generated deliverable
|
||||
(podcast / report / resume / video) to the correct chat, we parse the leading
|
||||
segment of that namespaced id rather than trusting a ``thread_id`` captured at
|
||||
tool-build time — the latter would be stale once a single compiled agent graph
|
||||
is reused across chats (cross-thread ``agent_cache`` reuse).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
|
||||
def resolve_root_thread_id(runtime: ToolRuntime, fallback: int | None) -> int | None:
|
||||
"""Return the root chat id from the live runtime config, else ``fallback``.
|
||||
|
||||
The subagent's ``configurable.thread_id`` looks like ``"2099::task:call_x"``;
|
||||
the chat id is the segment before the first ``"::"``. Returns ``fallback``
|
||||
when the config is absent or the leading segment is not an integer.
|
||||
"""
|
||||
try:
|
||||
config = getattr(runtime, "config", None)
|
||||
if not isinstance(config, dict):
|
||||
return fallback
|
||||
value = (config.get("configurable") or {}).get("thread_id")
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
if isinstance(value, str) and value:
|
||||
root = value.split("::", 1)[0]
|
||||
try:
|
||||
return int(root)
|
||||
except (TypeError, ValueError):
|
||||
return fallback
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return fallback
|
||||
return fallback
|
||||
|
|
@ -22,6 +22,9 @@ from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receip
|
|||
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.deliverable_wait import (
|
||||
wait_for_deliverable,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.thread_resolver import (
|
||||
resolve_root_thread_id,
|
||||
)
|
||||
from app.db import VideoPresentation, VideoPresentationStatus, shielded_async_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -58,7 +61,7 @@ def create_generate_video_presentation_tool(
|
|||
title=video_title,
|
||||
status=VideoPresentationStatus.PENDING,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
thread_id=resolve_root_thread_id(runtime, thread_id),
|
||||
)
|
||||
session.add(video_pres)
|
||||
await session.commit()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Annotated
|
||||
|
||||
from langchain.tools import BaseTool, ToolRuntime
|
||||
|
|
@ -39,7 +40,28 @@ def _wrap_result(result: dict, tool_call_id: str) -> Command:
|
|||
)
|
||||
|
||||
|
||||
def build_ask_knowledge_base_tool(kb_readonly_runnable: Runnable) -> BaseTool:
|
||||
def build_ask_knowledge_base_tool(
|
||||
kb_readonly: Runnable | Callable[[], Runnable],
|
||||
) -> BaseTool:
|
||||
"""Build the ``ask_knowledge_base`` tool backed by the read-only KB graph.
|
||||
|
||||
``kb_readonly`` may be a pre-compiled ``Runnable`` or a zero-arg factory
|
||||
that compiles it on first use. Passing a factory defers the ~0.3-0.8s
|
||||
``create_agent`` cost of the read-only knowledge_base graph until a subagent
|
||||
actually calls ``ask_knowledge_base``, keeping it off the cold agent-build
|
||||
(time-to-first-token) path. The factory result is memoized.
|
||||
"""
|
||||
_cache: dict[str, Runnable] = {}
|
||||
|
||||
def _resolve() -> Runnable:
|
||||
if not callable(kb_readonly) or isinstance(kb_readonly, Runnable):
|
||||
return kb_readonly # type: ignore[return-value]
|
||||
cached = _cache.get("runnable")
|
||||
if cached is None:
|
||||
cached = kb_readonly()
|
||||
_cache["runnable"] = cached
|
||||
return cached
|
||||
|
||||
def ask_knowledge_base(
|
||||
query: Annotated[
|
||||
str,
|
||||
|
|
@ -52,7 +74,7 @@ def build_ask_knowledge_base_tool(kb_readonly_runnable: Runnable) -> BaseTool:
|
|||
raise ValueError("Tool call ID is required for ask_knowledge_base")
|
||||
sub_state = _forward_state(runtime, query)
|
||||
sub_config = subagent_invoke_config(runtime)
|
||||
result = kb_readonly_runnable.invoke(sub_state, config=sub_config)
|
||||
result = _resolve().invoke(sub_state, config=sub_config)
|
||||
return _wrap_result(result, runtime.tool_call_id)
|
||||
|
||||
async def aask_knowledge_base(
|
||||
|
|
@ -67,7 +89,7 @@ def build_ask_knowledge_base_tool(kb_readonly_runnable: Runnable) -> BaseTool:
|
|||
raise ValueError("Tool call ID is required for ask_knowledge_base")
|
||||
sub_state = _forward_state(runtime, query)
|
||||
sub_config = subagent_invoke_config(runtime)
|
||||
result = await kb_readonly_runnable.ainvoke(sub_state, config=sub_config)
|
||||
result = await _resolve().ainvoke(sub_state, config=sub_config)
|
||||
return _wrap_result(result, runtime.tool_call_id)
|
||||
|
||||
return StructuredTool.from_function(
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ The KB-owned :class:`PermissionMiddleware` slot is what enforces
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import time as _perf_time
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
|
@ -31,6 +32,9 @@ from app.agents.chat.multi_agent_chat.shared.permissions import (
|
|||
Ruleset,
|
||||
build_permission_mw,
|
||||
)
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
def _kb_user_allowlist(
|
||||
|
|
@ -93,25 +97,62 @@ def build_kb_middleware(
|
|||
user_allowlist = _kb_user_allowlist(dependencies, subagent_name)
|
||||
if user_allowlist is not None:
|
||||
rulesets.append(user_allowlist)
|
||||
_t0 = _perf_time.perf_counter()
|
||||
permission_mw = build_permission_mw(
|
||||
flags=flags,
|
||||
subagent_rulesets=rulesets,
|
||||
trusted_tool_saver=dependencies.get("trusted_tool_saver"),
|
||||
)
|
||||
_t_perm = _perf_time.perf_counter() - _t0
|
||||
else:
|
||||
_t_perm = 0.0
|
||||
|
||||
_t0 = _perf_time.perf_counter()
|
||||
kb_ctx_mw = build_kb_context_projection_mw()
|
||||
_t_ctx = _perf_time.perf_counter() - _t0
|
||||
|
||||
_t0 = _perf_time.perf_counter()
|
||||
fs_mw = build_filesystem_mw(
|
||||
backend_resolver=dependencies["backend_resolver"],
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=dependencies["search_space_id"],
|
||||
user_id=dependencies.get("user_id"),
|
||||
thread_id=dependencies.get("thread_id"),
|
||||
read_only=read_only,
|
||||
)
|
||||
_t_fs = _perf_time.perf_counter() - _t0
|
||||
|
||||
_t0 = _perf_time.perf_counter()
|
||||
compaction_mw = build_compaction_mw(llm)
|
||||
_t_comp = _perf_time.perf_counter() - _t0
|
||||
|
||||
_t0 = _perf_time.perf_counter()
|
||||
patch_mw = build_patch_tool_calls_mw()
|
||||
_t_patch = _perf_time.perf_counter() - _t0
|
||||
|
||||
_t0 = _perf_time.perf_counter()
|
||||
cache_mw = build_anthropic_cache_mw()
|
||||
_t_cache = _perf_time.perf_counter() - _t0
|
||||
|
||||
_perf_log.info(
|
||||
"[kb_middleware] name=%s ro=%s ctx=%.3fs filesystem=%.3fs "
|
||||
"compaction=%.3fs patch=%.3fs anthropic_cache=%.3fs permission=%.3fs",
|
||||
subagent_name,
|
||||
read_only,
|
||||
_t_ctx,
|
||||
_t_fs,
|
||||
_t_comp,
|
||||
_t_patch,
|
||||
_t_cache,
|
||||
_t_perm,
|
||||
)
|
||||
return [
|
||||
mws["todos"],
|
||||
build_kb_context_projection_mw(),
|
||||
build_filesystem_mw(
|
||||
backend_resolver=dependencies["backend_resolver"],
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=dependencies["search_space_id"],
|
||||
user_id=dependencies.get("user_id"),
|
||||
thread_id=dependencies.get("thread_id"),
|
||||
read_only=read_only,
|
||||
),
|
||||
build_compaction_mw(llm),
|
||||
build_patch_tool_calls_mw(),
|
||||
kb_ctx_mw,
|
||||
fs_mw,
|
||||
compaction_mw,
|
||||
patch_mw,
|
||||
*([permission_mw] if permission_mw is not None else []),
|
||||
*resilience_mws,
|
||||
build_anthropic_cache_mw(),
|
||||
cache_mw,
|
||||
]
|
||||
|
|
|
|||
|
|
@ -2,18 +2,19 @@
|
|||
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
from fake_useragent import UserAgent
|
||||
from langchain_core.tools import tool
|
||||
from requests import Session
|
||||
from scrapling.fetchers import AsyncFetcher
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
|
||||
from app.connectors.webcrawler_connector import WebCrawlerConnector
|
||||
from app.tasks.document_processors.youtube_processor import get_youtube_video_id
|
||||
from app.utils.proxy_config import get_requests_proxies
|
||||
from app.utils.proxy import get_proxy_url, get_requests_proxies
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -79,15 +80,20 @@ async def _scrape_youtube_video(
|
|||
oembed_url = "https://www.youtube.com/oembed"
|
||||
|
||||
try:
|
||||
async with (
|
||||
aiohttp.ClientSession() as http_session,
|
||||
http_session.get(
|
||||
oembed_url,
|
||||
params=params,
|
||||
proxy=residential_proxies["http"] if residential_proxies else None,
|
||||
) as response,
|
||||
):
|
||||
video_data = await response.json()
|
||||
oembed_fetch_start = time.perf_counter()
|
||||
oembed_page = await AsyncFetcher.get(
|
||||
oembed_url,
|
||||
params=params,
|
||||
proxy=get_proxy_url(),
|
||||
stealthy_headers=True,
|
||||
)
|
||||
logger.info(
|
||||
"[scrape_webpage][perf] source=oembed video=%s status=%s fetch_ms=%.1f",
|
||||
video_id,
|
||||
getattr(oembed_page, "status", None),
|
||||
(time.perf_counter() - oembed_fetch_start) * 1000,
|
||||
)
|
||||
video_data = oembed_page.json()
|
||||
except Exception:
|
||||
video_data = {}
|
||||
|
||||
|
|
@ -96,6 +102,7 @@ async def _scrape_youtube_video(
|
|||
|
||||
# --- Transcript via YouTubeTranscriptApi ---
|
||||
try:
|
||||
transcript_fetch_start = time.perf_counter()
|
||||
ua = UserAgent()
|
||||
http_client = Session()
|
||||
http_client.headers.update({"User-Agent": ua.random})
|
||||
|
|
@ -109,6 +116,11 @@ async def _scrape_youtube_video(
|
|||
transcript = next(iter(transcript_list))
|
||||
captions = transcript.fetch()
|
||||
|
||||
logger.info(
|
||||
"[scrape_webpage][perf] source=transcript video=%s fetch_ms=%.1f",
|
||||
video_id,
|
||||
(time.perf_counter() - transcript_fetch_start) * 1000,
|
||||
)
|
||||
logger.info(
|
||||
f"[scrape_webpage] Fetched transcript for {video_id} "
|
||||
f"in {transcript.language} ({transcript.language_code})"
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import time as _perf_time
|
||||
from typing import Any, Protocol
|
||||
|
||||
from deepagents import SubAgent
|
||||
|
|
@ -72,6 +73,9 @@ from app.agents.chat.multi_agent_chat.subagents.shared.md_file_reader import (
|
|||
read_md_file,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
class SubagentBuilder(Protocol):
|
||||
|
|
@ -192,19 +196,25 @@ def build_subagents(
|
|||
if exclude:
|
||||
excluded.extend(exclude)
|
||||
disabled_names = frozenset(disabled_tools or ())
|
||||
_timings: list[tuple[str, float]] = []
|
||||
for name in sorted(SUBAGENT_BUILDERS_BY_NAME):
|
||||
if name in excluded:
|
||||
continue
|
||||
builder = SUBAGENT_BUILDERS_BY_NAME[name]
|
||||
_t0 = _perf_time.perf_counter()
|
||||
result = builder(
|
||||
dependencies=dependencies,
|
||||
model=model,
|
||||
middleware_stack=middleware_stack,
|
||||
mcp_tools=mcp.get(name),
|
||||
)
|
||||
_timings.append((name, _perf_time.perf_counter() - _t0))
|
||||
spec = result.spec
|
||||
_filter_disabled_tools_in_place(spec, disabled_names)
|
||||
if ask_kb_tool is not None:
|
||||
_inject_ask_kb_tool_in_place(spec, ask_kb_tool)
|
||||
specs.append(spec)
|
||||
if _timings:
|
||||
_detail = " ".join(f"{n}={dt:.3f}s" for n, dt in _timings)
|
||||
_perf_log.info("[build_subagents.detail] %s", _detail)
|
||||
return specs
|
||||
|
|
|
|||
|
|
@ -26,6 +26,16 @@ ContextHintProvider = Callable[[Mapping[str, Any], str], str | None]
|
|||
# The prefix avoids any collision with future deepagents fields.
|
||||
SURF_CONTEXT_HINT_PROVIDER_KEY = "surf_context_hint_provider"
|
||||
|
||||
# Custom key carrying a zero-arg callable that builds the full deepagents
|
||||
# ``SubAgent`` spec dict on demand. A descriptor dict carrying only
|
||||
# ``name`` / ``description`` / this key lets the checkpointed subagent
|
||||
# middleware register a subagent's catalog entry cheaply while deferring the
|
||||
# expensive spec construction (e.g. the knowledge_base filesystem middleware,
|
||||
# which builds ~13 tool schemas at ~150ms each) until the first
|
||||
# ``task(name)`` call. Most turns never invoke a subagent, so this keeps the
|
||||
# cost off the cold agent-build / time-to-first-token path.
|
||||
SURF_LAZY_SPEC_FACTORY_KEY = "surf_lazy_spec_factory"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class SurfSenseSubagentSpec:
|
||||
|
|
@ -54,6 +64,7 @@ class SurfSenseSubagentSpec:
|
|||
|
||||
__all__ = [
|
||||
"SURF_CONTEXT_HINT_PROVIDER_KEY",
|
||||
"SURF_LAZY_SPEC_FACTORY_KEY",
|
||||
"ContextHintProvider",
|
||||
"SurfSenseSubagentSpec",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
import re
|
||||
import time as _perf_time
|
||||
from typing import Any, cast
|
||||
|
||||
from deepagents import SubAgent
|
||||
|
|
@ -23,8 +24,10 @@ from app.agents.chat.multi_agent_chat.subagents.shared.spec import (
|
|||
ContextHintProvider,
|
||||
SurfSenseSubagentSpec,
|
||||
)
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
# ``<include snippet="NAME"/>`` directive. Matches an XML-style self-closing
|
||||
# tag whose ``snippet`` attribute names a file in ``shared/snippets/``.
|
||||
|
|
@ -110,19 +113,31 @@ def pack_subagent(
|
|||
msg = f"Subagent {name!r}: system_prompt is empty"
|
||||
raise ValueError(msg)
|
||||
|
||||
_t0 = _perf_time.perf_counter()
|
||||
system_prompt = _resolve_includes(system_prompt, subagent_name=name)
|
||||
_t_resolve = _perf_time.perf_counter() - _t0
|
||||
|
||||
flags = dependencies["flags"]
|
||||
user_allowlist = _user_allowlist_for(dependencies, name)
|
||||
subagent_rulesets: list[Ruleset] = [ruleset]
|
||||
if user_allowlist is not None:
|
||||
subagent_rulesets.append(user_allowlist)
|
||||
_t0 = _perf_time.perf_counter()
|
||||
per_subagent_perm = build_permission_mw(
|
||||
flags=flags,
|
||||
subagent_rulesets=subagent_rulesets,
|
||||
tools=tools,
|
||||
trusted_tool_saver=dependencies.get("trusted_tool_saver"),
|
||||
)
|
||||
_t_perm = _perf_time.perf_counter() - _t0
|
||||
_perf_log.info(
|
||||
"[pack_subagent] name=%s tools=%d resolve_includes=%.3fs "
|
||||
"build_permission_mw=%.3fs",
|
||||
name,
|
||||
len(tools),
|
||||
_t_resolve,
|
||||
_t_perm,
|
||||
)
|
||||
|
||||
prepended: list[Any] = []
|
||||
for slot, mw in (middleware_stack or {}).items():
|
||||
|
|
|
|||
|
|
@ -571,6 +571,41 @@ async def _warm_agent_jit_caches() -> None:
|
|||
)
|
||||
|
||||
|
||||
async def _warm_embedding_model() -> None:
|
||||
"""Pre-load/JIT the embedding model so the first KB search is fast.
|
||||
|
||||
With lazy KB retrieval (OpenCode-style), the main agent no longer embeds
|
||||
on every turn — it calls the on-demand ``search_knowledge_base`` tool only
|
||||
when it needs KB content, and that tool's first ``embed_texts`` call in a
|
||||
fresh process pays the model's one-time load/JIT (local sentence-transformer
|
||||
warm or API client init). Doing one throwaway embed at startup moves that
|
||||
cost off the first real search.
|
||||
|
||||
Safety: behind the embedding global lock (run in a worker thread), bounded
|
||||
by the caller's ``asyncio.wait_for``, and non-fatal — on any failure we log
|
||||
and swallow so the worst case is the first real search pays the cold cost.
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
t0 = _time.perf_counter()
|
||||
try:
|
||||
from app.utils.document_converters import embed_texts
|
||||
|
||||
await asyncio.to_thread(embed_texts, ["warmup"])
|
||||
logger.info(
|
||||
"[startup] Embedding model warmup completed in %.3fs",
|
||||
_time.perf_counter() - t0,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[startup] Embedding model warmup failed in %.3fs (non-fatal — first "
|
||||
"KB search will pay the cold embed cost)",
|
||||
_time.perf_counter() - t0,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Tune GC: lower gen-2 threshold so long-lived garbage is collected
|
||||
|
|
@ -601,6 +636,16 @@ async def lifespan(app: FastAPI):
|
|||
"first real request will pay the full compile cost."
|
||||
)
|
||||
|
||||
# Phase 2 — embedding warmup so the first lazy ``search_knowledge_base``
|
||||
# call doesn't pay the cold embed-model load. Bounded + non-fatal.
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(_warm_embedding_model()), timeout=20)
|
||||
except (TimeoutError, Exception): # pragma: no cover - defensive
|
||||
logging.getLogger(__name__).warning(
|
||||
"[startup] Embedding warmup hit timeout/error — skipping; "
|
||||
"first KB search will pay the cold embed cost."
|
||||
)
|
||||
|
||||
register_session_hooks()
|
||||
log_system_snapshot("startup_complete")
|
||||
await start_gateway_inbox_worker()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
"""Celery application configuration and setup."""
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import time
|
||||
|
||||
from celery import Celery
|
||||
|
|
@ -19,6 +18,8 @@ try:
|
|||
except ImportError: # pragma: no cover - optional OTel dependency
|
||||
trace = None # type: ignore[assignment]
|
||||
|
||||
from app.config import config
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
|
|
@ -124,16 +125,16 @@ def init_worker(**kwargs):
|
|||
initialize_vision_llm_router()
|
||||
|
||||
|
||||
# Get Celery configuration from environment
|
||||
CELERY_BROKER_URL = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
|
||||
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", "redis://localhost:6379/0")
|
||||
CELERY_TASK_DEFAULT_QUEUE = os.getenv("CELERY_TASK_DEFAULT_QUEUE", "surfsense")
|
||||
# Celery configuration, sourced from the central Config singleton
|
||||
CELERY_BROKER_URL = config.CELERY_BROKER_URL
|
||||
CELERY_RESULT_BACKEND = config.CELERY_RESULT_BACKEND
|
||||
CELERY_TASK_DEFAULT_QUEUE = config.CELERY_TASK_DEFAULT_QUEUE
|
||||
|
||||
# Get schedule checker interval from environment
|
||||
# Schedule checker interval
|
||||
# Format: "<number><unit>" where unit is 'm' (minutes) or 'h' (hours)
|
||||
# Examples: "1m" (every minute), "5m" (every 5 minutes), "1h" (every hour)
|
||||
SCHEDULE_CHECKER_INTERVAL = os.getenv("SCHEDULE_CHECKER_INTERVAL", "2m")
|
||||
STRIPE_RECONCILIATION_INTERVAL = os.getenv("STRIPE_RECONCILIATION_INTERVAL", "10m")
|
||||
SCHEDULE_CHECKER_INTERVAL = config.SCHEDULE_CHECKER_INTERVAL
|
||||
STRIPE_RECONCILIATION_INTERVAL = config.STRIPE_RECONCILIATION_INTERVAL
|
||||
|
||||
|
||||
def parse_schedule_interval(interval: str) -> dict:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import copy
|
||||
import os
|
||||
import shutil
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
|
@ -17,6 +19,37 @@ os.environ.setdefault("OR_APP_NAME", "SurfSense")
|
|||
os.environ.setdefault("OR_SITE_URL", "https://surfsense.com")
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def _read_global_config_yaml(path_str: str) -> dict:
|
||||
"""Read and parse ``global_llm_config.yaml`` once per resolved path.
|
||||
|
||||
Cached so the seven ``load_*`` helpers (and their re-invocations during
|
||||
startup) don't re-open and re-parse the same file repeatedly. Keyed on the
|
||||
resolved path string so tests that monkeypatch ``BASE_DIR`` to a unique
|
||||
``tmp_path`` still get a fresh parse. Callers MUST treat the returned dict
|
||||
as read-only and deep-copy any section they intend to mutate.
|
||||
"""
|
||||
f = Path(path_str)
|
||||
if not f.exists():
|
||||
return {}
|
||||
try:
|
||||
with open(f, encoding="utf-8") as fh:
|
||||
return yaml.safe_load(fh) or {}
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to read global_llm_config.yaml: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def _global_config_data() -> dict:
|
||||
"""Return the parsed global config YAML for the current ``BASE_DIR``.
|
||||
|
||||
``BASE_DIR`` is read at call time (not bound at import) so a
|
||||
``monkeypatch.setattr(config, "BASE_DIR", tmp_path)`` is honored.
|
||||
"""
|
||||
path = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
|
||||
return _read_global_config_yaml(str(path))
|
||||
|
||||
|
||||
def is_ffmpeg_installed():
|
||||
"""
|
||||
Check if ffmpeg is installed on the current system.
|
||||
|
|
@ -35,17 +68,15 @@ def load_global_llm_configs():
|
|||
Returns:
|
||||
list: List of global LLM config dictionaries, or empty list if file doesn't exist
|
||||
"""
|
||||
# Try main config file first
|
||||
global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
|
||||
|
||||
if not global_config_file.exists():
|
||||
data = _global_config_data()
|
||||
if not data:
|
||||
# No global configs available
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
configs = data.get("global_llm_configs", [])
|
||||
# Deep-copy so the in-place mutations below (setdefault, scoring
|
||||
# stamps) never leak into the cached YAML structure.
|
||||
configs = copy.deepcopy(data.get("global_llm_configs", []))
|
||||
|
||||
# Lazy import keeps the `app.config` -> `app.services` edge one-way
|
||||
# and matches the `provider_api_base` pattern used elsewhere.
|
||||
|
|
@ -145,18 +176,14 @@ def load_router_settings():
|
|||
"cooldown_time": 60,
|
||||
}
|
||||
|
||||
# Try main config file first
|
||||
global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
|
||||
|
||||
if not global_config_file.exists():
|
||||
data = _global_config_data()
|
||||
if not data:
|
||||
return default_settings
|
||||
|
||||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
settings = data.get("router_settings", {})
|
||||
# Merge with defaults
|
||||
return {**default_settings, **settings}
|
||||
settings = data.get("router_settings", {})
|
||||
# Merge with defaults
|
||||
return {**default_settings, **settings}
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load router settings: {e}")
|
||||
return default_settings
|
||||
|
|
@ -169,38 +196,32 @@ def load_global_image_gen_configs():
|
|||
Returns:
|
||||
list: List of global image generation config dictionaries, or empty list
|
||||
"""
|
||||
global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
|
||||
|
||||
if not global_config_file.exists():
|
||||
data = _global_config_data()
|
||||
if not data:
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
configs = data.get("global_image_generation_configs", []) or []
|
||||
for cfg in configs:
|
||||
if isinstance(cfg, dict):
|
||||
cfg.setdefault("billing_tier", "free")
|
||||
return configs
|
||||
configs = copy.deepcopy(data.get("global_image_generation_configs", []) or [])
|
||||
for cfg in configs:
|
||||
if isinstance(cfg, dict):
|
||||
cfg.setdefault("billing_tier", "free")
|
||||
return configs
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load global image generation configs: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def load_global_vision_llm_configs():
|
||||
global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
|
||||
|
||||
if not global_config_file.exists():
|
||||
data = _global_config_data()
|
||||
if not data:
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
configs = data.get("global_vision_llm_configs", []) or []
|
||||
for cfg in configs:
|
||||
if isinstance(cfg, dict):
|
||||
cfg.setdefault("billing_tier", "free")
|
||||
return configs
|
||||
configs = copy.deepcopy(data.get("global_vision_llm_configs", []) or [])
|
||||
for cfg in configs:
|
||||
if isinstance(cfg, dict):
|
||||
cfg.setdefault("billing_tier", "free")
|
||||
return configs
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load global vision LLM configs: {e}")
|
||||
return []
|
||||
|
|
@ -214,16 +235,13 @@ def load_vision_llm_router_settings():
|
|||
"cooldown_time": 60,
|
||||
}
|
||||
|
||||
global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
|
||||
|
||||
if not global_config_file.exists():
|
||||
data = _global_config_data()
|
||||
if not data:
|
||||
return default_settings
|
||||
|
||||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
settings = data.get("vision_llm_router_settings", {})
|
||||
return {**default_settings, **settings}
|
||||
settings = data.get("vision_llm_router_settings", {})
|
||||
return {**default_settings, **settings}
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load vision LLM router settings: {e}")
|
||||
return default_settings
|
||||
|
|
@ -243,16 +261,13 @@ def load_image_gen_router_settings():
|
|||
"cooldown_time": 60,
|
||||
}
|
||||
|
||||
global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
|
||||
|
||||
if not global_config_file.exists():
|
||||
data = _global_config_data()
|
||||
if not data:
|
||||
return default_settings
|
||||
|
||||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
settings = data.get("image_generation_router_settings", {})
|
||||
return {**default_settings, **settings}
|
||||
settings = data.get("image_generation_router_settings", {})
|
||||
return {**default_settings, **settings}
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load image generation router settings: {e}")
|
||||
return default_settings
|
||||
|
|
@ -268,49 +283,44 @@ def load_openrouter_integration_settings() -> dict | None:
|
|||
Returns:
|
||||
dict with settings if present and enabled, None otherwise
|
||||
"""
|
||||
global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
|
||||
|
||||
if not global_config_file.exists():
|
||||
data = _global_config_data()
|
||||
if not data:
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(global_config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
settings = data.get("openrouter_integration")
|
||||
if not settings or not settings.get("enabled"):
|
||||
return None
|
||||
# Deep-copy so the setdefault back-compat seeding below never mutates
|
||||
# the cached YAML structure.
|
||||
settings = copy.deepcopy(data.get("openrouter_integration"))
|
||||
if not settings or not settings.get("enabled"):
|
||||
return None
|
||||
|
||||
if "billing_tier" in settings:
|
||||
print(
|
||||
"Warning: openrouter_integration.billing_tier is deprecated; "
|
||||
"tier is now derived per model from OpenRouter data "
|
||||
"(':free' suffix or zero pricing). Remove this key."
|
||||
)
|
||||
if "billing_tier" in settings:
|
||||
print(
|
||||
"Warning: openrouter_integration.billing_tier is deprecated; "
|
||||
"tier is now derived per model from OpenRouter data "
|
||||
"(':free' suffix or zero pricing). Remove this key."
|
||||
)
|
||||
|
||||
if "anonymous_enabled" in settings:
|
||||
print(
|
||||
"Warning: openrouter_integration.anonymous_enabled is "
|
||||
"deprecated; use anonymous_enabled_paid and/or "
|
||||
"anonymous_enabled_free instead. Both new flags have been "
|
||||
"seeded from the legacy value for back-compat."
|
||||
)
|
||||
settings.setdefault(
|
||||
"anonymous_enabled_paid", settings["anonymous_enabled"]
|
||||
)
|
||||
settings.setdefault(
|
||||
"anonymous_enabled_free", settings["anonymous_enabled"]
|
||||
)
|
||||
if "anonymous_enabled" in settings:
|
||||
print(
|
||||
"Warning: openrouter_integration.anonymous_enabled is "
|
||||
"deprecated; use anonymous_enabled_paid and/or "
|
||||
"anonymous_enabled_free instead. Both new flags have been "
|
||||
"seeded from the legacy value for back-compat."
|
||||
)
|
||||
settings.setdefault("anonymous_enabled_paid", settings["anonymous_enabled"])
|
||||
settings.setdefault("anonymous_enabled_free", settings["anonymous_enabled"])
|
||||
|
||||
# Image generation + vision LLM emission are opt-in (issue L).
|
||||
# OpenRouter's catalogue contains hundreds of image / vision
|
||||
# capable models; auto-injecting all of them into every
|
||||
# deployment would explode the model selector and surprise
|
||||
# operators upgrading from prior versions. Default to False so
|
||||
# admins must explicitly turn them on.
|
||||
settings.setdefault("image_generation_enabled", False)
|
||||
settings.setdefault("vision_enabled", False)
|
||||
# Image generation + vision LLM emission are opt-in (issue L).
|
||||
# OpenRouter's catalogue contains hundreds of image / vision
|
||||
# capable models; auto-injecting all of them into every
|
||||
# deployment would explode the model selector and surprise
|
||||
# operators upgrading from prior versions. Default to False so
|
||||
# admins must explicitly turn them on.
|
||||
settings.setdefault("image_generation_enabled", False)
|
||||
settings.setdefault("vision_enabled", False)
|
||||
|
||||
return settings
|
||||
return settings
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load OpenRouter integration settings: {e}")
|
||||
return None
|
||||
|
|
@ -415,7 +425,9 @@ def initialize_llm_router():
|
|||
static YAML configs and dynamic OpenRouter models.
|
||||
"""
|
||||
all_configs = config.GLOBAL_LLM_CONFIGS
|
||||
router_settings = load_router_settings()
|
||||
# Reuse the router settings already parsed at Config construction instead
|
||||
# of re-reading the YAML here.
|
||||
router_settings = config.ROUTER_SETTINGS
|
||||
|
||||
if not all_configs:
|
||||
print("Info: No global LLM configs found, Auto mode will not be available")
|
||||
|
|
@ -439,7 +451,10 @@ def initialize_image_gen_router():
|
|||
This should be called during application startup.
|
||||
"""
|
||||
image_gen_configs = load_global_image_gen_configs()
|
||||
router_settings = load_image_gen_router_settings()
|
||||
# Reuse the router settings already parsed at Config construction. The
|
||||
# *configs* list is intentionally re-read from YAML (it must exclude the
|
||||
# OpenRouter-injected dynamic models held in config.GLOBAL_IMAGE_GEN_CONFIGS).
|
||||
router_settings = config.IMAGE_GEN_ROUTER_SETTINGS
|
||||
|
||||
if not image_gen_configs:
|
||||
print(
|
||||
|
|
@ -462,7 +477,10 @@ def initialize_image_gen_router():
|
|||
|
||||
def initialize_vision_llm_router():
|
||||
vision_configs = load_global_vision_llm_configs()
|
||||
router_settings = load_vision_llm_router_settings()
|
||||
# Reuse the router settings already parsed at Config construction. The
|
||||
# *configs* list is intentionally re-read from YAML (it must exclude the
|
||||
# OpenRouter-injected dynamic models held in config.GLOBAL_VISION_LLM_CONFIGS).
|
||||
router_settings = config.VISION_LLM_ROUTER_SETTINGS
|
||||
|
||||
if not vision_configs:
|
||||
print(
|
||||
|
|
@ -524,16 +542,51 @@ class Config:
|
|||
DATABASE_URL = os.getenv("DATABASE_URL")
|
||||
|
||||
# Celery / Redis
|
||||
CELERY_BROKER_URL = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
|
||||
CELERY_RESULT_BACKEND = os.getenv(
|
||||
"CELERY_RESULT_BACKEND", "redis://localhost:6379/0"
|
||||
)
|
||||
# Redis (single endpoint for Celery broker, result backend, and app cache).
|
||||
# Legacy CELERY_BROKER_URL / CELERY_RESULT_BACKEND / REDIS_APP_URL still
|
||||
# override individually when you need to split Redis across instances.
|
||||
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
||||
CELERY_BROKER_URL = os.getenv("CELERY_BROKER_URL", REDIS_URL)
|
||||
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", REDIS_URL)
|
||||
CELERY_TASK_DEFAULT_QUEUE = os.getenv("CELERY_TASK_DEFAULT_QUEUE", "surfsense")
|
||||
REDIS_APP_URL = os.getenv("REDIS_APP_URL", CELERY_BROKER_URL)
|
||||
CONNECTOR_INDEXING_LOCK_TTL_SECONDS = int(
|
||||
os.getenv("CONNECTOR_INDEXING_LOCK_TTL_SECONDS", str(8 * 60 * 60))
|
||||
)
|
||||
|
||||
# Celery beat scheduling intervals (format: "<number><unit>", e.g. "2m", "1h")
|
||||
SCHEDULE_CHECKER_INTERVAL = os.getenv("SCHEDULE_CHECKER_INTERVAL", "2m")
|
||||
STRIPE_RECONCILIATION_INTERVAL = os.getenv("STRIPE_RECONCILIATION_INTERVAL", "10m")
|
||||
|
||||
# File storage (local filesystem by default; Azure Blob optional)
|
||||
FILE_STORAGE_BACKEND = os.getenv("FILE_STORAGE_BACKEND", "local").strip().lower()
|
||||
AZURE_STORAGE_CONNECTION_STRING = os.getenv("AZURE_STORAGE_CONNECTION_STRING")
|
||||
AZURE_STORAGE_CONTAINER = os.getenv("AZURE_STORAGE_CONTAINER")
|
||||
FILE_STORAGE_LOCAL_PATH = os.getenv(
|
||||
"FILE_STORAGE_LOCAL_PATH", str(BASE_DIR / ".local_object_store")
|
||||
)
|
||||
|
||||
# Daytona sandbox (code execution / filesystem sandbox)
|
||||
DAYTONA_SANDBOX_ENABLED = (
|
||||
os.getenv("DAYTONA_SANDBOX_ENABLED", "FALSE").upper() == "TRUE"
|
||||
)
|
||||
DAYTONA_API_KEY = os.getenv("DAYTONA_API_KEY", "")
|
||||
DAYTONA_API_URL = os.getenv("DAYTONA_API_URL", "https://app.daytona.io/api")
|
||||
DAYTONA_TARGET = os.getenv("DAYTONA_TARGET", "us")
|
||||
DAYTONA_SNAPSHOT_ID = os.getenv("DAYTONA_SNAPSHOT_ID") or None
|
||||
SANDBOX_FILES_DIR = os.getenv("SANDBOX_FILES_DIR", "sandbox_files")
|
||||
|
||||
# Agent cache (in-process LRU+TTL cache for built agents)
|
||||
AGENT_CACHE_MAXSIZE = int(os.getenv("SURFSENSE_AGENT_CACHE_MAXSIZE", "256"))
|
||||
AGENT_CACHE_TTL_SECONDS = float(
|
||||
os.getenv("SURFSENSE_AGENT_CACHE_TTL_SECONDS", "1800")
|
||||
)
|
||||
|
||||
# Connector discovery cache TTL
|
||||
CONNECTOR_DISCOVERY_TTL_SECONDS = float(
|
||||
os.getenv("SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS", "30")
|
||||
)
|
||||
|
||||
# Platform web search (SearXNG)
|
||||
SEARXNG_DEFAULT_HOST = os.getenv("SEARXNG_DEFAULT_HOST")
|
||||
|
||||
|
|
@ -542,6 +595,9 @@ class Config:
|
|||
BACKEND_URL = os.getenv("BACKEND_URL")
|
||||
|
||||
# Messaging gateway (Telegram v1)
|
||||
# Global master switch: when FALSE, no gateway supervisors/workers start and all
|
||||
# gateway HTTP routes return 404, regardless of the per-channel flags below.
|
||||
GATEWAY_ENABLED = os.getenv("GATEWAY_ENABLED", "TRUE").upper() == "TRUE"
|
||||
TELEGRAM_SHARED_BOT_TOKEN = os.getenv("TELEGRAM_SHARED_BOT_TOKEN")
|
||||
TELEGRAM_SHARED_BOT_USERNAME = os.getenv("TELEGRAM_SHARED_BOT_USERNAME")
|
||||
TELEGRAM_WEBHOOK_SECRET = os.getenv("TELEGRAM_WEBHOOK_SECRET")
|
||||
|
|
@ -562,7 +618,9 @@ class Config:
|
|||
WHATSAPP_GRAPH_API_VERSION = os.getenv("WHATSAPP_GRAPH_API_VERSION", "v25.0")
|
||||
WHATSAPP_WEBHOOK_VERIFY_TOKEN = os.getenv("WHATSAPP_WEBHOOK_VERIFY_TOKEN")
|
||||
WHATSAPP_WEBHOOK_APP_SECRET = os.getenv("WHATSAPP_WEBHOOK_APP_SECRET")
|
||||
WHATSAPP_BRIDGE_URL = os.getenv("WHATSAPP_BRIDGE_URL", "http://whatsapp-bridge:9929")
|
||||
WHATSAPP_BRIDGE_URL = os.getenv(
|
||||
"WHATSAPP_BRIDGE_URL", "http://whatsapp-bridge:9929"
|
||||
)
|
||||
GATEWAY_WHATSAPP_INTAKE_MODE = os.getenv(
|
||||
"GATEWAY_WHATSAPP_INTAKE_MODE", "disabled"
|
||||
).lower()
|
||||
|
|
@ -572,7 +630,9 @@ class Config:
|
|||
)
|
||||
GATEWAY_SLACK_CLIENT_ID = os.getenv("SLACK_CLIENT_ID")
|
||||
GATEWAY_SLACK_CLIENT_SECRET = os.getenv("SLACK_CLIENT_SECRET")
|
||||
GATEWAY_SLACK_ENABLED = os.getenv("GATEWAY_SLACK_ENABLED", "FALSE").upper() == "TRUE"
|
||||
GATEWAY_SLACK_ENABLED = (
|
||||
os.getenv("GATEWAY_SLACK_ENABLED", "FALSE").upper() == "TRUE"
|
||||
)
|
||||
GATEWAY_SLACK_SIGNING_SECRET = os.getenv("GATEWAY_SLACK_SIGNING_SECRET")
|
||||
GATEWAY_SLACK_REDIRECT_URI = os.getenv("GATEWAY_SLACK_REDIRECT_URI")
|
||||
GATEWAY_DISCORD_ENABLED = (
|
||||
|
|
@ -856,8 +916,13 @@ class Config:
|
|||
AZURE_DI_ENDPOINT = os.getenv("AZURE_DI_ENDPOINT")
|
||||
AZURE_DI_KEY = os.getenv("AZURE_DI_KEY")
|
||||
|
||||
# Proxy provider selection. Maps to a ProxyProvider implementation registered
|
||||
# in app/utils/proxy/registry.py. Add new vendors there and switch via this var.
|
||||
PROXY_PROVIDER = os.getenv("PROXY_PROVIDER", "anonymous_proxies")
|
||||
|
||||
# Residential Proxy Configuration (anonymous-proxies.net)
|
||||
# Used for web crawling and YouTube transcript fetching to avoid IP bans.
|
||||
# Consumed by the "anonymous_proxies" proxy provider.
|
||||
RESIDENTIAL_PROXY_USERNAME = os.getenv("RESIDENTIAL_PROXY_USERNAME")
|
||||
RESIDENTIAL_PROXY_PASSWORD = os.getenv("RESIDENTIAL_PROXY_PASSWORD")
|
||||
RESIDENTIAL_PROXY_HOSTNAME = os.getenv("RESIDENTIAL_PROXY_HOSTNAME")
|
||||
|
|
|
|||
|
|
@ -1,31 +1,34 @@
|
|||
"""
|
||||
WebCrawler Connector Module
|
||||
|
||||
A module for crawling web pages and extracting content using Firecrawl,
|
||||
plain HTTP+Trafilatura, or Playwright. Provides a unified interface for
|
||||
web scraping.
|
||||
A module for crawling web pages and extracting content using Firecrawl or
|
||||
Scrapling's tiered fetchers, with Trafilatura for HTML -> markdown extraction.
|
||||
Provides a unified interface for web scraping.
|
||||
|
||||
Fallback order:
|
||||
1. Firecrawl (if API key is configured)
|
||||
2. HTTP + Trafilatura (lightweight, works on any event loop)
|
||||
3. Playwright / Chromium (runs in a thread to avoid event-loop limitations)
|
||||
1. Firecrawl (if API key is configured)
|
||||
2. Scrapling AsyncFetcher (fast static HTTP, no browser subprocess)
|
||||
3. Scrapling DynamicFetcher (full browser, run in a thread)
|
||||
4. Scrapling StealthyFetcher (anti-bot stealth browser, run in a thread)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import trafilatura
|
||||
import validators
|
||||
from fake_useragent import UserAgent
|
||||
from firecrawl import AsyncFirecrawlApp
|
||||
from playwright.sync_api import sync_playwright
|
||||
from scrapling.fetchers import AsyncFetcher, DynamicFetcher, StealthyFetcher
|
||||
|
||||
from app.utils.proxy_config import get_playwright_proxy, get_residential_proxy_url
|
||||
from app.utils.proxy import get_proxy_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Prefix for performance/timing log lines so they are easy to grep/filter.
|
||||
_PERF = "[webcrawler][perf]"
|
||||
|
||||
|
||||
class WebCrawlerConnector:
|
||||
"""Class for crawling web pages and extracting content."""
|
||||
|
|
@ -36,8 +39,8 @@ class WebCrawlerConnector:
|
|||
|
||||
Args:
|
||||
firecrawl_api_key: Firecrawl API key (optional). If provided, Firecrawl will be tried first
|
||||
and Chromium will be used as fallback if Firecrawl fails. If not provided,
|
||||
Chromium will be used directly.
|
||||
and Scrapling will be used as fallback if Firecrawl fails. If not provided,
|
||||
Scrapling fetchers are used directly.
|
||||
"""
|
||||
self.firecrawl_api_key = firecrawl_api_key
|
||||
self.use_firecrawl = bool(firecrawl_api_key)
|
||||
|
|
@ -60,8 +63,9 @@ class WebCrawlerConnector:
|
|||
|
||||
Fallback order:
|
||||
1. Firecrawl (if API key configured)
|
||||
2. Plain HTTP + Trafilatura (lightweight, no subprocess)
|
||||
3. Playwright / Chromium (needs subprocess-capable event loop)
|
||||
2. Scrapling AsyncFetcher (fast static HTTP, no subprocess)
|
||||
3. Scrapling DynamicFetcher (full browser, run in a thread)
|
||||
4. Scrapling StealthyFetcher (anti-bot stealth browser, run in a thread)
|
||||
|
||||
Args:
|
||||
url: URL to crawl
|
||||
|
|
@ -74,8 +78,8 @@ class WebCrawlerConnector:
|
|||
- metadata: Page metadata (title, description, etc.)
|
||||
- source: Original URL
|
||||
- crawler_type: Type of crawler used
|
||||
# Validate URL
|
||||
"""
|
||||
total_start = time.perf_counter()
|
||||
try:
|
||||
if not validators.url(url):
|
||||
return None, f"Invalid URL: {url}"
|
||||
|
|
@ -84,48 +88,138 @@ class WebCrawlerConnector:
|
|||
|
||||
# --- 1. Firecrawl (premium, if configured) ---
|
||||
if self.use_firecrawl:
|
||||
tier_start = time.perf_counter()
|
||||
try:
|
||||
logger.info(f"[webcrawler] Using Firecrawl for: {url}")
|
||||
return await self._crawl_with_firecrawl(url, formats), None
|
||||
result = await self._crawl_with_firecrawl(url, formats)
|
||||
self._log_tier_outcome("firecrawl", url, tier_start, "success")
|
||||
self._log_total(url, "firecrawl", total_start)
|
||||
return result, None
|
||||
except Exception as exc:
|
||||
errors.append(f"Firecrawl: {exc!s}")
|
||||
logger.warning(f"[webcrawler] Firecrawl failed for {url}: {exc!s}")
|
||||
self._log_tier_outcome("firecrawl", url, tier_start, "error", exc)
|
||||
|
||||
# --- 2. HTTP + Trafilatura (no subprocess required) ---
|
||||
# --- 2. Scrapling AsyncFetcher (fast static HTTP) ---
|
||||
tier_start = time.perf_counter()
|
||||
try:
|
||||
logger.info(f"[webcrawler] Using HTTP+Trafilatura for: {url}")
|
||||
result = await self._crawl_with_http(url)
|
||||
logger.info(f"[webcrawler] Using Scrapling AsyncFetcher for: {url}")
|
||||
result = await self._crawl_with_async_fetcher(url)
|
||||
if result:
|
||||
self._log_tier_outcome(
|
||||
"scrapling-static", url, tier_start, "success"
|
||||
)
|
||||
self._log_total(url, "scrapling-static", total_start)
|
||||
return result, None
|
||||
errors.append("HTTP+Trafilatura: empty extraction")
|
||||
errors.append("Scrapling static: empty extraction")
|
||||
self._log_tier_outcome("scrapling-static", url, tier_start, "empty")
|
||||
except Exception as exc:
|
||||
errors.append(f"HTTP+Trafilatura: {exc!s}")
|
||||
logger.warning(
|
||||
f"[webcrawler] HTTP+Trafilatura failed for {url}: {exc!s}"
|
||||
errors.append(f"Scrapling static: {exc!s}")
|
||||
self._log_tier_outcome(
|
||||
"scrapling-static", url, tier_start, "error", exc
|
||||
)
|
||||
|
||||
# --- 3. Playwright / Chromium (full browser, last resort) ---
|
||||
# --- 3. Scrapling DynamicFetcher (full browser) ---
|
||||
tier_start = time.perf_counter()
|
||||
try:
|
||||
logger.info(f"[webcrawler] Using Chromium+Trafilatura for: {url}")
|
||||
return await self._crawl_with_chromium(url), None
|
||||
logger.info(f"[webcrawler] Using Scrapling DynamicFetcher for: {url}")
|
||||
result = await self._crawl_with_dynamic(url)
|
||||
if result:
|
||||
self._log_tier_outcome(
|
||||
"scrapling-dynamic", url, tier_start, "success"
|
||||
)
|
||||
self._log_total(url, "scrapling-dynamic", total_start)
|
||||
return result, None
|
||||
errors.append("Scrapling dynamic: empty extraction")
|
||||
self._log_tier_outcome("scrapling-dynamic", url, tier_start, "empty")
|
||||
except NotImplementedError:
|
||||
errors.append(
|
||||
"Chromium: event loop does not support subprocesses "
|
||||
"Scrapling dynamic: event loop does not support subprocesses "
|
||||
"(common on Windows with uvicorn --reload)"
|
||||
)
|
||||
logger.warning(
|
||||
f"[webcrawler] Chromium unavailable for {url}: "
|
||||
"current event loop does not support subprocesses"
|
||||
self._log_tier_outcome(
|
||||
"scrapling-dynamic", url, tier_start, "unavailable"
|
||||
)
|
||||
except Exception as exc:
|
||||
errors.append(f"Chromium: {exc!s}")
|
||||
logger.warning(f"[webcrawler] Chromium failed for {url}: {exc!s}")
|
||||
errors.append(f"Scrapling dynamic: {exc!s}")
|
||||
self._log_tier_outcome(
|
||||
"scrapling-dynamic", url, tier_start, "error", exc
|
||||
)
|
||||
|
||||
# --- 4. Scrapling StealthyFetcher (anti-bot, last resort) ---
|
||||
tier_start = time.perf_counter()
|
||||
try:
|
||||
logger.info(f"[webcrawler] Using Scrapling StealthyFetcher for: {url}")
|
||||
result = await self._crawl_with_stealthy(url)
|
||||
if result:
|
||||
self._log_tier_outcome(
|
||||
"scrapling-stealthy", url, tier_start, "success"
|
||||
)
|
||||
self._log_total(url, "scrapling-stealthy", total_start)
|
||||
return result, None
|
||||
errors.append("Scrapling stealthy: empty extraction")
|
||||
self._log_tier_outcome("scrapling-stealthy", url, tier_start, "empty")
|
||||
except NotImplementedError:
|
||||
errors.append(
|
||||
"Scrapling stealthy: event loop does not support subprocesses "
|
||||
"(common on Windows with uvicorn --reload)"
|
||||
)
|
||||
self._log_tier_outcome(
|
||||
"scrapling-stealthy", url, tier_start, "unavailable"
|
||||
)
|
||||
except Exception as exc:
|
||||
errors.append(f"Scrapling stealthy: {exc!s}")
|
||||
self._log_tier_outcome(
|
||||
"scrapling-stealthy", url, tier_start, "error", exc
|
||||
)
|
||||
|
||||
self._log_total(url, "none", total_start)
|
||||
return None, f"All crawl methods failed for {url}. {'; '.join(errors)}"
|
||||
|
||||
except Exception as e:
|
||||
self._log_total(url, "error", total_start)
|
||||
return None, f"Error crawling URL {url}: {e!s}"
|
||||
|
||||
@staticmethod
|
||||
def _log_tier_outcome(
|
||||
tier: str,
|
||||
url: str,
|
||||
tier_start: float,
|
||||
outcome: str,
|
||||
exc: Exception | None = None,
|
||||
) -> None:
|
||||
"""Log how long a single tier took and how it ended."""
|
||||
elapsed_ms = (time.perf_counter() - tier_start) * 1000
|
||||
if outcome == "error":
|
||||
logger.warning(
|
||||
"%s tier=%s url=%s elapsed_ms=%.1f outcome=error error=%s",
|
||||
_PERF,
|
||||
tier,
|
||||
url,
|
||||
elapsed_ms,
|
||||
exc,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"%s tier=%s url=%s elapsed_ms=%.1f outcome=%s",
|
||||
_PERF,
|
||||
tier,
|
||||
url,
|
||||
elapsed_ms,
|
||||
outcome,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _log_total(url: str, selected: str, total_start: float) -> None:
|
||||
"""Log the total time spent across all attempted tiers."""
|
||||
total_ms = (time.perf_counter() - total_start) * 1000
|
||||
logger.info(
|
||||
"%s url=%s selected=%s total_ms=%.1f",
|
||||
_PERF,
|
||||
url,
|
||||
selected,
|
||||
total_ms,
|
||||
)
|
||||
|
||||
async def _crawl_with_firecrawl(
|
||||
self, url: str, formats: list[str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
|
|
@ -177,52 +271,172 @@ class WebCrawlerConnector:
|
|||
"crawler_type": "firecrawl",
|
||||
}
|
||||
|
||||
async def _crawl_with_http(self, url: str) -> dict[str, Any] | None:
|
||||
async def _crawl_with_async_fetcher(self, url: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Crawl URL using a plain HTTP request + Trafilatura content extraction.
|
||||
Crawl URL using Scrapling's AsyncFetcher (static HTTP) + Trafilatura.
|
||||
|
||||
This method avoids launching a browser subprocess, making it safe to
|
||||
call from any asyncio event loop (including Windows SelectorEventLoop
|
||||
which does not support ``create_subprocess_exec``).
|
||||
|
||||
Returns ``None`` when Trafilatura cannot extract meaningful content
|
||||
(e.g. JS-rendered SPAs) so the caller can fall through to Chromium.
|
||||
AsyncFetcher is httpx/curl_cffi based and does not launch a browser
|
||||
subprocess, making it safe to call from any asyncio event loop. Returns
|
||||
``None`` when Trafilatura cannot extract meaningful content (e.g. JS
|
||||
rendered SPAs) so the caller can fall through to the browser tiers.
|
||||
"""
|
||||
ua = UserAgent()
|
||||
user_agent = ua.random
|
||||
proxy_url = get_residential_proxy_url()
|
||||
fetch_start = time.perf_counter()
|
||||
page = await AsyncFetcher.get(
|
||||
url,
|
||||
stealthy_headers=True,
|
||||
proxy=get_proxy_url(),
|
||||
timeout=20,
|
||||
)
|
||||
fetch_ms = (time.perf_counter() - fetch_start) * 1000
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
timeout=20.0,
|
||||
follow_redirects=True,
|
||||
proxy=proxy_url,
|
||||
headers={
|
||||
"User-Agent": user_agent,
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
|
||||
"Accept-Language": "en-US,en;q=0.9",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
},
|
||||
) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
raw_html = response.text
|
||||
|
||||
if not raw_html or len(raw_html.strip()) == 0:
|
||||
status = getattr(page, "status", None)
|
||||
if status is not None and status >= 400:
|
||||
logger.info(
|
||||
"%s tier=scrapling-static url=%s fetch_ms=%.1f status=%s outcome=http_error",
|
||||
_PERF,
|
||||
url,
|
||||
fetch_ms,
|
||||
status,
|
||||
)
|
||||
return None
|
||||
|
||||
extracted_content = trafilatura.extract(
|
||||
raw_html,
|
||||
output_format="markdown",
|
||||
include_comments=False,
|
||||
include_tables=True,
|
||||
include_images=True,
|
||||
include_links=True,
|
||||
return self._build_result(
|
||||
page.html_content,
|
||||
url,
|
||||
"scrapling-static",
|
||||
allow_raw_fallback=False,
|
||||
fetch_ms=fetch_ms,
|
||||
status=status,
|
||||
)
|
||||
|
||||
if not extracted_content or len(extracted_content.strip()) == 0:
|
||||
async def _crawl_with_dynamic(self, url: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Crawl URL using Scrapling's DynamicFetcher (full browser) + Trafilatura.
|
||||
|
||||
Runs the sync fetch in a worker thread so it works on any event loop,
|
||||
including Windows ``SelectorEventLoop`` which cannot spawn subprocesses.
|
||||
"""
|
||||
return await asyncio.to_thread(self._crawl_with_dynamic_sync, url)
|
||||
|
||||
def _crawl_with_dynamic_sync(self, url: str) -> dict[str, Any] | None:
|
||||
"""Synchronous DynamicFetcher crawl executed in a worker thread."""
|
||||
fetch_start = time.perf_counter()
|
||||
page = DynamicFetcher.fetch(
|
||||
url,
|
||||
headless=True,
|
||||
network_idle=True,
|
||||
timeout=30000,
|
||||
proxy=get_proxy_url(),
|
||||
)
|
||||
fetch_ms = (time.perf_counter() - fetch_start) * 1000
|
||||
return self._build_result(
|
||||
page.html_content,
|
||||
url,
|
||||
"scrapling-dynamic",
|
||||
allow_raw_fallback=False,
|
||||
fetch_ms=fetch_ms,
|
||||
status=getattr(page, "status", None),
|
||||
)
|
||||
|
||||
async def _crawl_with_stealthy(self, url: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Crawl URL using Scrapling's StealthyFetcher (Camoufox) + Trafilatura.
|
||||
|
||||
Last-resort tier with anti-bot features. Runs the sync fetch in a worker
|
||||
thread for the same event-loop-safety reasons as DynamicFetcher. Falls
|
||||
back to the raw HTML when Trafilatura extraction is empty.
|
||||
"""
|
||||
return await asyncio.to_thread(self._crawl_with_stealthy_sync, url)
|
||||
|
||||
def _crawl_with_stealthy_sync(self, url: str) -> dict[str, Any] | None:
|
||||
"""Synchronous StealthyFetcher crawl executed in a worker thread."""
|
||||
fetch_start = time.perf_counter()
|
||||
page = StealthyFetcher.fetch(
|
||||
url,
|
||||
headless=True,
|
||||
network_idle=True,
|
||||
block_ads=True,
|
||||
proxy=get_proxy_url(),
|
||||
)
|
||||
fetch_ms = (time.perf_counter() - fetch_start) * 1000
|
||||
return self._build_result(
|
||||
page.html_content,
|
||||
url,
|
||||
"scrapling-stealthy",
|
||||
allow_raw_fallback=True,
|
||||
fetch_ms=fetch_ms,
|
||||
status=getattr(page, "status", None),
|
||||
)
|
||||
|
||||
def _build_result(
|
||||
self,
|
||||
raw_html: str | None,
|
||||
url: str,
|
||||
crawler_type: str,
|
||||
*,
|
||||
allow_raw_fallback: bool,
|
||||
fetch_ms: float | None = None,
|
||||
status: int | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Extract markdown + metadata from raw HTML using Trafilatura.
|
||||
|
||||
Args:
|
||||
raw_html: Raw HTML source from a fetcher.
|
||||
url: Original URL (used as the metadata source/title fallback).
|
||||
crawler_type: Identifier of the tier that produced the HTML.
|
||||
allow_raw_fallback: When True, return the raw HTML as content if
|
||||
Trafilatura cannot extract anything (used by the last-resort
|
||||
stealthy tier). When False, return ``None`` so the caller can
|
||||
fall through to the next tier.
|
||||
fetch_ms: Time spent fetching the page (for perf logging).
|
||||
status: HTTP status code returned by the fetcher (for perf logging).
|
||||
|
||||
Returns:
|
||||
Result dict (content/metadata/crawler_type) or ``None``.
|
||||
"""
|
||||
html_len = len(raw_html) if raw_html else 0
|
||||
|
||||
if not raw_html or len(raw_html.strip()) == 0:
|
||||
self._log_build(
|
||||
crawler_type, url, fetch_ms, 0.0, status, html_len, 0, "empty_html"
|
||||
)
|
||||
return None
|
||||
|
||||
trafilatura_metadata = trafilatura.extract_metadata(raw_html)
|
||||
extract_start = time.perf_counter()
|
||||
extracted_content: str | None = None
|
||||
trafilatura_metadata = None
|
||||
|
||||
try:
|
||||
extracted_content = trafilatura.extract(
|
||||
raw_html,
|
||||
output_format="markdown",
|
||||
include_comments=False,
|
||||
include_tables=True,
|
||||
include_images=True,
|
||||
include_links=True,
|
||||
)
|
||||
trafilatura_metadata = trafilatura.extract_metadata(raw_html)
|
||||
|
||||
if extracted_content and len(extracted_content.strip()) == 0:
|
||||
extracted_content = None
|
||||
except Exception:
|
||||
extracted_content = None
|
||||
|
||||
extract_ms = (time.perf_counter() - extract_start) * 1000
|
||||
|
||||
if not extracted_content and not allow_raw_fallback:
|
||||
self._log_build(
|
||||
crawler_type,
|
||||
url,
|
||||
fetch_ms,
|
||||
extract_ms,
|
||||
status,
|
||||
html_len,
|
||||
0,
|
||||
"no_extraction",
|
||||
)
|
||||
return None
|
||||
|
||||
metadata: dict[str, str] = {"source": url}
|
||||
if trafilatura_metadata:
|
||||
|
|
@ -236,105 +450,51 @@ class WebCrawlerConnector:
|
|||
metadata["date"] = trafilatura_metadata.date
|
||||
metadata.setdefault("title", url)
|
||||
|
||||
return {
|
||||
"content": extracted_content,
|
||||
"metadata": metadata,
|
||||
"crawler_type": "http",
|
||||
}
|
||||
|
||||
async def _crawl_with_chromium(self, url: str) -> dict[str, Any]:
|
||||
"""
|
||||
Crawl URL using Playwright with Trafilatura for content extraction.
|
||||
Falls back to raw HTML if Trafilatura extraction fails.
|
||||
|
||||
Runs the sync Playwright API in a thread so it works on any event
|
||||
loop, including Windows ``SelectorEventLoop`` which cannot spawn
|
||||
subprocesses.
|
||||
|
||||
Args:
|
||||
url: URL to crawl
|
||||
|
||||
Returns:
|
||||
Dict containing crawled content and metadata
|
||||
|
||||
Raises:
|
||||
Exception: If crawling fails
|
||||
"""
|
||||
return await asyncio.to_thread(self._crawl_with_chromium_sync, url)
|
||||
|
||||
def _crawl_with_chromium_sync(self, url: str) -> dict[str, Any]:
|
||||
"""Synchronous Playwright crawl executed in a worker thread."""
|
||||
ua = UserAgent()
|
||||
user_agent = ua.random
|
||||
|
||||
playwright_proxy = get_playwright_proxy()
|
||||
|
||||
with sync_playwright() as p:
|
||||
launch_kwargs: dict = {"headless": True}
|
||||
if playwright_proxy:
|
||||
launch_kwargs["proxy"] = playwright_proxy
|
||||
browser = p.chromium.launch(**launch_kwargs)
|
||||
context = browser.new_context(user_agent=user_agent)
|
||||
page = context.new_page()
|
||||
|
||||
try:
|
||||
page.goto(url, wait_until="domcontentloaded", timeout=30000)
|
||||
raw_html = page.content()
|
||||
page_title = page.title()
|
||||
finally:
|
||||
browser.close()
|
||||
|
||||
if not raw_html:
|
||||
raise ValueError(f"Failed to load content from {url}")
|
||||
|
||||
base_metadata = {"title": page_title} if page_title else {}
|
||||
|
||||
extracted_content = None
|
||||
trafilatura_metadata = None
|
||||
|
||||
try:
|
||||
extracted_content = trafilatura.extract(
|
||||
raw_html,
|
||||
output_format="markdown",
|
||||
include_comments=False,
|
||||
include_tables=True,
|
||||
include_images=True,
|
||||
include_links=True,
|
||||
)
|
||||
|
||||
trafilatura_metadata = trafilatura.extract_metadata(raw_html)
|
||||
|
||||
if not extracted_content or len(extracted_content.strip()) == 0:
|
||||
extracted_content = None
|
||||
|
||||
except Exception:
|
||||
extracted_content = None
|
||||
|
||||
metadata = {
|
||||
"source": url,
|
||||
"title": (
|
||||
trafilatura_metadata.title
|
||||
if trafilatura_metadata and trafilatura_metadata.title
|
||||
else base_metadata.get("title", url)
|
||||
),
|
||||
}
|
||||
|
||||
if trafilatura_metadata:
|
||||
if trafilatura_metadata.description:
|
||||
metadata["description"] = trafilatura_metadata.description
|
||||
if trafilatura_metadata.author:
|
||||
metadata["author"] = trafilatura_metadata.author
|
||||
if trafilatura_metadata.date:
|
||||
metadata["date"] = trafilatura_metadata.date
|
||||
|
||||
metadata.update(base_metadata)
|
||||
content = extracted_content if extracted_content else raw_html
|
||||
self._log_build(
|
||||
crawler_type,
|
||||
url,
|
||||
fetch_ms,
|
||||
extract_ms,
|
||||
status,
|
||||
html_len,
|
||||
len(content),
|
||||
"extracted" if extracted_content else "raw_fallback",
|
||||
)
|
||||
|
||||
return {
|
||||
"content": extracted_content if extracted_content else raw_html,
|
||||
"content": content,
|
||||
"metadata": metadata,
|
||||
"crawler_type": "chromium",
|
||||
"crawler_type": crawler_type,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _log_build(
|
||||
crawler_type: str,
|
||||
url: str,
|
||||
fetch_ms: float | None,
|
||||
extract_ms: float,
|
||||
status: int | None,
|
||||
html_len: int,
|
||||
content_len: int,
|
||||
outcome: str,
|
||||
) -> None:
|
||||
"""Emit a detailed perf line splitting fetch vs Trafilatura extraction."""
|
||||
fetch_repr = f"{fetch_ms:.1f}" if fetch_ms is not None else "n/a"
|
||||
logger.info(
|
||||
"%s tier=%s url=%s status=%s fetch_ms=%s extract_ms=%.1f "
|
||||
"html_len=%d content_len=%d outcome=%s",
|
||||
_PERF,
|
||||
crawler_type,
|
||||
url,
|
||||
status,
|
||||
fetch_repr,
|
||||
extract_ms,
|
||||
html_len,
|
||||
content_len,
|
||||
outcome,
|
||||
)
|
||||
|
||||
def format_to_structured_document(
|
||||
self, crawl_result: dict[str, Any], exclude_metadata: bool = False
|
||||
) -> str:
|
||||
|
|
|
|||
|
|
@ -714,7 +714,9 @@ class NewChatThread(BaseModel, TimestampMixin):
|
|||
|
||||
# Surface metadata for first-party SurfSense and external chat threads.
|
||||
# Zero publishes all chat-message sources; the UI can decide which surfaces to render.
|
||||
source = Column(Text, nullable=False, default="surfsense", server_default="surfsense")
|
||||
source = Column(
|
||||
Text, nullable=False, default="surfsense", server_default="surfsense"
|
||||
)
|
||||
external_chat_binding_id = Column(
|
||||
BigInteger,
|
||||
ForeignKey("external_chat_bindings.id", ondelete="SET NULL"),
|
||||
|
|
@ -802,7 +804,9 @@ class NewChatMessage(BaseModel, TimestampMixin):
|
|||
|
||||
# Mirrors the parent thread source for publication-level filtering.
|
||||
# This denormalization avoids join-dependent logical replication rules.
|
||||
source = Column(Text, nullable=False, default="surfsense", server_default="surfsense")
|
||||
source = Column(
|
||||
Text, nullable=False, default="surfsense", server_default="surfsense"
|
||||
)
|
||||
platform_metadata = Column(JSONB, nullable=True)
|
||||
|
||||
# Relationships
|
||||
|
|
@ -848,11 +852,15 @@ class ExternalChatAccount(Base, TimestampMixin):
|
|||
owner_search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
is_system_account = Column(Boolean, nullable=False, default=False, server_default="false")
|
||||
is_system_account = Column(
|
||||
Boolean, nullable=False, default=False, server_default="false"
|
||||
)
|
||||
encrypted_credentials = Column(Text, nullable=True)
|
||||
bot_username = Column(String(255), nullable=True)
|
||||
webhook_secret = Column(String(64), nullable=True)
|
||||
cursor_state = Column(JSONB, nullable=False, default=dict, server_default=text("'{}'::jsonb"))
|
||||
cursor_state = Column(
|
||||
JSONB, nullable=False, default=dict, server_default=text("'{}'::jsonb")
|
||||
)
|
||||
health_status = Column(
|
||||
SQLAlchemyEnum(
|
||||
ExternalChatHealthStatus,
|
||||
|
|
@ -875,7 +883,9 @@ class ExternalChatAccount(Base, TimestampMixin):
|
|||
)
|
||||
|
||||
owner = relationship("User", foreign_keys=[owner_user_id])
|
||||
owner_search_space = relationship("SearchSpace", foreign_keys=[owner_search_space_id])
|
||||
owner_search_space = relationship(
|
||||
"SearchSpace", foreign_keys=[owner_search_space_id]
|
||||
)
|
||||
bindings = relationship(
|
||||
"ExternalChatBinding",
|
||||
back_populates="account",
|
||||
|
|
@ -980,7 +990,9 @@ class ExternalChatBinding(Base, TimestampMixin):
|
|||
external_thread_id = Column(Text, nullable=True)
|
||||
external_display_name = Column(Text, nullable=True)
|
||||
external_username = Column(Text, nullable=True)
|
||||
external_metadata = Column(JSONB, nullable=False, default=dict, server_default=text("'{}'::jsonb"))
|
||||
external_metadata = Column(
|
||||
JSONB, nullable=False, default=dict, server_default=text("'{}'::jsonb")
|
||||
)
|
||||
new_chat_thread_id = Column(
|
||||
Integer,
|
||||
ForeignKey("new_chat_threads.id", ondelete="SET NULL"),
|
||||
|
|
@ -1030,7 +1042,9 @@ class ExternalChatBinding(Base, TimestampMixin):
|
|||
postgresql_where=text("state = 'pending'"),
|
||||
),
|
||||
Index("ix_external_chat_bindings_user_state", "user_id", "state"),
|
||||
Index("ix_external_chat_bindings_search_space_state", "search_space_id", "state"),
|
||||
Index(
|
||||
"ix_external_chat_bindings_search_space_state", "search_space_id", "state"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from app.file_storage.backends.base import StorageBackend
|
||||
|
|
@ -43,10 +44,8 @@ class AzureBlobBackend(StorageBackend):
|
|||
|
||||
async with self._service() as service:
|
||||
blob = service.get_blob_client(self._container, key)
|
||||
try:
|
||||
with contextlib.suppress(ResourceNotFoundError):
|
||||
await blob.delete_blob()
|
||||
except ResourceNotFoundError:
|
||||
pass
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
async with self._service() as service:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -53,10 +54,8 @@ class LocalFileBackend(StorageBackend):
|
|||
path = self._path_for(key)
|
||||
|
||||
def _unlink() -> None:
|
||||
try:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
path.unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_unlink)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,18 +1,12 @@
|
|||
"""Environment-driven configuration for the file-storage module."""
|
||||
"""Configuration for the file-storage module, sourced from the central Config."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
LOCAL_BACKEND = "local"
|
||||
AZURE_BACKEND = "azure"
|
||||
|
||||
# surfsense_backend/ — two levels up from app/file_storage/settings.py
|
||||
_BACKEND_ROOT = Path(__file__).resolve().parents[2]
|
||||
_DEFAULT_LOCAL_ROOT = str(_BACKEND_ROOT / ".local_object_store")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StorageSettings:
|
||||
|
|
@ -25,13 +19,15 @@ class StorageSettings:
|
|||
|
||||
|
||||
def load_storage_settings() -> StorageSettings:
|
||||
"""Read storage settings from the environment.
|
||||
"""Resolve storage settings from the central ``Config`` singleton.
|
||||
|
||||
Defaults to the ``local`` backend so development needs no cloud creds.
|
||||
"""
|
||||
from app.config import config
|
||||
|
||||
return StorageSettings(
|
||||
backend=os.getenv("FILE_STORAGE_BACKEND", LOCAL_BACKEND).strip().lower(),
|
||||
azure_connection_string=os.getenv("AZURE_STORAGE_CONNECTION_STRING"),
|
||||
azure_container=os.getenv("AZURE_STORAGE_CONTAINER"),
|
||||
local_root=os.getenv("FILE_STORAGE_LOCAL_PATH", _DEFAULT_LOCAL_ROOT),
|
||||
backend=config.FILE_STORAGE_BACKEND,
|
||||
azure_connection_string=config.AZURE_STORAGE_CONNECTION_STRING,
|
||||
azure_container=config.AZURE_STORAGE_CONTAINER,
|
||||
local_root=config.FILE_STORAGE_LOCAL_PATH,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,2 +1,19 @@
|
|||
"""Messaging gateway infrastructure for external chat channels."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.config import config
|
||||
|
||||
|
||||
def require_gateway_enabled() -> None:
|
||||
"""FastAPI dependency that gates all gateway HTTP routes on the global flag.
|
||||
|
||||
Returns 404 (rather than 503) when ``GATEWAY_ENABLED`` is FALSE so that
|
||||
disabling the gateway makes its webhook/OAuth/pairing surface indistinguishable
|
||||
from a route that does not exist.
|
||||
"""
|
||||
|
||||
if not config.GATEWAY_ENABLED:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found")
|
||||
|
|
|
|||
|
|
@ -31,7 +31,9 @@ def slack_account_credentials(account: ExternalChatAccount) -> dict:
|
|||
"""Decrypt Slack gateway credentials stored as encrypted JSON."""
|
||||
if not account.encrypted_credentials:
|
||||
return {}
|
||||
raw = TokenEncryption(config.SECRET_KEY or "").decrypt_token(account.encrypted_credentials)
|
||||
raw = TokenEncryption(config.SECRET_KEY or "").decrypt_token(
|
||||
account.encrypted_credentials
|
||||
)
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
|
|
@ -44,7 +46,9 @@ def discord_account_credentials(account: ExternalChatAccount) -> dict:
|
|||
"""Decrypt Discord gateway credentials stored as encrypted JSON."""
|
||||
if not account.encrypted_credentials:
|
||||
return {}
|
||||
raw = TokenEncryption(config.SECRET_KEY or "").decrypt_token(account.encrypted_credentials)
|
||||
raw = TokenEncryption(config.SECRET_KEY or "").decrypt_token(
|
||||
account.encrypted_credentials
|
||||
)
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
|
|
@ -135,4 +139,3 @@ async def get_discord_account_by_guild(
|
|||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,9 @@ from app.tasks.chat.streaming.flows import stream_new_chat
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _events_from_sse(chunks: AsyncIterator[str]) -> AsyncIterator[GatewayStreamEvent]:
|
||||
async def _events_from_sse(
|
||||
chunks: AsyncIterator[str],
|
||||
) -> AsyncIterator[GatewayStreamEvent]:
|
||||
saw_text = False
|
||||
async for chunk in chunks:
|
||||
for raw_line in chunk.splitlines():
|
||||
|
|
@ -98,4 +100,3 @@ async def call_agent_for_gateway(
|
|||
record_gateway_turn_latency(0, platform=platform_label)
|
||||
finally:
|
||||
release_thread_lock(thread.id)
|
||||
|
||||
|
|
|
|||
|
|
@ -52,4 +52,3 @@ async def assert_authorization_invariant(
|
|||
await _fail(session, binding, f"rbac_{exc.status_code}")
|
||||
|
||||
return user
|
||||
|
||||
|
|
|
|||
|
|
@ -1,2 +1 @@
|
|||
"""Base gateway interfaces."""
|
||||
|
||||
|
|
|
|||
|
|
@ -62,9 +62,10 @@ class BasePlatformAdapter(ABC):
|
|||
async def validate_credentials(self) -> dict[str, Any]:
|
||||
"""Validate configured credentials and return account metadata."""
|
||||
|
||||
async def fetch_updates(self, *, offset: int | None) -> AsyncIterator[dict[str, Any]]:
|
||||
async def fetch_updates(
|
||||
self, *, offset: int | None
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
"""Yield provider updates for long-polling adapters."""
|
||||
if False:
|
||||
yield {} # pragma: no cover
|
||||
raise NotImplementedError("This adapter does not support long-polling")
|
||||
|
||||
|
|
|
|||
|
|
@ -16,4 +16,3 @@ def hash_external_id(value: str | int | None) -> str | None:
|
|||
if not normalized:
|
||||
return None
|
||||
return hashlib.sha256(normalized.encode("utf-8")).hexdigest()
|
||||
|
||||
|
|
|
|||
|
|
@ -25,4 +25,3 @@ class BaseStreamTranslator(ABC):
|
|||
@abstractmethod
|
||||
async def translate(self, events: AsyncIterator[GatewayStreamEvent]) -> None:
|
||||
"""Consume agent stream events and emit platform messages."""
|
||||
|
||||
|
|
|
|||
|
|
@ -64,4 +64,3 @@ def resume_binding(binding: ExternalChatBinding) -> None:
|
|||
binding.state = ExternalChatBindingState.BOUND
|
||||
binding.suspended_at = None
|
||||
binding.suspended_reason = None
|
||||
|
||||
|
|
|
|||
|
|
@ -58,8 +58,10 @@ async def _whatsapp_baileys_supervisor() -> None:
|
|||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ExternalChatAccount).where(
|
||||
ExternalChatAccount.platform == ExternalChatPlatform.WHATSAPP,
|
||||
ExternalChatAccount.mode == ExternalChatAccountMode.SELF_HOST_BYO,
|
||||
ExternalChatAccount.platform
|
||||
== ExternalChatPlatform.WHATSAPP,
|
||||
ExternalChatAccount.mode
|
||||
== ExternalChatAccountMode.SELF_HOST_BYO,
|
||||
ExternalChatAccount.is_system_account.is_(False),
|
||||
ExternalChatAccount.suspended_at.is_(None),
|
||||
)
|
||||
|
|
@ -96,6 +98,8 @@ async def start_byo_long_poll_supervisors() -> None:
|
|||
"""Start one BYO long-poll supervisor per active non-system Telegram account."""
|
||||
|
||||
global _shutdown_event
|
||||
if not config.GATEWAY_ENABLED:
|
||||
return
|
||||
if (
|
||||
config.GATEWAY_TELEGRAM_INTAKE_MODE != "longpoll"
|
||||
and config.GATEWAY_WHATSAPP_INTAKE_MODE != "baileys"
|
||||
|
|
@ -126,7 +130,9 @@ async def start_byo_long_poll_supervisors() -> None:
|
|||
)
|
||||
_tasks.add(task)
|
||||
task.add_done_callback(_tasks.discard)
|
||||
logger.info("Started BYO Telegram long-poll supervisor account_id=%s", account.id)
|
||||
logger.info(
|
||||
"Started BYO Telegram long-poll supervisor account_id=%s", account.id
|
||||
)
|
||||
|
||||
if config.GATEWAY_WHATSAPP_INTAKE_MODE == "baileys":
|
||||
task = asyncio.create_task(
|
||||
|
|
@ -149,9 +155,12 @@ async def stop_byo_long_poll_supervisors() -> None:
|
|||
task.cancel()
|
||||
if tasks:
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*tasks, return_exceptions=True), timeout=10
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning("Timed out waiting for BYO Telegram long-poll supervisors to stop")
|
||||
logger.warning(
|
||||
"Timed out waiting for BYO Telegram long-poll supervisors to stop"
|
||||
)
|
||||
_tasks.clear()
|
||||
_shutdown_event = None
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,9 @@ def _message_reference_payload(message: discord.Message) -> dict[str, Any] | Non
|
|||
}
|
||||
|
||||
|
||||
def _serialize_message(message: discord.Message, *, bot_user_id: str | None) -> dict[str, Any]:
|
||||
def _serialize_message(
|
||||
message: discord.Message, *, bot_user_id: str | None
|
||||
) -> dict[str, Any]:
|
||||
guild = message.guild
|
||||
channel = message.channel
|
||||
thread_id = str(channel.id) if isinstance(channel, discord.Thread) else None
|
||||
|
|
@ -62,8 +64,7 @@ def _serialize_message(message: discord.Message, *, bot_user_id: str | None) ->
|
|||
"bot": message.author.bot,
|
||||
},
|
||||
"mentions": [
|
||||
{"id": str(user.id), "username": user.name}
|
||||
for user in message.mentions
|
||||
{"id": str(user.id), "username": user.name} for user in message.mentions
|
||||
],
|
||||
"message_reference": _message_reference_payload(message),
|
||||
"created_at": message.created_at.isoformat()
|
||||
|
|
@ -73,7 +74,9 @@ def _serialize_message(message: discord.Message, *, bot_user_id: str | None) ->
|
|||
}
|
||||
|
||||
|
||||
async def _persist_message(message: discord.Message, *, bot_user_id: str | None) -> None:
|
||||
async def _persist_message(
|
||||
message: discord.Message, *, bot_user_id: str | None
|
||||
) -> None:
|
||||
if message.guild is None:
|
||||
return
|
||||
guild_id = str(message.guild.id)
|
||||
|
|
@ -82,7 +85,9 @@ async def _persist_message(message: discord.Message, *, bot_user_id: str | None)
|
|||
async with async_session_maker() as session:
|
||||
account = await get_discord_account_by_guild(session, guild_id=guild_id)
|
||||
if account is None:
|
||||
logger.info("Ignoring Discord message for uninstalled guild_id=%s", guild_id)
|
||||
logger.info(
|
||||
"Ignoring Discord message for uninstalled guild_id=%s", guild_id
|
||||
)
|
||||
return
|
||||
|
||||
inbox_id = await persist_inbound_event(
|
||||
|
|
@ -144,7 +149,9 @@ def _build_client() -> discord.Client:
|
|||
try:
|
||||
await _persist_message(message, bot_user_id=bot_user_id)
|
||||
except Exception:
|
||||
logger.exception("Discord gateway failed to persist message_id=%s", message.id)
|
||||
logger.exception(
|
||||
"Discord gateway failed to persist message_id=%s", message.id
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
|
|
@ -177,6 +184,8 @@ async def _run_discord_gateway() -> None:
|
|||
|
||||
async def start_discord_gateway_supervisor() -> None:
|
||||
global _shutdown_event, _task
|
||||
if not config.GATEWAY_ENABLED:
|
||||
return
|
||||
if not config.GATEWAY_DISCORD_ENABLED:
|
||||
return
|
||||
if _task is not None and not _task.done():
|
||||
|
|
|
|||
|
|
@ -41,7 +41,9 @@ class DiscordStreamTranslator(BaseStreamTranslator):
|
|||
async def translate(self, events: AsyncIterator[GatewayStreamEvent]) -> None:
|
||||
async for event in events:
|
||||
if event.type in {"text-delta", "text_delta", "text"}:
|
||||
self._buffer += str(event.data.get("text") or event.data.get("delta") or "")
|
||||
self._buffer += str(
|
||||
event.data.get("text") or event.data.get("delta") or ""
|
||||
)
|
||||
elif event.type in {"data-interrupt-request", "interrupt"}:
|
||||
await self._handle_hitl_interrupt()
|
||||
return
|
||||
|
|
@ -53,7 +55,9 @@ class DiscordStreamTranslator(BaseStreamTranslator):
|
|||
async def _flush_final(self) -> None:
|
||||
if not self._buffer:
|
||||
return
|
||||
for chunk in split_text_message(self._buffer, max_chars=DISCORD_MAX_MESSAGE_CHARS):
|
||||
for chunk in split_text_message(
|
||||
self._buffer, max_chars=DISCORD_MAX_MESSAGE_CHARS
|
||||
):
|
||||
await self._send_text(chunk)
|
||||
|
||||
async def _send_text(self, text: str) -> PlatformSendResult:
|
||||
|
|
|
|||
|
|
@ -32,4 +32,3 @@ def filter_hitl_tools(
|
|||
return None
|
||||
blocked = blocked_names or DEFAULT_HITL_TOOL_NAMES
|
||||
return [tool for tool in toolkit if (_tool_name(tool) or "") not in blocked]
|
||||
|
||||
|
|
|
|||
|
|
@ -51,4 +51,3 @@ async def persist_inbound_event(
|
|||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
|
|
|||
|
|
@ -128,7 +128,9 @@ async def process_inbound_event(
|
|||
event.status = ExternalChatEventStatus.PROCESSED
|
||||
event.processed_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
record_gateway_inbox_processed(platform=event.platform.value, status="processed")
|
||||
record_gateway_inbox_processed(
|
||||
platform=event.platform.value, status="processed"
|
||||
)
|
||||
|
||||
|
||||
async def _mark_failed(
|
||||
|
|
@ -173,7 +175,9 @@ async def _resolve_slack_thread_binding(
|
|||
parsed,
|
||||
) -> ExternalChatBinding | None:
|
||||
user_peer_id = parsed.metadata.get("slack_user_peer_id")
|
||||
thread_peer_id = parsed.metadata.get("slack_thread_peer_id") or parsed.external_peer_id
|
||||
thread_peer_id = (
|
||||
parsed.metadata.get("slack_thread_peer_id") or parsed.external_peer_id
|
||||
)
|
||||
if not user_peer_id or not thread_peer_id:
|
||||
return None
|
||||
|
||||
|
|
@ -233,7 +237,9 @@ async def _resolve_discord_thread_binding(
|
|||
parsed,
|
||||
) -> ExternalChatBinding | None:
|
||||
user_peer_id = parsed.metadata.get("discord_user_peer_id")
|
||||
thread_peer_id = parsed.metadata.get("discord_thread_peer_id") or parsed.external_peer_id
|
||||
thread_peer_id = (
|
||||
parsed.metadata.get("discord_thread_peer_id") or parsed.external_peer_id
|
||||
)
|
||||
if not user_peer_id or not thread_peer_id:
|
||||
return None
|
||||
|
||||
|
|
@ -357,7 +363,11 @@ async def _dispatch_inbound_event(
|
|||
return
|
||||
|
||||
if binding is None:
|
||||
if bundle.auto_bind_owner and account.owner_user_id and account.owner_search_space_id:
|
||||
if (
|
||||
bundle.auto_bind_owner
|
||||
and account.owner_user_id
|
||||
and account.owner_search_space_id
|
||||
):
|
||||
binding = ExternalChatBinding(
|
||||
account_id=account.id,
|
||||
user_id=account.owner_user_id,
|
||||
|
|
@ -385,7 +395,9 @@ async def _dispatch_inbound_event(
|
|||
event.external_chat_binding_id = binding.id
|
||||
|
||||
if cmd == "/help":
|
||||
handled = await bundle.commands.handle_help_command(adapter=adapter, event=parsed)
|
||||
handled = await bundle.commands.handle_help_command(
|
||||
adapter=adapter, event=parsed
|
||||
)
|
||||
if handled:
|
||||
event.status = ExternalChatEventStatus.PROCESSED
|
||||
await session.commit()
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import asyncio
|
|||
import logging
|
||||
from contextlib import suppress
|
||||
|
||||
from app.config import config
|
||||
from app.gateway.inbox_processor import claim_next_inbound_event, process_inbound_event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -39,6 +40,8 @@ async def _process_inbox_forever() -> None:
|
|||
|
||||
async def start_gateway_inbox_worker() -> None:
|
||||
global _task
|
||||
if not config.GATEWAY_ENABLED:
|
||||
return
|
||||
if _task is not None and not _task.done():
|
||||
return
|
||||
_task = asyncio.create_task(_process_inbox_forever(), name="gateway-inbox-worker")
|
||||
|
|
@ -52,4 +55,3 @@ async def stop_gateway_inbox_worker() -> None:
|
|||
with suppress(TimeoutError, asyncio.CancelledError):
|
||||
await asyncio.wait_for(_task, timeout=10)
|
||||
_task = None
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from datetime import UTC, datetime, timedelta
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import ExternalChatBindingState, ExternalChatBinding
|
||||
from app.db import ExternalChatBinding, ExternalChatBindingState
|
||||
|
||||
PAIRING_CODE_TTL = timedelta(minutes=10)
|
||||
|
||||
|
|
@ -51,4 +51,3 @@ async def redeem_pairing_code(
|
|||
binding.external_username = external_username
|
||||
binding.external_metadata = external_metadata or {}
|
||||
return binding
|
||||
|
||||
|
|
|
|||
|
|
@ -133,4 +133,3 @@ async def wait_for_token(
|
|||
if wait_ms > 0:
|
||||
await asyncio.sleep(wait_ms / 1000)
|
||||
return wait_ms
|
||||
|
||||
|
|
|
|||
|
|
@ -186,4 +186,6 @@ def resolve_platform_bundle(account: ExternalChatAccount) -> PlatformBundle:
|
|||
auto_bind_owner=False,
|
||||
)
|
||||
|
||||
raise RuntimeError(f"unsupported_gateway_platform:{account.platform.value}:{account.mode.value}")
|
||||
raise RuntimeError(
|
||||
f"unsupported_gateway_platform:{account.platform.value}:{account.mode.value}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,12 @@ import uuid
|
|||
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.db import ExternalChatPlatform, ExternalChatAccount, async_session_maker, engine
|
||||
from app.db import (
|
||||
ExternalChatAccount,
|
||||
ExternalChatPlatform,
|
||||
async_session_maker,
|
||||
engine,
|
||||
)
|
||||
from app.gateway.inbox import persist_inbound_event, telegram_event_dedupe_key
|
||||
from app.gateway.telegram.adapter import TelegramAdapter
|
||||
from app.observability.metrics import record_gateway_byo_longpoll_running_delta
|
||||
|
|
@ -39,7 +44,9 @@ async def _run_telegram_account(account_id: int, token: str) -> None:
|
|||
account = await session.get(ExternalChatAccount, account_id)
|
||||
offset = None
|
||||
if account is not None:
|
||||
offset = int((account.cursor_state or {}).get("last_update_id", 0)) + 1
|
||||
offset = (
|
||||
int((account.cursor_state or {}).get("last_update_id", 0)) + 1
|
||||
)
|
||||
|
||||
async for update in adapter.fetch_updates(offset=offset):
|
||||
request_id = f"gateway_{uuid.uuid4().hex[:16]}"
|
||||
|
|
@ -58,8 +65,11 @@ async def _run_telegram_account(account_id: int, token: str) -> None:
|
|||
)
|
||||
await session.commit()
|
||||
if inbox_id is not None:
|
||||
logger.debug("Persisted Telegram polling update inbox_id=%s", inbox_id)
|
||||
logger.debug(
|
||||
"Persisted Telegram polling update inbox_id=%s", inbox_id
|
||||
)
|
||||
finally:
|
||||
record_gateway_byo_longpoll_running_delta(-1, account_id=account_id)
|
||||
await conn.execute(text("SELECT pg_advisory_unlock(:key)"), {"key": lock_key})
|
||||
|
||||
await conn.execute(
|
||||
text("SELECT pg_advisory_unlock(:key)"), {"key": lock_key}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -38,7 +38,9 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
slack_user_id = str(event.get("user") or "")
|
||||
message_ts = str(event.get("ts") or "")
|
||||
thread_ts = str(event.get("thread_ts") or message_ts)
|
||||
bot_user_id = self.bot_user_id or str(raw_payload.get("authorizations", [{}])[0].get("user_id") or "")
|
||||
bot_user_id = self.bot_user_id or str(
|
||||
raw_payload.get("authorizations", [{}])[0].get("user_id") or ""
|
||||
)
|
||||
|
||||
if not channel_id or not slack_user_id or not message_ts:
|
||||
return ParsedInboundEvent(
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@ class SlackGatewayClient:
|
|||
def __init__(self, bot_token: str) -> None:
|
||||
self.bot_token = bot_token
|
||||
|
||||
async def api_call(self, method: str, payload: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
async def api_call(
|
||||
self, method: str, payload: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
response = await client.post(
|
||||
f"{SLACK_API}/{method}",
|
||||
|
|
@ -55,7 +57,9 @@ class SlackGatewayClient:
|
|||
ts: str,
|
||||
text: str,
|
||||
) -> PlatformSendResult:
|
||||
data = await self.api_call("chat.update", {"channel": channel, "ts": ts, "text": text})
|
||||
data = await self.api_call(
|
||||
"chat.update", {"channel": channel, "ts": ts, "text": text}
|
||||
)
|
||||
return PlatformSendResult(
|
||||
external_message_id=str(data.get("ts") or ts),
|
||||
raw_response=data,
|
||||
|
|
|
|||
|
|
@ -41,7 +41,9 @@ class SlackStreamTranslator(BaseStreamTranslator):
|
|||
async def translate(self, events: AsyncIterator[GatewayStreamEvent]) -> None:
|
||||
async for event in events:
|
||||
if event.type in {"text-delta", "text_delta", "text"}:
|
||||
self._buffer += str(event.data.get("text") or event.data.get("delta") or "")
|
||||
self._buffer += str(
|
||||
event.data.get("text") or event.data.get("delta") or ""
|
||||
)
|
||||
elif event.type in {"data-interrupt-request", "interrupt"}:
|
||||
await self._handle_hitl_interrupt()
|
||||
return
|
||||
|
|
@ -53,7 +55,9 @@ class SlackStreamTranslator(BaseStreamTranslator):
|
|||
async def _flush_final(self) -> None:
|
||||
if not self._buffer:
|
||||
return
|
||||
for chunk in split_text_message(self._buffer, max_chars=SLACK_MAX_MESSAGE_CHARS):
|
||||
for chunk in split_text_message(
|
||||
self._buffer, max_chars=SLACK_MAX_MESSAGE_CHARS
|
||||
):
|
||||
await self._send_text(chunk)
|
||||
|
||||
async def _send_text(self, text: str) -> PlatformSendResult:
|
||||
|
|
|
|||
|
|
@ -1,2 +1 @@
|
|||
"""Telegram gateway adapter."""
|
||||
|
||||
|
|
|
|||
|
|
@ -51,9 +51,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
"channel": "channel",
|
||||
}.get(chat_type, "unknown")
|
||||
display_name = chat.get("title") or " ".join(
|
||||
part
|
||||
for part in (sender.get("first_name"), sender.get("last_name"))
|
||||
if part
|
||||
part for part in (sender.get("first_name"), sender.get("last_name")) if part
|
||||
)
|
||||
|
||||
return ParsedInboundEvent(
|
||||
|
|
@ -62,14 +60,21 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
external_peer_id=str(chat["id"]) if chat.get("id") is not None else None,
|
||||
external_peer_kind=peer_kind,
|
||||
external_message_id=(
|
||||
str(message["message_id"]) if message.get("message_id") is not None else None
|
||||
str(message["message_id"])
|
||||
if message.get("message_id") is not None
|
||||
else None
|
||||
),
|
||||
external_user_id=str(sender["id"]) if sender.get("id") is not None else None,
|
||||
external_user_id=str(sender["id"])
|
||||
if sender.get("id") is not None
|
||||
else None,
|
||||
text=message.get("text") or message.get("caption"),
|
||||
raw_payload=raw_payload,
|
||||
display_name=display_name or None,
|
||||
username=sender.get("username") or chat.get("username"),
|
||||
metadata={"chat_type": chat_type, "update_id": raw_payload.get("update_id")},
|
||||
metadata={
|
||||
"chat_type": chat_type,
|
||||
"update_id": raw_payload.get("update_id"),
|
||||
},
|
||||
)
|
||||
|
||||
async def send_message(
|
||||
|
|
@ -108,7 +113,8 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
async def leave_chat(self, *, external_peer_id: str) -> None:
|
||||
await self.client.leave_chat(chat_id=external_peer_id)
|
||||
|
||||
async def fetch_updates(self, *, offset: int | None) -> AsyncIterator[dict[str, Any]]:
|
||||
async def fetch_updates(
|
||||
self, *, offset: int | None
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
async for update in self.client.get_updates(offset=offset):
|
||||
yield update
|
||||
|
||||
|
|
|
|||
|
|
@ -106,4 +106,3 @@ async def retry_plaintext_on_bad_markdown(call, *args, **kwargs) -> PlatformSend
|
|||
raise
|
||||
kwargs["parse_mode"] = None
|
||||
return await call(*args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -54,7 +54,9 @@ async def handle_start_command(
|
|||
return True
|
||||
|
||||
|
||||
async def handle_help_command(*, adapter: TelegramAdapter, event: ParsedInboundEvent) -> bool:
|
||||
async def handle_help_command(
|
||||
*, adapter: TelegramAdapter, event: ParsedInboundEvent
|
||||
) -> bool:
|
||||
if not event.external_peer_id:
|
||||
return True
|
||||
await adapter.send_message(external_peer_id=event.external_peer_id, text=HELP_TEXT)
|
||||
|
|
@ -114,4 +116,4 @@ class TelegramGatewayCommands(BaseGatewayCommands):
|
|||
adapter=adapter,
|
||||
event=event,
|
||||
dashboard_url=dashboard_url,
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -32,9 +32,13 @@ def _split_at_boundary(text: str, max_units: int) -> tuple[str, str]:
|
|||
end -= 1
|
||||
|
||||
candidate = text[:end]
|
||||
boundary = max(candidate.rfind("\n\n"), candidate.rfind(". "), candidate.rfind("\n"))
|
||||
boundary = max(
|
||||
candidate.rfind("\n\n"), candidate.rfind(". "), candidate.rfind("\n")
|
||||
)
|
||||
if boundary > max(200, end // 2):
|
||||
end = boundary + (2 if candidate[boundary : boundary + 2] in {"\n\n", ". "} else 1)
|
||||
end = boundary + (
|
||||
2 if candidate[boundary : boundary + 2] in {"\n\n", ". "} else 1
|
||||
)
|
||||
|
||||
return text[:end], text[end:]
|
||||
|
||||
|
|
@ -56,4 +60,3 @@ def chunk_message(
|
|||
chunks.append(chunk)
|
||||
return chunks
|
||||
return split_text_message(text, max_chars=max_units)
|
||||
|
||||
|
|
|
|||
|
|
@ -49,7 +49,9 @@ class TelegramStreamTranslator(BaseStreamTranslator):
|
|||
async def translate(self, events: AsyncIterator[GatewayStreamEvent]) -> None:
|
||||
async for event in events:
|
||||
if event.type in {"text-delta", "text_delta", "text"}:
|
||||
self._buffer += str(event.data.get("text") or event.data.get("delta") or "")
|
||||
self._buffer += str(
|
||||
event.data.get("text") or event.data.get("delta") or ""
|
||||
)
|
||||
await self._maybe_flush()
|
||||
elif event.type in {"data-interrupt-request", "interrupt"}:
|
||||
await self._handle_hitl_interrupt()
|
||||
|
|
@ -159,7 +161,9 @@ class TelegramStreamTranslator(BaseStreamTranslator):
|
|||
)
|
||||
if chat_wait:
|
||||
record_gateway_rate_limit_hit(bucket="tg:chat")
|
||||
global_wait = await wait_for_token("tg:global", capacity=25, refill_per_sec=25.0)
|
||||
global_wait = await wait_for_token(
|
||||
"tg:global", capacity=25, refill_per_sec=25.0
|
||||
)
|
||||
if global_wait:
|
||||
record_gateway_rate_limit_hit(bucket="tg:global")
|
||||
|
||||
|
|
@ -168,4 +172,3 @@ class TelegramStreamTranslator(BaseStreamTranslator):
|
|||
await self._flush(final=False)
|
||||
await self._send_text(HITL_UNSUPPORTED_MESSAGE)
|
||||
record_gateway_hitl_aborted(platform="telegram")
|
||||
|
||||
|
|
|
|||
|
|
@ -36,5 +36,6 @@ def release_thread_lock(thread_id: int) -> None:
|
|||
try:
|
||||
_redis().delete(_lock_key(thread_id))
|
||||
except redis.RedisError as exc:
|
||||
logger.warning("Failed to release gateway thread lock for %s: %s", thread_id, exc)
|
||||
|
||||
logger.warning(
|
||||
"Failed to release gateway thread lock for %s: %s", thread_id, exc
|
||||
)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,8 @@ class WhatsAppBaileysAdapter(BasePlatformAdapter):
|
|||
external_user_id=sender_id or None,
|
||||
text=str(body) if body is not None else None,
|
||||
raw_payload=raw_payload,
|
||||
display_name=str(raw_payload.get("chatName") or sender_id or chat_id) or None,
|
||||
display_name=str(raw_payload.get("chatName") or sender_id or chat_id)
|
||||
or None,
|
||||
username=None,
|
||||
metadata={
|
||||
"sender_id": sender_id,
|
||||
|
|
@ -92,7 +93,9 @@ class WhatsAppBaileysAdapter(BasePlatformAdapter):
|
|||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def fetch_updates(self, *, offset: int | None) -> AsyncIterator[dict[str, Any]]:
|
||||
async def fetch_updates(
|
||||
self, *, offset: int | None
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
async with httpx.AsyncClient(timeout=35) as client:
|
||||
response = await client.get(f"{self.bridge_url}/messages")
|
||||
response.raise_for_status()
|
||||
|
|
|
|||
|
|
@ -54,7 +54,9 @@ class WhatsAppCloudAdapter(BasePlatformAdapter):
|
|||
username=None,
|
||||
metadata={
|
||||
"phone_number_id": _metadata(raw_payload).get("phone_number_id"),
|
||||
"display_phone_number": _metadata(raw_payload).get("display_phone_number"),
|
||||
"display_phone_number": _metadata(raw_payload).get(
|
||||
"display_phone_number"
|
||||
),
|
||||
"timestamp": message.get("timestamp"),
|
||||
"message_type": message.get("type"),
|
||||
},
|
||||
|
|
@ -96,7 +98,9 @@ def _changes(raw_payload: dict[str, Any]) -> list[dict[str, Any]]:
|
|||
for entry in raw_payload.get("entry") or []:
|
||||
if isinstance(entry, dict):
|
||||
changes.extend(
|
||||
change for change in (entry.get("changes") or []) if isinstance(change, dict)
|
||||
change
|
||||
for change in (entry.get("changes") or [])
|
||||
if isinstance(change, dict)
|
||||
)
|
||||
return changes
|
||||
|
||||
|
|
|
|||
|
|
@ -18,8 +18,7 @@ class WhatsAppCredentials(TypedDict, total=False):
|
|||
|
||||
def load_system_whatsapp_credentials() -> WhatsAppCredentials:
|
||||
if not (
|
||||
config.WHATSAPP_SHARED_BUSINESS_TOKEN
|
||||
and config.WHATSAPP_SHARED_PHONE_NUMBER_ID
|
||||
config.WHATSAPP_SHARED_BUSINESS_TOKEN and config.WHATSAPP_SHARED_PHONE_NUMBER_ID
|
||||
):
|
||||
raise RuntimeError("whatsapp_system_credentials_not_configured")
|
||||
|
||||
|
|
|
|||
|
|
@ -41,7 +41,9 @@ class WhatsAppCloudStreamTranslator(BaseStreamTranslator):
|
|||
if event.type in {"text-delta", "text_delta", "text"}:
|
||||
if not self._typing_sent:
|
||||
await self._send_typing_indicator()
|
||||
self._buffer += str(event.data.get("text") or event.data.get("delta") or "")
|
||||
self._buffer += str(
|
||||
event.data.get("text") or event.data.get("delta") or ""
|
||||
)
|
||||
elif event.type in {"data-interrupt-request", "interrupt"}:
|
||||
await self._handle_hitl_interrupt()
|
||||
return
|
||||
|
|
|
|||
|
|
@ -42,7 +42,9 @@ class WhatsAppBaileysStreamTranslator(BaseStreamTranslator):
|
|||
await self._send_typing_indicator()
|
||||
async for event in events:
|
||||
if event.type in {"text-delta", "text_delta", "text"}:
|
||||
self._buffer += str(event.data.get("text") or event.data.get("delta") or "")
|
||||
self._buffer += str(
|
||||
event.data.get("text") or event.data.get("delta") or ""
|
||||
)
|
||||
await self._maybe_flush()
|
||||
elif event.type in {"data-interrupt-request", "interrupt"}:
|
||||
await self._handle_hitl_interrupt()
|
||||
|
|
@ -86,7 +88,9 @@ class WhatsAppBaileysStreamTranslator(BaseStreamTranslator):
|
|||
if not isinstance(self.adapter, WhatsAppBaileysAdapter):
|
||||
return
|
||||
try:
|
||||
await self.adapter.send_typing_indicator(external_peer_id=self.external_peer_id)
|
||||
await self.adapter.send_typing_indicator(
|
||||
external_peer_id=self.external_peer_id
|
||||
)
|
||||
record_gateway_outbound(platform="whatsapp", kind="typing", status="sent")
|
||||
except Exception:
|
||||
logger.debug("WhatsApp Baileys typing indicator failed", exc_info=True)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import contextlib
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
|
@ -6,6 +8,8 @@ from sqlalchemy.orm.attributes import set_committed_value
|
|||
|
||||
from app.db import Document, DocumentStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def rollback_and_persist_failure(
|
||||
session: AsyncSession, document: Document, message: str
|
||||
|
|
@ -18,14 +22,28 @@ async def rollback_and_persist_failure(
|
|||
try:
|
||||
await session.rollback()
|
||||
except Exception:
|
||||
return # Session is completely dead; nothing further we can do.
|
||||
# Session is completely dead; surface it but never raise.
|
||||
logger.warning(
|
||||
"Rollback failed; cannot persist failed status for document %s",
|
||||
getattr(document, "id", "unknown"),
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
try:
|
||||
await session.refresh(document)
|
||||
document.updated_at = datetime.now(UTC)
|
||||
document.status = DocumentStatus.failed(message)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass # Best-effort; document will be retried on the next sync.
|
||||
# Best-effort: the document stays non-ready and is retried next sync.
|
||||
# Log it so a permanently-stuck document is at least traceable.
|
||||
logger.warning(
|
||||
"Could not persist failed status for document %s; will retry next sync",
|
||||
getattr(document, "id", "unknown"),
|
||||
exc_info=True,
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
await session.rollback()
|
||||
|
||||
|
||||
def attach_chunks_to_document(document: Document, chunks: list) -> None:
|
||||
|
|
|
|||
|
|
@ -202,7 +202,9 @@ class IndexingPipelineService:
|
|||
|
||||
await self.session.commit()
|
||||
|
||||
async def index_batch(self, connector_docs: list[ConnectorDocument]) -> list[Document]:
|
||||
async def index_batch(
|
||||
self, connector_docs: list[ConnectorDocument]
|
||||
) -> list[Document]:
|
||||
"""Convenience method: prepare_for_indexing then index each document.
|
||||
|
||||
Indexers that need heartbeat callbacks or custom per-document logic
|
||||
|
|
@ -347,7 +349,9 @@ class IndexingPipelineService:
|
|||
await self.session.rollback()
|
||||
return []
|
||||
|
||||
async def index(self, document: Document, connector_doc: ConnectorDocument) -> Document:
|
||||
async def index(
|
||||
self, document: Document, connector_doc: ConnectorDocument
|
||||
) -> Document:
|
||||
"""
|
||||
Run deterministic content storage, embedding, and chunking for a document.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from __future__ import annotations
|
|||
# Initialize app.db first to avoid a partial-init circular import when this
|
||||
# package is the entry point (e.g. Celery loading it before any ORM code).
|
||||
import app.db # noqa: F401
|
||||
|
||||
from app.notifications.persistence import Notification
|
||||
from app.notifications.service import NotificationService
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.automations.api import router as automations_router
|
||||
from app.file_storage.api import router as file_storage_router
|
||||
from app.gateway import require_gateway_enabled
|
||||
from app.notifications.api import router as notifications_router
|
||||
|
||||
from .agent_action_log_route import router as agent_action_log_router
|
||||
from .agent_flags_route import router as agent_flags_router
|
||||
|
|
@ -45,7 +47,6 @@ from .model_list_routes import router as model_list_router
|
|||
from .new_chat_routes import router as new_chat_router
|
||||
from .new_llm_config_routes import router as new_llm_config_router
|
||||
from .notes_routes import router as notes_router
|
||||
from app.notifications.api import router as notifications_router
|
||||
from .notion_add_connector_route import router as notion_add_connector_router
|
||||
from .obsidian_plugin_routes import router as obsidian_plugin_router
|
||||
from .onedrive_add_connector_route import router as onedrive_add_connector_router
|
||||
|
|
@ -73,9 +74,14 @@ router.include_router(editor_router)
|
|||
router.include_router(export_router)
|
||||
router.include_router(documents_router)
|
||||
router.include_router(folders_router)
|
||||
router.include_router(gateway_router)
|
||||
router.include_router(gateway_whatsapp_webhook_router)
|
||||
router.include_router(gateway_whatsapp_baileys_router)
|
||||
_gateway_enabled_dep = [Depends(require_gateway_enabled)]
|
||||
router.include_router(gateway_router, dependencies=_gateway_enabled_dep)
|
||||
router.include_router(
|
||||
gateway_whatsapp_webhook_router, dependencies=_gateway_enabled_dep
|
||||
)
|
||||
router.include_router(
|
||||
gateway_whatsapp_baileys_router, dependencies=_gateway_enabled_dep
|
||||
)
|
||||
router.include_router(notes_router)
|
||||
router.include_router(new_chat_router) # Chat with assistant-ui persistence
|
||||
router.include_router(agent_revert_router) # POST /threads/{id}/revert/{action_id}
|
||||
|
|
|
|||
|
|
@ -119,21 +119,35 @@ def _discord_redirect_uri() -> str:
|
|||
return f"{base.rstrip('/')}/api/v1/gateway/discord/callback"
|
||||
|
||||
|
||||
def _slack_frontend_redirect(space_id: int, *, success: bool = False, error: str | None = None) -> RedirectResponse:
|
||||
qs = "slack_gateway=connected" if success else f"error={error or 'slack_gateway_failed'}"
|
||||
def _slack_frontend_redirect(
|
||||
space_id: int, *, success: bool = False, error: str | None = None
|
||||
) -> RedirectResponse:
|
||||
qs = (
|
||||
"slack_gateway=connected"
|
||||
if success
|
||||
else f"error={error or 'slack_gateway_failed'}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/user-settings?{qs}"
|
||||
)
|
||||
|
||||
|
||||
def _discord_frontend_redirect(space_id: int, *, success: bool = False, error: str | None = None) -> RedirectResponse:
|
||||
qs = "discord_gateway=connected" if success else f"error={error or 'discord_gateway_failed'}"
|
||||
def _discord_frontend_redirect(
|
||||
space_id: int, *, success: bool = False, error: str | None = None
|
||||
) -> RedirectResponse:
|
||||
qs = (
|
||||
"discord_gateway=connected"
|
||||
if success
|
||||
else f"error={error or 'discord_gateway_failed'}"
|
||||
)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/user-settings?{qs}"
|
||||
)
|
||||
|
||||
|
||||
def verify_slack_signature(*, signing_secret: str, timestamp: str | None, signature: str | None, body: bytes) -> bool:
|
||||
def verify_slack_signature(
|
||||
*, signing_secret: str, timestamp: str | None, signature: str | None, body: bytes
|
||||
) -> bool:
|
||||
if not signing_secret or not timestamp or not signature:
|
||||
return False
|
||||
try:
|
||||
|
|
@ -239,7 +253,9 @@ async def install_slack_gateway(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> dict[str, str]:
|
||||
if not _slack_gateway_enabled():
|
||||
raise HTTPException(status_code=500, detail="Slack gateway OAuth is not configured")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Slack gateway OAuth is not configured"
|
||||
)
|
||||
await check_search_space_access(session, user, search_space_id)
|
||||
state = _get_state_manager().generate_secure_state(search_space_id, user.id)
|
||||
auth_params = {
|
||||
|
|
@ -269,11 +285,17 @@ async def slack_gateway_callback(
|
|||
state_data = None
|
||||
|
||||
if error:
|
||||
return _slack_frontend_redirect(space_id or 0, error="slack_gateway_oauth_denied")
|
||||
return _slack_frontend_redirect(
|
||||
space_id or 0, error="slack_gateway_oauth_denied"
|
||||
)
|
||||
if not code or state_data is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid Slack gateway OAuth callback")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid Slack gateway OAuth callback"
|
||||
)
|
||||
if not _slack_gateway_enabled():
|
||||
raise HTTPException(status_code=500, detail="Slack gateway OAuth is not configured")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Slack gateway OAuth is not configured"
|
||||
)
|
||||
|
||||
user_id = UUID(state_data["user_id"])
|
||||
token_payload = {
|
||||
|
|
@ -300,7 +322,9 @@ async def slack_gateway_callback(
|
|||
team = token_json.get("team") or {}
|
||||
team_id = team.get("id")
|
||||
if not bot_token or not team_id:
|
||||
raise HTTPException(status_code=400, detail="Slack gateway OAuth returned incomplete data")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Slack gateway OAuth returned incomplete data"
|
||||
)
|
||||
|
||||
bot_user_id = token_json.get("bot_user_id")
|
||||
app_id = token_json.get("app_id")
|
||||
|
|
@ -388,7 +412,9 @@ async def install_discord_gateway(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> dict[str, str]:
|
||||
if not _discord_gateway_enabled():
|
||||
raise HTTPException(status_code=500, detail="Discord gateway OAuth is not configured")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Discord gateway OAuth is not configured"
|
||||
)
|
||||
await check_search_space_access(session, user, search_space_id)
|
||||
state = _get_state_manager().generate_secure_state(search_space_id, user.id)
|
||||
auth_params = {
|
||||
|
|
@ -420,11 +446,17 @@ async def discord_gateway_callback(
|
|||
state_data = None
|
||||
|
||||
if error:
|
||||
return _discord_frontend_redirect(space_id or 0, error="discord_gateway_oauth_denied")
|
||||
return _discord_frontend_redirect(
|
||||
space_id or 0, error="discord_gateway_oauth_denied"
|
||||
)
|
||||
if not code or state_data is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid Discord gateway OAuth callback")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid Discord gateway OAuth callback"
|
||||
)
|
||||
if not _discord_gateway_enabled():
|
||||
raise HTTPException(status_code=500, detail="Discord gateway OAuth is not configured")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Discord gateway OAuth is not configured"
|
||||
)
|
||||
|
||||
user_id = UUID(state_data["user_id"])
|
||||
token_payload = {
|
||||
|
|
@ -535,7 +567,10 @@ async def discord_gateway_callback(
|
|||
elif binding.user_id == user_id:
|
||||
binding.search_space_id = space_id
|
||||
binding.external_username = discord_username or binding.external_username
|
||||
binding.external_metadata = {**(binding.external_metadata or {}), **metadata}
|
||||
binding.external_metadata = {
|
||||
**(binding.external_metadata or {}),
|
||||
**metadata,
|
||||
}
|
||||
|
||||
await session.commit()
|
||||
return _discord_frontend_redirect(space_id, success=True)
|
||||
|
|
@ -614,7 +649,9 @@ async def _resolve_webhook_account(
|
|||
if account is None or account.platform != ExternalChatPlatform.TELEGRAM:
|
||||
raise HTTPException(status_code=404, detail="Gateway account not found")
|
||||
expected_secret = account.webhook_secret or ""
|
||||
if not expected_secret or not hmac.compare_digest(header_secret or "", expected_secret):
|
||||
if not expected_secret or not hmac.compare_digest(
|
||||
header_secret or "", expected_secret
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Invalid Telegram webhook secret")
|
||||
return account
|
||||
|
||||
|
|
@ -654,7 +691,9 @@ async def telegram_webhook(
|
|||
event_dedupe_key=telegram_event_dedupe_key(update_id),
|
||||
external_event_id=str(update_id),
|
||||
external_message_id=(
|
||||
str(message["message_id"]) if message.get("message_id") is not None else None
|
||||
str(message["message_id"])
|
||||
if message.get("message_id") is not None
|
||||
else None
|
||||
),
|
||||
event_kind=_classify_telegram_event(payload),
|
||||
raw_payload=payload,
|
||||
|
|
@ -739,7 +778,10 @@ async def list_bindings(
|
|||
) -> list[dict[str, Any]]:
|
||||
result = await session.execute(
|
||||
select(ExternalChatBinding, ExternalChatAccount)
|
||||
.join(ExternalChatAccount, ExternalChatBinding.account_id == ExternalChatAccount.id)
|
||||
.join(
|
||||
ExternalChatAccount,
|
||||
ExternalChatBinding.account_id == ExternalChatAccount.id,
|
||||
)
|
||||
.where(ExternalChatBinding.user_id == user.id)
|
||||
)
|
||||
return [
|
||||
|
|
@ -777,13 +819,20 @@ async def list_connections(
|
|||
]
|
||||
if platform is not None:
|
||||
filters.append(ExternalChatAccount.platform == platform)
|
||||
if platform == ExternalChatPlatform.WHATSAPP and active_whatsapp_mode is not None:
|
||||
if (
|
||||
platform == ExternalChatPlatform.WHATSAPP
|
||||
and active_whatsapp_mode is not None
|
||||
):
|
||||
filters.append(ExternalChatAccount.mode == active_whatsapp_mode)
|
||||
else:
|
||||
if not _telegram_gateway_enabled():
|
||||
filters.append(ExternalChatAccount.platform != ExternalChatPlatform.TELEGRAM)
|
||||
filters.append(
|
||||
ExternalChatAccount.platform != ExternalChatPlatform.TELEGRAM
|
||||
)
|
||||
if active_whatsapp_mode is None:
|
||||
filters.append(ExternalChatAccount.platform != ExternalChatPlatform.WHATSAPP)
|
||||
filters.append(
|
||||
ExternalChatAccount.platform != ExternalChatPlatform.WHATSAPP
|
||||
)
|
||||
else:
|
||||
filters.append(
|
||||
or_(
|
||||
|
|
@ -794,7 +843,10 @@ async def list_connections(
|
|||
|
||||
result = await session.execute(
|
||||
select(ExternalChatBinding, ExternalChatAccount)
|
||||
.join(ExternalChatAccount, ExternalChatBinding.account_id == ExternalChatAccount.id)
|
||||
.join(
|
||||
ExternalChatAccount,
|
||||
ExternalChatBinding.account_id == ExternalChatAccount.id,
|
||||
)
|
||||
.where(*filters)
|
||||
)
|
||||
|
||||
|
|
@ -828,7 +880,9 @@ async def list_connections(
|
|||
baileys_account_ids.add(int(account.id))
|
||||
route_type = "account"
|
||||
connection_id = account.id
|
||||
search_space_id = account.owner_search_space_id or binding.search_space_id
|
||||
search_space_id = (
|
||||
account.owner_search_space_id or binding.search_space_id
|
||||
)
|
||||
display_name = "WhatsApp Bridge"
|
||||
|
||||
connections.append(
|
||||
|
|
@ -853,9 +907,8 @@ async def list_connections(
|
|||
}
|
||||
)
|
||||
|
||||
if (
|
||||
active_whatsapp_mode == ExternalChatAccountMode.SELF_HOST_BYO
|
||||
and (platform is None or platform == ExternalChatPlatform.WHATSAPP)
|
||||
if active_whatsapp_mode == ExternalChatAccountMode.SELF_HOST_BYO and (
|
||||
platform is None or platform == ExternalChatPlatform.WHATSAPP
|
||||
):
|
||||
account_result = await session.execute(
|
||||
select(ExternalChatAccount).where(
|
||||
|
|
@ -940,7 +993,9 @@ async def update_binding_search_space(
|
|||
ExternalChatBindingState.BOUND,
|
||||
ExternalChatBindingState.SUSPENDED,
|
||||
}:
|
||||
raise HTTPException(status_code=400, detail="Only active bindings can be routed")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Only active bindings can be routed"
|
||||
)
|
||||
account = await session.get(ExternalChatAccount, binding.account_id)
|
||||
if account is None or _is_inactive_whatsapp_account(account):
|
||||
raise HTTPException(status_code=404, detail="Binding not found")
|
||||
|
|
@ -1062,4 +1117,3 @@ async def resume_external_chat_binding(
|
|||
binding.updated_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,9 @@ class BaileysPairRequest(BaseModel):
|
|||
|
||||
def _ensure_baileys_enabled() -> None:
|
||||
if config.GATEWAY_WHATSAPP_INTAKE_MODE != "baileys":
|
||||
raise HTTPException(status_code=404, detail="WhatsApp Baileys gateway is disabled")
|
||||
raise HTTPException(
|
||||
status_code=404, detail="WhatsApp Baileys gateway is disabled"
|
||||
)
|
||||
if config.is_cloud():
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
|
|
|
|||
|
|
@ -79,7 +79,9 @@ async def whatsapp_webhook(
|
|||
|
||||
def _verify_signature(raw_body: bytes, header_signature: str | None) -> None:
|
||||
if not config.WHATSAPP_WEBHOOK_APP_SECRET:
|
||||
raise HTTPException(status_code=500, detail="WhatsApp app secret is not configured")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="WhatsApp app secret is not configured"
|
||||
)
|
||||
received = (header_signature or "").removeprefix("sha256=")
|
||||
expected = hmac.new(
|
||||
config.WHATSAPP_WEBHOOK_APP_SECRET.encode(),
|
||||
|
|
@ -87,7 +89,9 @@ def _verify_signature(raw_body: bytes, header_signature: str | None) -> None:
|
|||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
if not received or not hmac.compare_digest(received, expected):
|
||||
raise HTTPException(status_code=403, detail="Invalid WhatsApp webhook signature")
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Invalid WhatsApp webhook signature"
|
||||
)
|
||||
|
||||
|
||||
async def _process_payload(session: AsyncSession, payload: dict[str, Any]) -> None:
|
||||
|
|
@ -114,7 +118,9 @@ async def _process_messages_change(
|
|||
change: dict[str, Any],
|
||||
value: dict[str, Any],
|
||||
) -> None:
|
||||
statuses = [status for status in value.get("statuses") or [] if isinstance(status, dict)]
|
||||
statuses = [
|
||||
status for status in value.get("statuses") or [] if isinstance(status, dict)
|
||||
]
|
||||
for status in statuses:
|
||||
record_gateway_outbound(
|
||||
platform="whatsapp",
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from app.db import (
|
|||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.notifications.service import NotificationService
|
||||
from app.schemas.obsidian_plugin import (
|
||||
ALLOWED_ATTACHMENT_EXTENSIONS,
|
||||
ATTACHMENT_MIME_TYPES,
|
||||
|
|
@ -43,7 +44,6 @@ from app.schemas.obsidian_plugin import (
|
|||
SyncAckItem,
|
||||
SyncBatchRequest,
|
||||
)
|
||||
from app.notifications.service import NotificationService
|
||||
from app.services.obsidian_plugin_indexer import (
|
||||
delete_note,
|
||||
get_manifest,
|
||||
|
|
|
|||
|
|
@ -3,14 +3,14 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
|
||||
import aiohttp
|
||||
from fake_useragent import UserAgent
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from scrapling.fetchers import AsyncFetcher
|
||||
|
||||
from app.db import User
|
||||
from app.users import current_active_user
|
||||
from app.utils.proxy_config import get_requests_proxies
|
||||
from app.utils.proxy import get_proxy_url
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -69,26 +69,30 @@ async def _fetch_playlist_via_innertube(playlist_id: str) -> list[str]:
|
|||
"context": {"client": _INNERTUBE_CLIENT},
|
||||
"browseId": f"VL{playlist_id}",
|
||||
}
|
||||
proxies = get_requests_proxies()
|
||||
|
||||
try:
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.post(
|
||||
_INNERTUBE_API_URL,
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json"},
|
||||
proxy=proxies["http"] if proxies else None,
|
||||
) as response,
|
||||
):
|
||||
if response.status != 200:
|
||||
logger.warning(
|
||||
"Innertube API returned %d for playlist %s",
|
||||
response.status,
|
||||
playlist_id,
|
||||
)
|
||||
return []
|
||||
data = await response.json()
|
||||
fetch_start = time.perf_counter()
|
||||
page = await AsyncFetcher.post(
|
||||
_INNERTUBE_API_URL,
|
||||
json=payload,
|
||||
proxy=get_proxy_url(),
|
||||
stealthy_headers=True,
|
||||
)
|
||||
fetch_ms = (time.perf_counter() - fetch_start) * 1000
|
||||
logger.info(
|
||||
"[youtube][perf] source=innertube playlist=%s status=%s fetch_ms=%.1f",
|
||||
playlist_id,
|
||||
page.status,
|
||||
fetch_ms,
|
||||
)
|
||||
if page.status != 200:
|
||||
logger.warning(
|
||||
"Innertube API returned %d for playlist %s",
|
||||
page.status,
|
||||
playlist_id,
|
||||
)
|
||||
return []
|
||||
data = page.json()
|
||||
|
||||
return _extract_playlist_video_ids(data)
|
||||
except Exception as e:
|
||||
|
|
@ -98,35 +102,38 @@ async def _fetch_playlist_via_innertube(playlist_id: str) -> list[str]:
|
|||
|
||||
async def _fetch_playlist_via_html(playlist_id: str) -> list[str]:
|
||||
"""Fallback: scrape playlist page HTML with consent cookies set."""
|
||||
ua = UserAgent()
|
||||
headers = {
|
||||
"User-Agent": ua.random,
|
||||
"Accept-Language": "en-US,en;q=0.9",
|
||||
}
|
||||
# Scrapling's stealthy_headers supplies a realistic User-Agent automatically.
|
||||
headers = {"Accept-Language": "en-US,en;q=0.9"}
|
||||
cookies = {
|
||||
"CONSENT": "PENDING+999",
|
||||
"SOCS": "CAISNQgDEitib3FfaWRlbnRpdHlmcm9udGVuZHVpc2VydmVyXzIwMjMwODI5LjA3X3AxGgJlbiADGgYIgOa_pgY",
|
||||
}
|
||||
proxies = get_requests_proxies()
|
||||
playlist_url = f"https://www.youtube.com/playlist?list={playlist_id}"
|
||||
|
||||
try:
|
||||
async with (
|
||||
aiohttp.ClientSession(cookies=cookies) as session,
|
||||
session.get(
|
||||
playlist_url,
|
||||
headers=headers,
|
||||
proxy=proxies["http"] if proxies else None,
|
||||
) as response,
|
||||
):
|
||||
if response.status != 200:
|
||||
logger.warning(
|
||||
"HTML fallback returned %d for playlist %s",
|
||||
response.status,
|
||||
playlist_id,
|
||||
)
|
||||
return []
|
||||
html = await response.text()
|
||||
fetch_start = time.perf_counter()
|
||||
page = await AsyncFetcher.get(
|
||||
playlist_url,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
proxy=get_proxy_url(),
|
||||
stealthy_headers=True,
|
||||
)
|
||||
fetch_ms = (time.perf_counter() - fetch_start) * 1000
|
||||
logger.info(
|
||||
"[youtube][perf] source=html-fallback playlist=%s status=%s fetch_ms=%.1f",
|
||||
playlist_id,
|
||||
page.status,
|
||||
fetch_ms,
|
||||
)
|
||||
if page.status != 200:
|
||||
logger.warning(
|
||||
"HTML fallback returned %d for playlist %s",
|
||||
page.status,
|
||||
playlist_id,
|
||||
)
|
||||
return []
|
||||
html = page.html_content
|
||||
|
||||
yt_data = _extract_yt_initial_data(html)
|
||||
if not yt_data:
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from app.db import (
|
|||
User,
|
||||
has_permission,
|
||||
)
|
||||
from app.notifications.service import NotificationService
|
||||
from app.schemas.chat_comments import (
|
||||
AuthorResponse,
|
||||
CommentBatchResponse,
|
||||
|
|
@ -31,7 +32,6 @@ from app.schemas.chat_comments import (
|
|||
MentionListResponse,
|
||||
MentionResponse,
|
||||
)
|
||||
from app.notifications.service import NotificationService
|
||||
from app.utils.chat_comments import parse_mentions, render_mentions
|
||||
from app.utils.rbac import check_permission, get_user_permissions
|
||||
|
||||
|
|
|
|||
|
|
@ -64,9 +64,6 @@ class ConfluenceKBSyncService:
|
|||
if dup:
|
||||
content_hash = unique_hash
|
||||
|
||||
|
||||
|
||||
|
||||
summary_content = f"Confluence Page: {page_title}\n\n{page_content}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
|
|
@ -166,8 +163,6 @@ class ConfluenceKBSyncService:
|
|||
|
||||
space_id = (document.document_metadata or {}).get("space_id", "")
|
||||
|
||||
|
||||
|
||||
summary_content = f"Confluence Page: {page_title}\n\n{page_content}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from threading import Lock
|
||||
|
|
@ -12,6 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from tavily import TavilyClient
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
NATIVE_TO_LEGACY_DOCTYPE,
|
||||
Chunk,
|
||||
|
|
@ -2856,9 +2856,7 @@ class ConnectorService:
|
|||
# bounded and the alternative (cross-replica fanout) is not worth the
|
||||
# coupling here.
|
||||
|
||||
_DISCOVERY_TTL_SECONDS: float = float(
|
||||
os.getenv("SURFSENSE_CONNECTOR_DISCOVERY_TTL_SECONDS", "30")
|
||||
)
|
||||
_DISCOVERY_TTL_SECONDS: float = config.CONNECTOR_DISCOVERY_TTL_SECONDS
|
||||
|
||||
# Per-search-space caches. Keyed by ``search_space_id``; value is
|
||||
# ``(expires_at_monotonic, payload)``. Plain dicts protected by a lock —
|
||||
|
|
|
|||
|
|
@ -71,9 +71,6 @@ class DropboxKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
|
||||
|
||||
|
||||
summary_content = f"Dropbox File: {file_name}\n\n{indexable_content}"
|
||||
summary_embedding = embed_text(summary_content)
|
||||
|
||||
|
|
|
|||
|
|
@ -77,9 +77,6 @@ class GmailKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
|
||||
|
||||
|
||||
summary_content = f"Gmail Message: {subject}\n\n{indexable_content}"
|
||||
summary_embedding = await asyncio.to_thread(embed_text, summary_content)
|
||||
|
||||
|
|
|
|||
|
|
@ -89,9 +89,6 @@ class GoogleCalendarKBSyncService:
|
|||
)
|
||||
content_hash = unique_hash
|
||||
|
||||
|
||||
|
||||
|
||||
summary_content = (
|
||||
f"Google Calendar Event: {event_summary}\n\n{indexable_content}"
|
||||
)
|
||||
|
|
@ -252,9 +249,6 @@ class GoogleCalendarKBSyncService:
|
|||
if not indexable_content:
|
||||
return {"status": "error", "message": "Event produced empty content"}
|
||||
|
||||
|
||||
|
||||
|
||||
summary_content = (
|
||||
f"Google Calendar Event: {event_summary}\n\n{indexable_content}"
|
||||
)
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue