Merge remote-tracking branch 'upstream/dev' into feat/ui-revamp

This commit is contained in:
Anish Sarkar 2026-05-02 12:54:49 +05:30
commit 9b1b5a504e
148 changed files with 19460 additions and 2708 deletions

View file

@ -136,6 +136,19 @@ jobs:
AZURE_CODESIGN_ENDPOINT: ${{ vars.AZURE_CODESIGN_ENDPOINT }}
AZURE_CODESIGN_ACCOUNT: ${{ vars.AZURE_CODESIGN_ACCOUNT }}
AZURE_CODESIGN_PROFILE: ${{ vars.AZURE_CODESIGN_PROFILE }}
# macOS Developer ID signing + notarization. Only the macos-latest runner
# consumes these; Windows/Linux runners ignore them. CSC_LINK accepts either
# a file path or a base64-encoded .p12 blob — electron-builder auto-detects.
CSC_LINK: ${{ secrets.MAC_CERT_P12_BASE64 }}
CSC_KEY_PASSWORD: ${{ secrets.MAC_CERT_PASSWORD }}
APPLE_ID: ${{ secrets.APPLE_ID }}
APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
# TEMP DEBUG — remove once the codesign hang on macos-latest is diagnosed.
# Surfaces the exact codesign / notarize commands electron-builder spawns,
# so we can see which subprocess hangs.
DEBUG: electron-builder,electron-osx-sign*,@electron/notarize*
ELECTRON_BUILDER_ALLOW_UNRESOLVED_DEPENDENCIES: "true"
# Service principal credentials for Azure.Identity EnvironmentCredential used by the
# TrustedSigning PowerShell module. Only populated when signing is enabled.
# electron-builder 26 does not yet support OIDC federated tokens for Azure signing,

60
.github/workflows/notary-status.yml vendored Normal file
View file

@ -0,0 +1,60 @@
name: Notary status check
# One-off diagnostic workflow. Queries Apple's notary service to see if your
# submissions are queued, in progress, accepted, or rejected. Useful when a
# notarization seems "hung" — most often the queue itself, especially on a
# brand-new Apple Developer account.
#
# Run via: Actions tab -> "Notary status check" -> Run workflow.
# Inputs are optional; if you provide a submission ID, it also fetches that
# submission's full Apple log.
#
# Safe to delete after diagnosis.
on:
workflow_dispatch:
inputs:
submission_id:
description: 'Optional: submission UUID to fetch full Apple log for'
required: false
default: ''
jobs:
status:
runs-on: macos-latest
steps:
- name: List recent notarization submissions
env:
APPLE_ID: ${{ secrets.APPLE_ID }}
APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
run: |
set -euo pipefail
echo "::group::Submission history (most recent first)"
xcrun notarytool history \
--apple-id "$APPLE_ID" \
--password "$APPLE_APP_SPECIFIC_PASSWORD" \
--team-id "$APPLE_TEAM_ID"
echo "::endgroup::"
- name: Inspect specific submission (if id provided)
if: ${{ inputs.submission_id != '' }}
env:
APPLE_ID: ${{ secrets.APPLE_ID }}
APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
SUBMISSION_ID: ${{ inputs.submission_id }}
run: |
set -euo pipefail
echo "::group::Submission info"
xcrun notarytool info "$SUBMISSION_ID" \
--apple-id "$APPLE_ID" \
--password "$APPLE_APP_SPECIFIC_PASSWORD" \
--team-id "$APPLE_TEAM_ID"
echo "::endgroup::"
echo "::group::Apple's processing log for this submission"
xcrun notarytool log "$SUBMISSION_ID" \
--apple-id "$APPLE_ID" \
--password "$APPLE_APP_SPECIFIC_PASSWORD" \
--team-id "$APPLE_TEAM_ID" || true
echo "::endgroup::"

View file

@ -282,6 +282,14 @@ LANGSMITH_PROJECT=surfsense
# SURFSENSE_ENABLE_ACTION_LOG=false
# SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships
# Streaming parity v2 — opt in to LangChain's structured AIMessageChunk
# content (typed reasoning blocks, tool-input deltas) and propagate the
# real tool_call_id to the SSE layer. When OFF, the stream falls back to
# the str-only text path and synthetic "call_<run_id>" tool-call ids.
# Schema migrations 135/136 ship unconditionally because they are
# forward-compatible.
# SURFSENSE_ENABLE_STREAM_PARITY_V2=false
# Plugins
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
# Comma-separated allowlist of plugin entry-point names

View file

@ -0,0 +1,139 @@
"""134_relax_revision_fks
Revision ID: 134
Revises: 133
Create Date: 2026-04-29
Relax the parent FKs on ``document_revisions`` and ``folder_revisions`` so
revisions survive the deletes they describe.
Why: the snapshot/revert pipeline writes a ``DocumentRevision`` BEFORE
hard-deleting a document via the ``rm`` tool (and likewise a
``FolderRevision`` before ``rmdir``). If the FK is ``ON DELETE CASCADE``
the snapshot row is wiped at the exact moment we need it most revert
then has nothing to read and the operation becomes irreversible.
Migration:
* ``document_revisions.document_id``: ``NOT NULL`` -> nullable; FK
``ON DELETE CASCADE`` -> ``ON DELETE SET NULL``.
* ``folder_revisions.folder_id``: same treatment.
The ``search_space_id`` FK on both tables is left unchanged (still
``ON DELETE CASCADE``). When a search space is deleted, all documents,
folders, AND their revisions go together that's the correct teardown
story.
"""
from __future__ import annotations
from collections.abc import Sequence
import sqlalchemy as sa
from sqlalchemy import inspect
from alembic import op
revision: str = "134"
down_revision: str | None = "133"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def _fk_name(bind, table: str, column: str) -> str | None:
"""Return the (single) FK constraint name on ``table.column``, if any."""
inspector = inspect(bind)
for fk in inspector.get_foreign_keys(table):
cols = fk.get("constrained_columns") or []
if cols == [column]:
return fk.get("name")
return None
def upgrade() -> None:
bind = op.get_bind()
# --- document_revisions.document_id -> nullable + SET NULL ---------------
fk_name = _fk_name(bind, "document_revisions", "document_id")
if fk_name:
op.drop_constraint(fk_name, "document_revisions", type_="foreignkey")
op.alter_column(
"document_revisions",
"document_id",
existing_type=sa.Integer(),
nullable=True,
)
op.create_foreign_key(
"document_revisions_document_id_fkey",
"document_revisions",
"documents",
["document_id"],
["id"],
ondelete="SET NULL",
)
# --- folder_revisions.folder_id -> nullable + SET NULL -------------------
fk_name = _fk_name(bind, "folder_revisions", "folder_id")
if fk_name:
op.drop_constraint(fk_name, "folder_revisions", type_="foreignkey")
op.alter_column(
"folder_revisions",
"folder_id",
existing_type=sa.Integer(),
nullable=True,
)
op.create_foreign_key(
"folder_revisions_folder_id_fkey",
"folder_revisions",
"folders",
["folder_id"],
["id"],
ondelete="SET NULL",
)
def downgrade() -> None:
bind = op.get_bind()
# Reinstating NOT NULL + CASCADE requires draining orphan rows first
# (any revision whose parent doc/folder has already been deleted).
op.execute("DELETE FROM document_revisions WHERE document_id IS NULL")
op.execute("DELETE FROM folder_revisions WHERE folder_id IS NULL")
# --- document_revisions.document_id -> NOT NULL + CASCADE ---------------
fk_name = _fk_name(bind, "document_revisions", "document_id")
if fk_name:
op.drop_constraint(fk_name, "document_revisions", type_="foreignkey")
op.alter_column(
"document_revisions",
"document_id",
existing_type=sa.Integer(),
nullable=False,
)
op.create_foreign_key(
"document_revisions_document_id_fkey",
"document_revisions",
"documents",
["document_id"],
["id"],
ondelete="CASCADE",
)
# --- folder_revisions.folder_id -> NOT NULL + CASCADE -------------------
fk_name = _fk_name(bind, "folder_revisions", "folder_id")
if fk_name:
op.drop_constraint(fk_name, "folder_revisions", type_="foreignkey")
op.alter_column(
"folder_revisions",
"folder_id",
existing_type=sa.Integer(),
nullable=False,
)
op.create_foreign_key(
"folder_revisions_folder_id_fkey",
"folder_revisions",
"folders",
["folder_id"],
["id"],
ondelete="CASCADE",
)

View file

@ -0,0 +1,82 @@
"""135_action_log_correlation_ids
Revision ID: 135
Revises: 134
Create Date: 2026-04-29
Action-log correlation-id cleanup.
Background
----------
``agent_action_log.turn_id`` is misnamed. ``ActionLogMiddleware`` writes
the LangChain ``tool_call.id`` into that column today (see
``action_log.py:_resolve_turn_id``), and ``kb_persistence._find_action_ids_batch``
joins on it as such. The real chat-turn id (``f"{chat_id}:{ms}"`` from
``stream_new_chat.py``) lives in ``config.configurable.turn_id`` and was
never persisted.
This migration introduces two new, correctly-named columns:
* ``tool_call_id`` (LangChain tool-call id, what ``turn_id`` actually held)
* ``chat_turn_id`` (the per-turn correlation id from
``configurable.turn_id`` used by the per-turn ``revert-turn`` route).
Backfill copies the current ``turn_id`` values into ``tool_call_id`` so
existing joins keep working. The old ``turn_id`` column is left in place
for one release as a deprecated alias to give safe rollback. ``ActionLogMiddleware``
keeps writing it (= ``tool_call_id``) for the same reason.
Indexes
-------
* ``ix_agent_action_log_tool_call_id`` required by
``_find_action_ids_batch`` (was on ``turn_id``).
* ``ix_agent_action_log_chat_turn_id`` required by the
``revert-turn/{chat_turn_id}`` query.
"""
from __future__ import annotations
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "135"
down_revision: str | None = "134"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.add_column(
"agent_action_log",
sa.Column("tool_call_id", sa.String(length=64), nullable=True),
)
op.add_column(
"agent_action_log",
sa.Column("chat_turn_id", sa.String(length=64), nullable=True),
)
op.create_index(
"ix_agent_action_log_tool_call_id",
"agent_action_log",
["tool_call_id"],
)
op.create_index(
"ix_agent_action_log_chat_turn_id",
"agent_action_log",
["chat_turn_id"],
)
op.execute(
"UPDATE agent_action_log SET tool_call_id = turn_id WHERE tool_call_id IS NULL"
)
def downgrade() -> None:
op.drop_index("ix_agent_action_log_chat_turn_id", table_name="agent_action_log")
op.drop_index("ix_agent_action_log_tool_call_id", table_name="agent_action_log")
op.drop_column("agent_action_log", "chat_turn_id")
op.drop_column("agent_action_log", "tool_call_id")

View file

@ -0,0 +1,52 @@
"""136_new_chat_message_turn_id
Revision ID: 136
Revises: 135
Create Date: 2026-04-29
Persist the per-turn correlation id on each chat message.
Background
----------
LangGraph's checkpointer stores user-provided ``configurable.turn_id``
in checkpoint metadata (see
``langgraph/checkpoint/base/__init__.py:get_checkpoint_metadata``). To
support edit-from-arbitrary-position, the regenerate route needs to map
a ``message_id`` -> ``turn_id`` -> checkpoint at request time. Without
this column the mapping doesn't exist anywhere, so regenerate would
have to hardcode the "last 2 messages" rewind heuristic.
This migration adds a nullable ``turn_id`` column to ``new_chat_messages``
plus an index. Legacy rows have NULL the regenerate route degrades
gracefully to the reload-last-two heuristic for those.
"""
from __future__ import annotations
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "136"
down_revision: str | None = "135"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.add_column(
"new_chat_messages",
sa.Column("turn_id", sa.String(length=64), nullable=True),
)
op.create_index(
"ix_new_chat_messages_turn_id",
"new_chat_messages",
["turn_id"],
)
def downgrade() -> None:
op.drop_index("ix_new_chat_messages_turn_id", table_name="new_chat_messages")
op.drop_column("new_chat_messages", "turn_id")

View file

@ -0,0 +1,74 @@
"""137_unique_reverse_of_in_action_log
Revision ID: 137
Revises: 136
Create Date: 2026-04-29
Protect ``agent_action_log.reverse_of`` against double inserts. Two
concurrent revert calls (single-action route + the per-turn batch
route, or two batch routes racing) both pass the
``_was_already_reverted`` SELECT and both insert their own
``_revert:*`` rows. The application-level idempotency check is racy
because there's no DB constraint backing it.
This migration adds a partial unique index on ``reverse_of`` (PostgreSQL
``WHERE reverse_of IS NOT NULL``) so the second concurrent insert raises
``IntegrityError`` and the route can translate it to ``"already_reverted"``
deterministically.
The plain ``UniqueConstraint`` flavour can't be used because most
existing rows have ``reverse_of = NULL`` (only revert rows fill it),
and Postgres does treat NULL as distinct in unique indexes but a
partial index is the cleanest expression of intent and works even on
older Postgres releases that distinguish NULL handling.
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
revision: str = "137"
down_revision: str | None = "136"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
_INDEX_NAME = "ux_agent_action_log_reverse_of"
def upgrade() -> None:
# Defensively de-dup any pre-existing double-revert rows before
# adding the unique index. Keeps the OLDEST row (smallest id) and
# NULLs out the duplicates' ``reverse_of`` so they survive as audit
# trail but no longer claim to be the canonical revert. We do NOT
# delete them — operators can still inspect them via /actions.
op.execute(
"""
WITH dups AS (
SELECT id,
reverse_of,
ROW_NUMBER() OVER (
PARTITION BY reverse_of ORDER BY id ASC
) AS rn
FROM agent_action_log
WHERE reverse_of IS NOT NULL
)
UPDATE agent_action_log
SET reverse_of = NULL
WHERE id IN (SELECT id FROM dups WHERE rn > 1)
"""
)
op.create_index(
_INDEX_NAME,
"agent_action_log",
["reverse_of"],
unique=True,
postgresql_where="reverse_of IS NOT NULL",
)
def downgrade() -> None:
op.drop_index(_INDEX_NAME, table_name="agent_action_log")

View file

@ -0,0 +1,44 @@
"""138_add_thread_auto_model_pinning_fields
Revision ID: 138
Revises: 137
Create Date: 2026-04-30
Add a single thread-level column to persist the Auto (Fastest) model pin:
- pinned_llm_config_id: concrete resolved global LLM config id used for this
thread. NULL means "no pin; Auto will resolve on next turn".
The column is unindexed: all reads are by new_chat_threads.id (primary key),
so a secondary index would be dead write amplification.
"""
from __future__ import annotations
from collections.abc import Sequence
from alembic import op
revision: str = "138"
down_revision: str | None = "137"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
op.execute(
"ALTER TABLE new_chat_threads "
"ADD COLUMN IF NOT EXISTS pinned_llm_config_id INTEGER"
)
def downgrade() -> None:
# Drop any shape the thread row may be carrying. The extra columns and
# indexes only exist on dev DBs that ran an earlier draft of 138; IF EXISTS
# makes each statement a safe no-op on the lean shape.
op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode")
op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id")
op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at")
op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode")
op.execute(
"ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_llm_config_id"
)

View file

@ -0,0 +1,160 @@
"""add user table to zero_publication with column list
Adds the "user" table to zero_publication with a column-list publication
so that only the 5 fields driving the live usage meters are replicated
through WAL -> zero-cache -> browser IndexedDB:
id, pages_limit, pages_used,
premium_tokens_limit, premium_tokens_used
Sensitive columns (hashed_password, email, oauth_account, display_name,
avatar_url, memory_md, refresh_tokens, last_login, etc.) are NOT
included in the publication, so they never enter WAL replication.
Also re-asserts REPLICA IDENTITY DEFAULT on "user" for idempotency
(it is already DEFAULT today since "user" was never in the
TABLES_WITH_FULL_IDENTITY list of migration 117).
IMPORTANT - before AND after running this migration:
1. Stop zero-cache (it holds replication locks that will deadlock DDL)
2. Run: alembic upgrade head
3. Delete / reset the zero-cache data volume
4. Restart zero-cache (it will do a fresh initial sync)
Revision ID: 139
Revises: 138
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "139"
down_revision: str | None = "138"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
PUBLICATION_NAME = "zero_publication"
# Document column list as left by migration 117. Must match exactly.
DOCUMENT_COLS = [
"id",
"title",
"document_type",
"search_space_id",
"folder_id",
"created_by_id",
"status",
"created_at",
"updated_at",
]
# Five fields needed by the live usage meters (sidebar Tokens/Pages,
# Buy Tokens content). Keep this list narrow on purpose: anything added
# here flows into WAL and IndexedDB for every connected browser.
USER_COLS = [
"id",
"pages_limit",
"pages_used",
"premium_tokens_limit",
"premium_tokens_used",
]
def _terminate_blocked_pids(conn, table: str) -> None:
"""Kill backends whose locks on *table* would block our AccessExclusiveLock."""
conn.execute(
sa.text(
"SELECT pg_terminate_backend(l.pid) "
"FROM pg_locks l "
"JOIN pg_class c ON c.oid = l.relation "
"WHERE c.relname = :tbl "
" AND l.pid != pg_backend_pid()"
),
{"tbl": table},
)
def _has_zero_version(conn, table: str) -> bool:
return (
conn.execute(
sa.text(
"SELECT 1 FROM information_schema.columns "
"WHERE table_name = :tbl AND column_name = '_0_version'"
),
{"tbl": table},
).fetchone()
is not None
)
def _build_publication_ddl(
documents_has_zero_ver: bool, user_has_zero_ver: bool
) -> str:
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
user_cols = USER_COLS + (['"_0_version"'] if user_has_zero_ver else [])
doc_col_list = ", ".join(doc_cols)
user_col_list = ", ".join(user_cols)
return (
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
f"notifications, "
f"documents ({doc_col_list}), "
f"folders, "
f"search_source_connectors, "
f"new_chat_messages, "
f"chat_comments, "
f"chat_session_state, "
f'"user" ({user_col_list})'
)
def _build_publication_ddl_without_user(documents_has_zero_ver: bool) -> str:
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
doc_col_list = ", ".join(doc_cols)
return (
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
f"notifications, "
f"documents ({doc_col_list}), "
f"folders, "
f"search_source_connectors, "
f"new_chat_messages, "
f"chat_comments, "
f"chat_session_state"
)
def upgrade() -> None:
conn = op.get_bind()
# asyncpg requires LOCK TABLE inside a transaction block. Alembic already
# opened one via context.begin_transaction(), but the driver still errors
# unless we use an explicit SAVEPOINT (nested transaction) for this block.
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
with tx:
conn.execute(sa.text("SET lock_timeout = '10s'"))
_terminate_blocked_pids(conn, "user")
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
# Idempotent: "user" was never in TABLES_WITH_FULL_IDENTITY of
# migration 117, so this is already DEFAULT. Re-assert anyway so
# the column-list publication stays valid (DEFAULT identity only
# requires the PK to be in the column list).
conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT'))
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
documents_has_zero_ver = _has_zero_version(conn, "documents")
user_has_zero_ver = _has_zero_version(conn, "user")
conn.execute(
sa.text(_build_publication_ddl(documents_has_zero_ver, user_has_zero_ver))
)
def downgrade() -> None:
conn = op.get_bind()
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
documents_has_zero_ver = _has_zero_version(conn, "documents")
conn.execute(sa.text(_build_publication_ddl_without_user(documents_has_zero_ver)))

View file

@ -10,7 +10,9 @@ We use ``create_agent`` (from langchain) rather than ``create_deep_agent``
This lets us swap in ``SurfSenseFilesystemMiddleware`` a customisable
subclass of the default ``FilesystemMiddleware`` while preserving every
other behaviour that ``create_deep_agent`` provides (todo-list, subagents,
summarisation, prompt-caching, etc.).
summarisation, etc.). Prompt caching is configured at LLM-build time via
``apply_litellm_prompt_caching`` (LiteLLM-native, multi-provider) rather
than as a middleware.
"""
import asyncio
@ -33,7 +35,6 @@ from langchain.agents.middleware import (
TodoListMiddleware,
ToolCallLimitMiddleware,
)
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
@ -74,6 +75,7 @@ from app.agents.new_chat.plugin_loader import (
load_allowed_plugin_names_from_env,
load_plugin_middlewares,
)
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
from app.agents.new_chat.subagents import build_specialized_subagents
from app.agents.new_chat.system_prompt import (
build_configurable_system_prompt,
@ -94,6 +96,39 @@ from app.utils.perf import get_perf_logger
_perf_log = get_perf_logger()
def _resolve_prompt_model_name(
agent_config: AgentConfig | None,
llm: BaseChatModel,
) -> str | None:
"""Resolve the model id to feed to provider-variant detection.
Preference order (matches the established idiom in
``llm_router_service.py`` see ``params.get("base_model") or
params.get("model", "")`` usages there):
1. ``agent_config.litellm_params["base_model"]`` required for Azure
deployments where ``model_name`` is the deployment slug, not the
underlying family. Without this, a deployment named e.g.
``"prod-chat-001"`` would silently miss every provider regex.
2. ``agent_config.model_name`` the user's configured model id.
3. ``getattr(llm, "model", None)`` fallback for direct callers that
don't supply an ``AgentConfig`` (currently a defensive path; all
production callers pass ``agent_config``).
Returns ``None`` when nothing is available; ``compose_system_prompt``
treats that as the ``"default"`` variant (no provider block emitted).
"""
if agent_config is not None:
params = agent_config.litellm_params or {}
base_model = params.get("base_model")
if isinstance(base_model, str) and base_model.strip():
return base_model
if agent_config.model_name:
return agent_config.model_name
return getattr(llm, "model", None)
# =============================================================================
# Connector Type Mapping
# =============================================================================
@ -279,6 +314,14 @@ async def create_surfsense_deep_agent(
)
"""
_t_agent_total = time.perf_counter()
# Layer thread-aware prompt caching onto the LLM. Idempotent with the
# build-time call in ``llm_config.py``; this run merely adds
# ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` for OpenAI-family
# configs now that ``thread_id`` is known. No-op when ``thread_id`` is
# None or the provider is non-OpenAI-family.
apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id)
filesystem_selection = filesystem_selection or FilesystemSelection()
backend_resolver = build_backend_resolver(
filesystem_selection,
@ -398,6 +441,7 @@ async def create_surfsense_deep_agent(
enabled_tool_names=_enabled_tool_names,
disabled_tool_names=_user_disabled_tool_names,
mcp_connector_tools=_mcp_connector_tools,
model_name=_resolve_prompt_model_name(agent_config, llm),
)
else:
system_prompt = build_surfsense_system_prompt(
@ -405,6 +449,7 @@ async def create_surfsense_deep_agent(
enabled_tool_names=_enabled_tool_names,
disabled_tool_names=_user_disabled_tool_names,
mcp_connector_tools=_mcp_connector_tools,
model_name=_resolve_prompt_model_name(agent_config, llm),
)
_perf_log.info(
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
@ -568,7 +613,6 @@ def _build_compiled_agent_blocking(
),
create_surfsense_compaction_middleware(llm, StateBackend),
PatchToolCallsMiddleware(),
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
]
general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key]
@ -724,7 +768,8 @@ def _build_compiled_agent_blocking(
repair_mw = None
if flags.enable_tool_call_repair and not flags.disable_new_agent_stack:
registered_names: set[str] = {t.name for t in tools}
# Tools owned by the standard deepagents middleware stack.
# Tools owned by the standard deepagents middleware stack and the
# SurfSense filesystem extension.
registered_names |= {
"write_todos",
"ls",
@ -735,6 +780,14 @@ def _build_compiled_agent_blocking(
"grep",
"execute",
"task",
"mkdir",
"cd",
"pwd",
"move_file",
"rm",
"rmdir",
"list_tree",
"execute_code",
}
repair_mw = ToolCallNameRepairMiddleware(
registered_tool_names=registered_names,
@ -763,25 +816,51 @@ def _build_compiled_agent_blocking(
# on every safe read-only call (``ls``, ``read_file``, ``grep``,
# ``glob``, ``web_search`` …) and, on resume, replay the previous
# reject decision into innocent calls.
# 2. ``connector_synthesized`` — deny rules for tools whose required
# connector is not connected to this space. Overrides #1.
# 3. (future) user-defined rules from ``agent_permission_rules`` table
# via the Agent Permissions UI. Loaded last so they override both.
# 2. ``desktop_safety`` — ``ask`` for destructive filesystem ops when
# the agent is operating against the user's real disk. Cloud mode
# has full revision-based revert via ``revert_service``, but
# desktop mode hits disk immediately with no undo, so an
# accidental ``rm`` / ``rmdir`` / ``move_file`` / ``edit_file`` /
# ``write_file`` is unrecoverable. This layer is forced on in
# desktop mode regardless of ``enable_permission`` because the
# safety net is non-negotiable.
# 3. ``connector_synthesized`` — deny rules for tools whose required
# connector is not connected to this space. Overrides #1/#2.
# 4. (future) user-defined rules from ``agent_permission_rules`` table
# via the Agent Permissions UI. Loaded last so they override all.
permission_mw: PermissionMiddleware | None = None
if flags.enable_permission and not flags.disable_new_agent_stack:
synthesized = _synthesize_connector_deny_rules(
available_connectors=available_connectors,
enabled_tool_names={t.name for t in tools},
)
permission_mw = PermissionMiddleware(
rulesets=[
is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
permission_enabled = flags.enable_permission and not flags.disable_new_agent_stack
# Build the middleware whenever it has work to do: either the user
# opted into the rule engine, OR we're in desktop mode and need the
# safety rules unconditionally.
if permission_enabled or is_desktop_fs:
rulesets: list[Ruleset] = [
Ruleset(
rules=[Rule(permission="*", pattern="*", action="allow")],
origin="surfsense_defaults",
),
]
if is_desktop_fs:
rulesets.append(
Ruleset(
rules=[Rule(permission="*", pattern="*", action="allow")],
origin="surfsense_defaults",
),
Ruleset(rules=synthesized, origin="connector_synthesized"),
],
)
rules=[
Rule(permission="rm", pattern="*", action="ask"),
Rule(permission="rmdir", pattern="*", action="ask"),
Rule(permission="move_file", pattern="*", action="ask"),
Rule(permission="edit_file", pattern="*", action="ask"),
Rule(permission="write_file", pattern="*", action="ask"),
],
origin="desktop_safety",
)
)
if permission_enabled:
synthesized = _synthesize_connector_deny_rules(
available_connectors=available_connectors,
enabled_tool_names={t.name for t in tools},
)
rulesets.append(Ruleset(rules=synthesized, origin="connector_synthesized"))
permission_mw = PermissionMiddleware(rulesets=rulesets)
# ActionLogMiddleware. Off by default until the ``agent_action_log``
# table is migrated. When enabled, persists one row per tool call
@ -938,6 +1017,7 @@ def _build_compiled_agent_blocking(
search_space_id=search_space_id,
created_by_id=user_id,
filesystem_mode=filesystem_mode,
thread_id=thread_id,
)
if filesystem_mode == FilesystemMode.CLOUD
else None,
@ -970,12 +1050,12 @@ def _build_compiled_agent_blocking(
action_log_mw,
PatchToolCallsMiddleware(),
DedupHITLToolCallsMiddleware(agent_tools=list(tools)),
# Plugin slot — sits just before AnthropicCache so plugin-side
# transforms see the final tool result and run before any
# caching heuristics. Multiple plugins in declared order; loader
# filtered by the admin allowlist already.
# Plugin slot — sits at the tail so plugin-side transforms see the
# final tool result. Prompt caching is now applied at LLM build time
# via ``apply_litellm_prompt_caching`` (see prompt_caching.py), so no
# caching middleware is needed here. Multiple plugins run in declared
# order; loader filtered by the admin allowlist already.
*plugin_middlewares,
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
]
deepagent_middleware = [m for m in deepagent_middleware if m is not None]

View file

@ -23,6 +23,7 @@ Local development (recommended for trying everything except doom-loop / selector
SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy
SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false
SURFSENSE_ENABLE_STREAM_PARITY_V2=false # structured streaming events
Master kill-switch (overrides everything else):
@ -86,6 +87,15 @@ class AgentFeatureFlags:
False # Backend ships before UI; route returns 503 until this flips
)
# Streaming parity v2 — opt in to LangChain's structured
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input
# deltas) and propagate the real ``tool_call_id`` to the SSE layer.
# When OFF the ``stream_new_chat`` task falls back to the str-only
# text path and the synthetic ``call_<run_id>`` tool-call id (no
# ``langchainToolCallId`` propagation). Schema migrations 135/136
# ship unconditionally because they're forward-compatible.
enable_stream_parity_v2: bool = False
# Plugins
enable_plugin_loader: bool = False
@ -139,6 +149,10 @@ class AgentFeatureFlags:
# Snapshot / revert
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False),
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False),
# Streaming parity v2
enable_stream_parity_v2=_env_bool(
"SURFSENSE_ENABLE_STREAM_PARITY_V2", False
),
# Plugins
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
# Observability

View file

@ -5,9 +5,14 @@ extra fields needed to implement Postgres-backed virtual filesystem semantics:
* ``cwd`` current working directory (per-thread checkpointed).
* ``staged_dirs`` pending mkdir requests (cloud only).
* ``staged_dir_tool_calls`` sidecar map ``path -> tool_call_id`` for staged dirs.
* ``pending_moves`` pending move_file requests (cloud only).
* ``pending_deletes`` pending ``rm`` requests (cloud only).
* ``pending_dir_deletes`` pending ``rmdir`` requests (cloud only).
* ``doc_id_by_path`` virtual_path -> Document.id, populated by lazy reads.
* ``dirty_paths`` paths whose state file content differs from DB.
* ``dirty_path_tool_calls`` sidecar map ``path -> latest tool_call_id`` for
dirty paths; used to bind the per-path snapshot to an action_id.
* ``kb_priority`` top-K priority hints rendered into a system message.
* ``kb_matched_chunk_ids`` internal hand-off for matched-chunk highlighting.
* ``kb_anon_doc`` Redis-loaded anonymous document (if any).
@ -32,12 +37,31 @@ from app.agents.new_chat.state_reducers import (
)
class PendingMove(TypedDict):
"""A staged move_file operation pending end-of-turn commit."""
class PendingMove(TypedDict, total=False):
"""A staged move_file operation pending end-of-turn commit.
``tool_call_id`` is optional for backward compatibility with checkpoints
written before the snapshot/revert pipeline was wired up; new entries
always include it so the persistence body can resolve an action_id.
"""
source: str
dest: str
overwrite: bool
tool_call_id: str
class PendingDelete(TypedDict, total=False):
"""A staged ``rm`` or ``rmdir`` operation pending end-of-turn commit.
``tool_call_id`` is required for new entries (it's the binding key used
by :class:`KnowledgeBasePersistenceMiddleware` to find the matching
:class:`AgentActionLog` row and bind the snapshot to it). Marked
``total=False`` only to tolerate older checkpoint payloads.
"""
path: str
tool_call_id: str
class KbPriorityEntry(TypedDict, total=False):
@ -76,9 +100,38 @@ class SurfSenseFilesystemState(FilesystemState):
staged_dirs: NotRequired[Annotated[list[str], _add_unique_reducer]]
"""mkdir paths staged for end-of-turn folder creation (cloud only)."""
staged_dir_tool_calls: NotRequired[
Annotated[dict[str, str], _dict_merge_with_tombstones_reducer]
]
"""``path -> tool_call_id`` sidecar for ``staged_dirs``.
Used by :class:`KnowledgeBasePersistenceMiddleware` to bind the
:class:`FolderRevision` snapshot to the originating ``mkdir`` action.
Kept separate from ``staged_dirs`` (which stays a unique-string list)
to avoid breaking ``_add_unique_reducer`` semantics.
"""
pending_moves: NotRequired[Annotated[list[PendingMove], _list_append_reducer]]
"""move_file ops staged for end-of-turn commit (cloud only)."""
pending_deletes: NotRequired[Annotated[list[PendingDelete], _list_append_reducer]]
"""``rm`` ops staged for end-of-turn ``DELETE FROM documents`` (cloud only).
Each entry is a dict ``{"path": ..., "tool_call_id": ...}``. Per-path
uniqueness is enforced inside the commit body, not the reducer (we keep
``tool_call_id`` per occurrence so snapshot binding works).
"""
pending_dir_deletes: NotRequired[
Annotated[list[PendingDelete], _list_append_reducer]
]
"""``rmdir`` ops staged for end-of-turn ``DELETE FROM folders`` (cloud only).
Same shape as :data:`pending_deletes`. Commit body re-verifies the
folder is empty (in-DB AND with this turn's pending changes accounted
for) before issuing the DELETE.
"""
doc_id_by_path: NotRequired[
Annotated[dict[str, int], _dict_merge_with_tombstones_reducer]
]
@ -92,6 +145,17 @@ class SurfSenseFilesystemState(FilesystemState):
dirty_paths: NotRequired[Annotated[list[str], _add_unique_reducer]]
"""Paths whose ``state["files"]`` content has been modified this turn."""
dirty_path_tool_calls: NotRequired[
Annotated[dict[str, str], _dict_merge_with_tombstones_reducer]
]
"""``path -> latest tool_call_id`` sidecar for ``dirty_paths``.
The persistence body coalesces multiple writes/edits to the same path
into one snapshot per turn. This map captures the most-recent
``tool_call_id`` so the resulting :class:`DocumentRevision` is bound
to the latest action_id (the one the user is most likely to revert).
"""
kb_priority: NotRequired[Annotated[list[KbPriorityEntry], _replace_reducer]]
"""Top-K priority hints rendered as a system message before the user turn."""
@ -108,6 +172,7 @@ class SurfSenseFilesystemState(FilesystemState):
__all__ = [
"KbAnonDoc",
"KbPriorityEntry",
"PendingDelete",
"PendingMove",
"SurfSenseFilesystemState",
]

View file

@ -27,6 +27,7 @@ from litellm import get_model_info
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
from app.services.llm_router_service import (
AUTO_MODE_ID,
ChatLiteLLMRouter,
@ -494,6 +495,11 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
llm = SanitizedChatLiteLLM(**litellm_kwargs)
_attach_model_profile(llm, model_string)
# Configure LiteLLM-native prompt caching (cache_control_injection_points
# for Anthropic/Bedrock/Vertex/Gemini/Azure-AI/OpenRouter/Databricks/etc.).
# ``agent_config=None`` here — the YAML path doesn't have provider intent
# in a structured form, so we set only the universal injection points.
apply_litellm_prompt_caching(llm)
return llm
@ -518,7 +524,16 @@ def create_chat_litellm_from_agent_config(
print("Error: Auto mode requested but LLM Router not initialized")
return None
try:
return get_auto_mode_llm()
router_llm = get_auto_mode_llm()
if router_llm is not None:
# Universal cache_control_injection_points only — auto-mode
# fans out across providers, so OpenAI-only kwargs (e.g.
# ``prompt_cache_key``) are left off here. ``drop_params``
# would strip them at the provider boundary anyway, but
# there's no point setting them when we don't know the
# destination.
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
return router_llm
except Exception as e:
print(f"Error creating ChatLiteLLMRouter: {e}")
return None
@ -549,4 +564,9 @@ def create_chat_litellm_from_agent_config(
llm = SanitizedChatLiteLLM(**litellm_kwargs)
_attach_model_profile(llm, model_string)
# Build-time prompt caching: sets ``cache_control_injection_points`` for
# all providers and (for OpenAI/DeepSeek/xAI) ``prompt_cache_retention``.
# Per-thread ``prompt_cache_key`` is layered on later in
# ``create_surfsense_deep_agent`` once ``thread_id`` is known.
apply_litellm_prompt_caching(llm, agent_config=agent_config)
return llm

View file

@ -30,6 +30,7 @@ from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any
from langchain.agents.middleware import AgentMiddleware
from langchain_core.callbacks import adispatch_custom_event
from langchain_core.messages import ToolMessage
from app.agents.new_chat.feature_flags import get_flags
@ -144,11 +145,19 @@ class ActionLogMiddleware(AgentMiddleware):
result=result,
)
tool_call_id = _resolve_tool_call_id(request)
chat_turn_id = _resolve_chat_turn_id(request)
row = AgentActionLog(
thread_id=self._thread_id,
user_id=self._user_id,
search_space_id=self._search_space_id,
turn_id=_resolve_turn_id(request),
# ``turn_id`` is the deprecated alias of ``tool_call_id``
# kept for one release for safe rollback. New consumers
# should read ``tool_call_id`` directly.
turn_id=tool_call_id,
tool_call_id=tool_call_id,
chat_turn_id=chat_turn_id,
message_id=_resolve_message_id(request),
tool_name=tool_name,
args=args_payload,
@ -160,11 +169,41 @@ class ActionLogMiddleware(AgentMiddleware):
async with shielded_async_session() as session:
session.add(row)
await session.commit()
row_id = int(row.id) if row.id is not None else None
row_created_at = row.created_at
except Exception:
logger.warning(
"ActionLogMiddleware failed to persist action log row",
exc_info=True,
)
return
# Surface a side-channel SSE event so the chat tool card can
# render a Revert button immediately after the row is durable.
# ``stream_new_chat`` translates this into a
# ``data-action-log`` SSE event. We DO NOT include the
# ``reverse_descriptor`` payload here; only a presence flag.
try:
await adispatch_custom_event(
"action_log",
{
"id": row_id,
"lc_tool_call_id": tool_call_id,
"chat_turn_id": chat_turn_id,
"tool_name": tool_name,
"reversible": bool(reversible),
"reverse_descriptor_present": reverse_descriptor is not None,
"created_at": row_created_at.isoformat()
if row_created_at
else None,
"error": error_payload is not None,
},
)
except Exception:
logger.debug(
"ActionLogMiddleware failed to dispatch action_log event",
exc_info=True,
)
def _render_reverse(
self,
@ -254,7 +293,8 @@ def _resolve_args_payload(request: Any) -> dict[str, Any] | None:
}
def _resolve_turn_id(request: Any) -> str | None:
def _resolve_tool_call_id(request: Any) -> str | None:
"""Return the LangChain ``tool_call.id`` for this request, if any."""
try:
call = getattr(request, "tool_call", None) or {}
if isinstance(call, dict):
@ -266,9 +306,40 @@ def _resolve_turn_id(request: Any) -> str | None:
return None
# Deprecated alias kept for one release. Old callers and tests treated
# ``turn_id`` as if it carried the LangChain tool_call id; the new column
# lives under ``tool_call_id``. Both resolve to the same value today.
_resolve_turn_id = _resolve_tool_call_id
def _resolve_chat_turn_id(request: Any) -> str | None:
"""Return ``configurable.turn_id`` for this request, if accessible.
``ToolRuntime.config`` is exposed by LangGraph (see
``langgraph/prebuilt/tool_node.py``); the chat-turn correlation id
lives at ``runtime.config["configurable"]["turn_id"]``.
"""
try:
runtime = getattr(request, "runtime", None)
if runtime is None:
return None
config = getattr(runtime, "config", None)
if not isinstance(config, dict):
return None
configurable = config.get("configurable")
if not isinstance(configurable, dict):
return None
value = configurable.get("turn_id")
if isinstance(value, str) and value:
return value
except Exception: # pragma: no cover - defensive
pass
return None
def _resolve_message_id(request: Any) -> str | None:
"""Tool-call IDs serve as best-available message correlator at this layer."""
return _resolve_turn_id(request)
return _resolve_tool_call_id(request)
def _resolve_result_id(result: Any) -> str | None:

View file

@ -33,6 +33,7 @@ from __future__ import annotations
import asyncio
import logging
import time
import weakref
from typing import Any
@ -58,6 +59,11 @@ class _ThreadLockManager:
weakref.WeakValueDictionary()
)
self._cancel_events: dict[str, asyncio.Event] = {}
self._cancel_requested_at_ms: dict[str, int] = {}
self._cancel_attempt_count: dict[str, int] = {}
# Monotonic per-thread epoch used to prevent stale middleware
# teardown from releasing a newer turn's lock.
self._turn_epoch: dict[str, int] = {}
def lock_for(self, thread_id: str) -> asyncio.Lock:
lock = self._locks.get(thread_id)
@ -76,14 +82,57 @@ class _ThreadLockManager:
def request_cancel(self, thread_id: str) -> bool:
event = self._cancel_events.get(thread_id)
if event is None:
return False
event = asyncio.Event()
self._cancel_events[thread_id] = event
event.set()
now_ms = int(time.time() * 1000)
self._cancel_requested_at_ms[thread_id] = now_ms
self._cancel_attempt_count[thread_id] = (
self._cancel_attempt_count.get(thread_id, 0) + 1
)
return True
def is_cancel_requested(self, thread_id: str) -> bool:
event = self._cancel_events.get(thread_id)
return bool(event and event.is_set())
def cancel_state(self, thread_id: str) -> tuple[int, int] | None:
if not self.is_cancel_requested(thread_id):
return None
attempts = self._cancel_attempt_count.get(thread_id, 1)
requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0)
return attempts, requested_at_ms
def reset(self, thread_id: str) -> None:
event = self._cancel_events.get(thread_id)
if event is not None:
event.clear()
self._cancel_requested_at_ms.pop(thread_id, None)
self._cancel_attempt_count.pop(thread_id, None)
def bump_turn_epoch(self, thread_id: str) -> int:
epoch = self._turn_epoch.get(thread_id, 0) + 1
self._turn_epoch[thread_id] = epoch
return epoch
def current_turn_epoch(self, thread_id: str) -> int:
return self._turn_epoch.get(thread_id, 0)
def end_turn(self, thread_id: str) -> None:
"""Best-effort terminal cleanup for a thread turn.
This is intentionally idempotent and safe to call from outer stream
finally-blocks where middleware teardown might be skipped due to abort
or disconnect edge-cases.
"""
# Invalidate any in-flight middleware holder first. This guarantees a
# stale ``aafter_agent`` from an older attempt cannot unlock a newer
# retry that already acquired the lock for the same thread.
self.bump_turn_epoch(thread_id)
lock = self._locks.get(thread_id)
if lock is not None and lock.locked():
lock.release()
self.reset(thread_id)
# Module-level singleton — process-local but reused across all agent
@ -98,15 +147,30 @@ def get_cancel_event(thread_id: str) -> asyncio.Event:
def request_cancel(thread_id: str) -> bool:
"""Trip the cancel event for ``thread_id``. Returns True if found."""
"""Trip the cancel event for ``thread_id``. Always returns True."""
return manager.request_cancel(thread_id)
def is_cancel_requested(thread_id: str) -> bool:
"""Return whether ``thread_id`` currently has a pending cancel signal."""
return manager.is_cancel_requested(thread_id)
def get_cancel_state(thread_id: str) -> tuple[int, int] | None:
"""Return ``(attempt_count, requested_at_ms)`` for pending cancel state."""
return manager.cancel_state(thread_id)
def reset_cancel(thread_id: str) -> None:
"""Reset the cancel event for ``thread_id`` (called between turns)."""
manager.reset(thread_id)
def end_turn(thread_id: str) -> None:
"""Force end-of-turn cleanup for lock + cancel state."""
manager.end_turn(thread_id)
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
"""Block concurrent prompts on the same thread.
@ -129,10 +193,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
super().__init__()
self._require_thread_id = require_thread_id
self.tools = []
# Per-call locks owned by this middleware. We track them as
# an instance attribute so ``aafter_agent`` knows which lock
# to release.
self._held_locks: dict[str, asyncio.Lock] = {}
# Per-call lock ownership tracked as (lock, epoch). ``aafter_agent``
# only releases when its epoch still matches the manager's current
# epoch for the thread, preventing stale unlock races.
self._held_locks: dict[str, tuple[asyncio.Lock, int]] = {}
@staticmethod
def _thread_id(runtime: Runtime[ContextT]) -> str | None:
@ -183,7 +247,8 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
if lock.locked():
raise BusyError(request_id=thread_id)
await lock.acquire()
self._held_locks[thread_id] = lock
epoch = manager.bump_turn_epoch(thread_id)
self._held_locks[thread_id] = (lock, epoch)
# Reset the cancel event so this turn starts fresh
reset_cancel(thread_id)
return None
@ -197,8 +262,15 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
thread_id = self._thread_id(runtime)
if thread_id is None:
return None
lock = self._held_locks.pop(thread_id, None)
if lock is not None and lock.locked():
held = self._held_locks.pop(thread_id, None)
if held is None:
return None
lock, held_epoch = held
if held_epoch != manager.current_turn_epoch(thread_id):
# Stale teardown from an older attempt (e.g. runtime-recovery path
# already advanced epoch). Do not touch current lock/cancel state.
return None
if lock.locked():
lock.release()
# Always clear cancel event between turns so a stale signal
# doesn't leak into the next request.
@ -229,7 +301,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
__all__ = [
"BusyMutexMiddleware",
"end_turn",
"get_cancel_event",
"get_cancel_state",
"is_cancel_requested",
"manager",
"request_cancel",
"reset_cancel",

View file

@ -102,6 +102,8 @@ current working directory (`cwd`, default `/documents`).
- cd(path): change the current working directory.
- pwd(): print the current working directory.
- move_file(source, dest): move/rename a file under `/documents/`.
- rm(path): delete a single file under `/documents/` (no `-r`).
- rmdir(path): delete an empty directory under `/documents/`.
- list_tree(path, max_depth, page_size): recursively list files/folders.
## Persistence Rules
@ -112,8 +114,9 @@ current working directory (`cwd`, default `/documents`).
`/documents/temp_scratch.md`) are **discarded** at end of turn use this
prefix for any scratch/working content you do NOT want saved.
- All other paths (outside `/documents/` and not `temp_*`) are rejected.
- mkdir/move_file are staged this turn and committed at end of turn alongside
any new/edited documents.
- mkdir/move_file/rm/rmdir are staged this turn and committed at end of
turn alongside any new/edited documents. Snapshot/revert is enabled
for every destructive operation when action logging is on.
## Reading Documents Efficiently
@ -176,6 +179,8 @@ directory (`cwd`).
- cd(path): change the current working directory.
- pwd(): print the current working directory.
- move_file(source, dest): move/rename a file.
- rm(path): delete a single file from disk (no `-r`). NOT reversible.
- rmdir(path): delete an empty directory from disk. NOT reversible.
- list_tree(path, max_depth, page_size): recursively list files/folders.
## Workflow Tips
@ -184,6 +189,8 @@ directory (`cwd`).
- For large trees, prefer `list_tree` then `grep` then `read_file` over
brute-force directory traversal.
- Cross-mount moves are not supported.
- Desktop deletes hit disk immediately and cannot be undone via the
agent's revert flow — confirm before calling `rm`/`rmdir`.
"""
)
@ -355,6 +362,42 @@ Notes:
- Parent folders are created as needed.
"""
_CLOUD_RM_TOOL_DESCRIPTION = """Deletes a single file under `/documents/`.
Mirrors POSIX `rm path` (no `-r`, no glob expansion). Stages the deletion
for end-of-turn commit; the row is removed only after the agent's turn
finishes successfully.
Args:
- path: absolute or relative file path. Cannot point at a directory use
`rmdir` for empty folders. Cannot target the root or `/documents`.
Notes:
- The action is reversible via the per-action revert flow when action
logging is enabled.
- The anonymous uploaded document is read-only and cannot be deleted.
"""
_CLOUD_RMDIR_TOOL_DESCRIPTION = """Deletes an empty directory under `/documents/`.
Mirrors POSIX `rmdir path`: refuses non-empty directories. Recursive
deletion (`rm -r`) is intentionally NOT supported clear contents with
`rm` first.
Args:
- path: absolute or relative directory path. Cannot target the root,
`/documents`, the current cwd, or any ancestor of cwd (use `cd` to
move out first).
Notes:
- Emptiness is evaluated against the post-staged view, so a same-turn
`rm /a/x.md` followed by `rmdir /a` is fine.
- If the directory was added in this same turn via `mkdir` and never
committed, the staged mkdir is dropped instead of issuing a delete.
- The action is reversible via the per-action revert flow when action
logging is enabled.
"""
# --- desktop-only ----------------------------------------------------------
_DESKTOP_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path.
@ -421,6 +464,28 @@ Notes:
- Parent folders are created as needed.
"""
_DESKTOP_RM_TOOL_DESCRIPTION = """Deletes a single file from disk.
Mirrors POSIX `rm path` (no `-r`, no glob expansion). The deletion hits
disk immediately. Desktop deletes are NOT reversible via the agent's
revert flow.
Args:
- path: absolute mount-prefixed file path. Cannot point at a directory
use `rmdir` for empty folders.
"""
_DESKTOP_RMDIR_TOOL_DESCRIPTION = """Deletes an empty directory from disk.
Mirrors POSIX `rmdir path`: refuses non-empty directories. Recursive
deletion is NOT supported. The deletion hits disk immediately and is
NOT reversible via the agent's revert flow.
Args:
- path: absolute mount-prefixed directory path. Cannot target the mount
root or any directory containing files/subfolders.
"""
def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]:
"""Pick the active-mode description for every filesystem tool."""
@ -437,6 +502,8 @@ def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]:
"mkdir": _CLOUD_MKDIR_TOOL_DESCRIPTION,
"cd": SURFSENSE_CD_TOOL_DESCRIPTION,
"pwd": SURFSENSE_PWD_TOOL_DESCRIPTION,
"rm": _CLOUD_RM_TOOL_DESCRIPTION,
"rmdir": _CLOUD_RMDIR_TOOL_DESCRIPTION,
}
return {
"ls": _DESKTOP_LIST_FILES_TOOL_DESCRIPTION,
@ -450,6 +517,8 @@ def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]:
"mkdir": _DESKTOP_MKDIR_TOOL_DESCRIPTION,
"cd": SURFSENSE_CD_TOOL_DESCRIPTION,
"pwd": SURFSENSE_PWD_TOOL_DESCRIPTION,
"rm": _DESKTOP_RM_TOOL_DESCRIPTION,
"rmdir": _DESKTOP_RMDIR_TOOL_DESCRIPTION,
}
@ -476,6 +545,21 @@ def _basename(path: str) -> str:
return path.rsplit("/", 1)[-1]
def _is_ancestor_of(candidate: str, target: str) -> bool:
"""True iff ``candidate`` is a strict ancestor directory of ``target``.
``target`` itself is NOT considered an ancestor (use equality for that).
Both paths are assumed to be canonicalised, absolute, and free of
trailing slashes (except the root ``/``).
"""
if not candidate.startswith("/") or not target.startswith("/"):
return False
if candidate == target:
return False
prefix = candidate.rstrip("/") + "/"
return target.startswith(prefix)
class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
"""SurfSense-specific filesystem middleware (cloud + desktop)."""
@ -519,6 +603,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
self.tools.append(self._create_cd_tool())
self.tools.append(self._create_pwd_tool())
self.tools.append(self._create_move_file_tool())
self.tools.append(self._create_rm_tool())
self.tools.append(self._create_rmdir_tool())
self.tools.append(self._create_list_tree_tool())
if self._sandbox_available:
self.tools.append(self._create_execute_code_tool())
@ -941,6 +1027,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
}
if self._is_cloud():
update["dirty_paths"] = [path]
update["dirty_path_tool_calls"] = {path: runtime.tool_call_id}
return Command(update=update)
def sync_write_file(
@ -1036,6 +1123,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
}
if self._is_cloud():
update["dirty_paths"] = [path]
update["dirty_path_tool_calls"] = {path: runtime.tool_call_id}
if doc_id_to_attach is not None:
update["doc_id_by_path"] = {path: doc_id_to_attach}
return Command(update=update)
@ -1103,6 +1191,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
return Command(
update={
"staged_dirs": [validated],
"staged_dir_tool_calls": {
validated: runtime.tool_call_id,
},
"messages": [
ToolMessage(
content=(
@ -1372,7 +1463,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
files_update: dict[str, Any] = {source: None, dest: source_file_data}
update: dict[str, Any] = {
"files": files_update,
"pending_moves": [{"source": source, "dest": dest, "overwrite": False}],
"pending_moves": [
{
"source": source,
"dest": dest,
"overwrite": False,
"tool_call_id": runtime.tool_call_id,
}
],
"messages": [
ToolMessage(
content=(
@ -1396,6 +1494,323 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
update["dirty_paths"] = new_dirty
return Command(update=update)
# ------------------------------------------------------------------ tool: rm
def _create_rm_tool(self) -> BaseTool:
tool_description = (
self._custom_tool_descriptions.get("rm") or _CLOUD_RM_TOOL_DESCRIPTION
)
async def async_rm(
path: Annotated[
str,
"Absolute or relative path to the file to delete.",
],
runtime: ToolRuntime[None, SurfSenseFilesystemState],
) -> Command | str:
if not path or not path.strip():
return "Error: path is required."
target = self._resolve_relative(path, runtime)
try:
validated = validate_path(target)
except ValueError as exc:
return f"Error: {exc}"
if self._is_cloud():
if validated in ("/", DOCUMENTS_ROOT):
return f"Error: refusing to rm '{validated}'."
if not validated.startswith(DOCUMENTS_ROOT + "/"):
return (
"Error: cloud rm must target a path under /documents/ "
f"(got '{validated}')."
)
anon = runtime.state.get("kb_anon_doc") or {}
if isinstance(anon, dict) and str(anon.get("path") or "") == validated:
return "Error: the anonymous uploaded document is read-only."
# Refuse if the path looks like a directory.
staged_dirs = list(runtime.state.get("staged_dirs") or [])
if validated in staged_dirs:
return (
f"Error: '{validated}' is a directory. Use rmdir for "
"empty directories."
)
pending_dir_deletes = list(
runtime.state.get("pending_dir_deletes") or []
)
if any(
isinstance(d, dict) and d.get("path") == validated
for d in pending_dir_deletes
):
return f"Error: '{validated}' is already queued for rmdir."
backend = self._get_backend(runtime)
if isinstance(backend, KBPostgresBackend):
# Detect "is a directory" via `ls`: if the path lists
# children we know it's a folder. Otherwise we still
# need to confirm it's a real file before staging.
children = await backend.als_info(validated)
if children:
return (
f"Error: '{validated}' is a directory. Use rmdir for "
"empty directories."
)
# Already queued for delete this turn?
pending_deletes = list(runtime.state.get("pending_deletes") or [])
if any(
isinstance(d, dict) and d.get("path") == validated
for d in pending_deletes
):
return f"'{validated}' is already queued for deletion."
# Resolve doc_id (best-effort): file in state or DB.
files_state = runtime.state.get("files") or {}
doc_id_by_path = runtime.state.get("doc_id_by_path") or {}
resolved_doc_id: int | None = doc_id_by_path.get(validated)
if (
validated not in files_state
and resolved_doc_id is None
and isinstance(backend, KBPostgresBackend)
):
loaded = await backend._load_file_data(validated)
if loaded is None:
return f"Error: file '{validated}' not found."
_, resolved_doc_id = loaded
files_update: dict[str, Any] = {validated: None}
update: dict[str, Any] = {
"pending_deletes": [
{
"path": validated,
"tool_call_id": runtime.tool_call_id,
}
],
"files": files_update,
"doc_id_by_path": {validated: None},
"messages": [
ToolMessage(
content=(
f"Staged delete of '{validated}' (will commit at "
"end of turn)."
),
tool_call_id=runtime.tool_call_id,
)
],
}
# Drop the path from dirty_paths so a same-turn write+rm
# doesn't recreate the doc at commit time.
dirty_paths = list(runtime.state.get("dirty_paths") or [])
if validated in dirty_paths:
new_dirty: list[Any] = [_CLEAR]
for entry in dirty_paths:
if entry != validated:
new_dirty.append(entry)
update["dirty_paths"] = new_dirty
update["dirty_path_tool_calls"] = {validated: None}
return Command(update=update)
# Desktop mode — hit disk immediately.
backend = self._get_backend(runtime)
adelete = getattr(backend, "adelete_file", None)
if not callable(adelete):
return "Error: rm is not supported by the active backend."
res: WriteResult = await adelete(validated)
if res.error:
return res.error
update_desktop: dict[str, Any] = {
"files": {validated: None},
"messages": [
ToolMessage(
content=f"Deleted file '{res.path or validated}'",
tool_call_id=runtime.tool_call_id,
)
],
}
return Command(update=update_desktop)
def sync_rm(
path: Annotated[
str,
"Absolute or relative path to the file to delete.",
],
runtime: ToolRuntime[None, SurfSenseFilesystemState],
) -> Command | str:
return self._run_async_blocking(async_rm(path, runtime))
return StructuredTool.from_function(
name="rm",
description=tool_description,
func=sync_rm,
coroutine=async_rm,
)
# ------------------------------------------------------------------ tool: rmdir
def _create_rmdir_tool(self) -> BaseTool:
tool_description = (
self._custom_tool_descriptions.get("rmdir") or _CLOUD_RMDIR_TOOL_DESCRIPTION
)
async def async_rmdir(
path: Annotated[
str,
"Absolute or relative path of the empty directory to delete.",
],
runtime: ToolRuntime[None, SurfSenseFilesystemState],
) -> Command | str:
if not path or not path.strip():
return "Error: path is required."
target = self._resolve_relative(path, runtime)
try:
validated = validate_path(target)
except ValueError as exc:
return f"Error: {exc}"
if self._is_cloud():
if validated in ("/", DOCUMENTS_ROOT):
return f"Error: refusing to rmdir '{validated}'."
if not validated.startswith(DOCUMENTS_ROOT + "/"):
return (
"Error: cloud rmdir must target a path under /documents/ "
f"(got '{validated}')."
)
cwd = self._current_cwd(runtime)
if validated == cwd or _is_ancestor_of(validated, cwd):
return (
f"Error: cannot rmdir '{validated}' because the current "
"cwd is at or under it. cd out first."
)
staged_dirs = list(runtime.state.get("staged_dirs") or [])
pending_dir_deletes = list(
runtime.state.get("pending_dir_deletes") or []
)
if any(
isinstance(d, dict) and d.get("path") == validated
for d in pending_dir_deletes
):
return f"'{validated}' is already queued for deletion."
backend = self._get_backend(runtime)
# The path must currently exist either in DB folder paths or
# in staged_dirs. We rely on KBPostgresBackend.als_info (which
# already accounts for pending deletes/moves) to evaluate
# both existence and emptiness against the post-staged view.
exists_in_staged = validated in staged_dirs
children: list[Any] = []
if isinstance(backend, KBPostgresBackend):
children = list(await backend.als_info(validated))
# Detect "is a file" — if als_info returns no children but
# the path is actually a file, we should reject. We use
# _load_file_data to disambiguate file vs missing folder.
if (
isinstance(backend, KBPostgresBackend)
and not children
and not exists_in_staged
):
loaded = await backend._load_file_data(validated)
if loaded is not None:
return (
f"Error: '{validated}' is a file. Use rm to delete files."
)
# Confirm folder exists in DB by checking the parent listing.
parent = posixpath.dirname(validated) or "/"
parent_listing = await backend.als_info(parent)
parent_has_dir = any(
info.get("path") == validated and info.get("is_dir")
for info in parent_listing
)
if not parent_has_dir:
return f"Error: directory '{validated}' not found."
if children:
return (
f"Error: directory '{validated}' is not empty. "
"Remove contents first."
)
# Same-turn mkdir un-stage: drop the staged_dirs entry
# entirely and skip queuing a DB delete (nothing was ever
# committed).
if exists_in_staged:
rest = [d for d in staged_dirs if d != validated]
return Command(
update={
"staged_dirs": [_CLEAR, *rest],
"staged_dir_tool_calls": {validated: None},
"messages": [
ToolMessage(
content=(f"Un-staged directory '{validated}'."),
tool_call_id=runtime.tool_call_id,
)
],
}
)
return Command(
update={
"pending_dir_deletes": [
{
"path": validated,
"tool_call_id": runtime.tool_call_id,
}
],
"messages": [
ToolMessage(
content=(
f"Staged rmdir of '{validated}' (will commit "
"at end of turn)."
),
tool_call_id=runtime.tool_call_id,
)
],
}
)
# Desktop mode — hit disk immediately.
backend = self._get_backend(runtime)
armdir = getattr(backend, "armdir", None)
if not callable(armdir):
return "Error: rmdir is not supported by the active backend."
res: WriteResult = await armdir(validated)
if res.error:
return res.error
return Command(
update={
"messages": [
ToolMessage(
content=f"Deleted directory '{res.path or validated}'",
tool_call_id=runtime.tool_call_id,
)
],
}
)
def sync_rmdir(
path: Annotated[
str,
"Absolute or relative path of the empty directory to delete.",
],
runtime: ToolRuntime[None, SurfSenseFilesystemState],
) -> Command | str:
return self._run_async_blocking(async_rmdir(path, runtime))
return StructuredTool.from_function(
name="rmdir",
description=tool_description,
func=sync_rmdir,
coroutine=async_rmdir,
)
# ------------------------------------------------------------------ tool: list_tree
def _create_list_tree_tool(self) -> BaseTool:

View file

@ -115,6 +115,12 @@ class KBPostgresBackend(BackendProtocol):
def _pending_moves(self) -> list[dict[str, Any]]:
return list(self.state.get("pending_moves") or [])
def _pending_deletes(self) -> list[dict[str, Any]]:
return list(self.state.get("pending_deletes") or [])
def _pending_dir_deletes(self) -> list[dict[str, Any]]:
return list(self.state.get("pending_dir_deletes") or [])
def _kb_anon_doc(self) -> dict[str, Any] | None:
anon = self.state.get("kb_anon_doc")
return anon if isinstance(anon, dict) else None
@ -140,18 +146,28 @@ class KBPostgresBackend(BackendProtocol):
return path
return path.rstrip("/") if path != "/" else path
def _moved_view_paths(
def _pending_filesystem_view(
self,
existing: dict[str, dict[str, Any]],
) -> tuple[set[str], dict[str, str]]:
"""Apply ``pending_moves`` to a path set and return ``(removed, alias)``.
) -> tuple[set[str], dict[str, str], set[str]]:
"""Compute removed/aliased/dir-suppressed paths from staged ops.
Removed paths should disappear from listings; ``alias[source] = dest``
means a virtual entry should appear at ``dest`` even if no DB row is
yet there.
Returns ``(removed, alias, deleted_dirs)`` where:
* ``removed`` paths to drop from listings (sources of pending moves
AND paths queued for ``rm``).
* ``alias`` ``{source: dest}`` for pending moves; the dest should
appear as a virtual entry even when no DB row is at that path yet.
* ``deleted_dirs`` folder paths queued for ``rmdir``; their entire
subtree (descendants) is suppressed from listings/glob/grep.
Entries in ``existing`` (the ``files`` state cache) keyed by a
removed path are popped so a same-turn delete-after-write doesn't
leave a stale virtual file in listings.
"""
removed: set[str] = set()
alias: dict[str, str] = {}
deleted_dirs: set[str] = set()
for move in self._pending_moves():
src = move.get("source")
dst = move.get("dest")
@ -160,7 +176,23 @@ class KBPostgresBackend(BackendProtocol):
removed.add(src)
alias[src] = dst
existing.pop(src, None)
return removed, alias
for entry in self._pending_deletes():
path = entry.get("path") if isinstance(entry, dict) else None
if not path:
continue
removed.add(path)
existing.pop(path, None)
for entry in self._pending_dir_deletes():
path = entry.get("path") if isinstance(entry, dict) else None
if not path:
continue
deleted_dirs.add(path)
return removed, alias, deleted_dirs
@staticmethod
def _is_dir_suppressed(path: str, deleted_dirs: set[str]) -> bool:
"""Return True iff ``path`` is at-or-under any directory in ``deleted_dirs``."""
return any(path == d or _is_under(path, d) for d in deleted_dirs)
# ------------------------------------------------------------------ ls/read
@ -189,7 +221,7 @@ class KBPostgresBackend(BackendProtocol):
seen.add(anon_path)
files = self._state_files()
moved_removed, moved_alias = self._moved_view_paths(files)
moved_removed, moved_alias, deleted_dirs = self._pending_filesystem_view(files)
if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/":
try:
@ -203,7 +235,12 @@ class KBPostgresBackend(BackendProtocol):
for info in db_infos:
p = info.get("path", "")
if not p or p in seen or p in moved_removed:
if (
not p
or p in seen
or p in moved_removed
or self._is_dir_suppressed(p, deleted_dirs)
):
continue
infos.append(info)
seen.add(p)
@ -212,6 +249,8 @@ class KBPostgresBackend(BackendProtocol):
if src not in seen:
if not _is_under(dst, normalized):
continue
if self._is_dir_suppressed(dst, deleted_dirs):
continue
rel = (
dst[len(normalized) :].lstrip("/")
if normalized != "/"
@ -247,6 +286,8 @@ class KBPostgresBackend(BackendProtocol):
continue
if not _is_under(staged, normalized):
continue
if self._is_dir_suppressed(staged, deleted_dirs):
continue
rel = (
staged[len(normalized) :].lstrip("/")
if normalized != "/"
@ -265,14 +306,26 @@ class KBPostgresBackend(BackendProtocol):
for sub in sorted(subdir_paths):
if sub in seen:
continue
if self._is_dir_suppressed(sub, deleted_dirs):
continue
infos.append(FileInfo(path=sub, is_dir=True, size=0, modified_at=""))
seen.add(sub)
for path_key, fd in files.items():
if not isinstance(path_key, str) or path_key in seen:
continue
# Tombstones (None values) are deletion markers from `rm`. The
# deepagents reducer normally pops them, but a stale tombstone
# surviving a checkpoint must NOT be reported as a child here —
# otherwise rmdir mistakenly sees the deleted file as content.
if fd is None:
continue
if not _is_under(path_key, normalized) or path_key == normalized:
continue
if path_key in moved_removed or self._is_dir_suppressed(
path_key, deleted_dirs
):
continue
if normalized == "/":
rel = path_key.lstrip("/")
else:
@ -550,10 +603,12 @@ class KBPostgresBackend(BackendProtocol):
seen: set[str] = set()
files = self._state_files()
moved_removed, _ = self._moved_view_paths(files)
moved_removed, _, deleted_dirs = self._pending_filesystem_view(files)
regex = re.compile(fnmatch.translate(pattern))
for path_key, fd in files.items():
if path_key in moved_removed:
if path_key in moved_removed or self._is_dir_suppressed(
path_key, deleted_dirs
):
continue
if not _is_under(path_key, normalized):
continue
@ -595,7 +650,11 @@ class KBPostgresBackend(BackendProtocol):
folder_id=row.folder_id,
index=index,
)
if candidate in seen or candidate in moved_removed:
if (
candidate in seen
or candidate in moved_removed
or self._is_dir_suppressed(candidate, deleted_dirs)
):
continue
if not _is_under(candidate, normalized):
continue
@ -634,10 +693,12 @@ class KBPostgresBackend(BackendProtocol):
matches: list[GrepMatch] = []
files = self._state_files()
moved_removed, _ = self._moved_view_paths(files)
moved_removed, _, deleted_dirs = self._pending_filesystem_view(files)
glob_re = re.compile(fnmatch.translate(glob)) if glob else None
for path_key, fd in files.items():
if path_key in moved_removed:
if path_key in moved_removed or self._is_dir_suppressed(
path_key, deleted_dirs
):
continue
if not _is_under(path_key, normalized):
continue
@ -695,7 +756,11 @@ class KBPostgresBackend(BackendProtocol):
)
for doc_id, chunk_id, content in chunk_buffer:
candidate = doc_id_to_path.get(doc_id)
if not candidate or candidate in moved_removed:
if (
not candidate
or candidate in moved_removed
or self._is_dir_suppressed(candidate, deleted_dirs)
):
continue
if not _is_under(candidate, normalized):
continue
@ -769,7 +834,7 @@ class KBPostgresBackend(BackendProtocol):
return {"entries": [], "truncated": False}
files = self._state_files()
moved_removed, _ = self._moved_view_paths(files)
moved_removed, _, deleted_dirs = self._pending_filesystem_view(files)
anon = self._kb_anon_doc()
anon_path = str(anon.get("path") or "") if anon else ""
@ -795,6 +860,8 @@ class KBPostgresBackend(BackendProtocol):
for _fid, fpath in sorted(index.folder_paths.items(), key=lambda kv: kv[1]):
if not _is_under(fpath, normalized):
continue
if self._is_dir_suppressed(fpath, deleted_dirs):
continue
depth = _depth_of(fpath)
if max_depth is not None and depth > max_depth:
continue
@ -811,6 +878,8 @@ class KBPostgresBackend(BackendProtocol):
for staged in self._staged_dirs():
if not _is_under(staged, normalized):
continue
if self._is_dir_suppressed(staged, deleted_dirs):
continue
depth = _depth_of(staged)
if max_depth is not None and depth > max_depth:
continue
@ -835,7 +904,9 @@ class KBPostgresBackend(BackendProtocol):
folder_id=row.folder_id,
index=index,
)
if candidate in moved_removed:
if candidate in moved_removed or self._is_dir_suppressed(
candidate, deleted_dirs
):
continue
if not _is_under(candidate, normalized):
continue
@ -875,6 +946,10 @@ class KBPostgresBackend(BackendProtocol):
continue
if not _is_under(path_key, normalized):
continue
if path_key in moved_removed or self._is_dir_suppressed(
path_key, deleted_dirs
):
continue
if any(e["path"] == path_key for e in entries):
continue
if not (

View file

@ -201,6 +201,12 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
)
all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT]))
# Pre-compute which folders have at least one descendant (folder or doc).
# A folder is "empty" iff no path in `all_paths` is strictly under it.
# Used to emit an explicit "(empty)" marker so the LLM doesn't have to
# infer emptiness from indentation alone.
non_empty_folders = self._compute_non_empty_folders(folder_paths, doc_paths)
lines: list[str] = []
for path in all_paths:
depth = (
@ -214,7 +220,10 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents"
)
if is_dir:
lines.append(f"{indent}{display}/")
if path != DOCUMENTS_ROOT and path not in non_empty_folders:
lines.append(f"{indent}{display}/ (empty)")
else:
lines.append(f"{indent}{display}/")
else:
lines.append(f"{indent}{display}")
if len(lines) >= self.max_entries:
@ -235,6 +244,35 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
return self._format_root_summary(folder_paths, doc_paths)
@staticmethod
def _compute_non_empty_folders(
folder_paths: list[str], doc_paths: list[str]
) -> set[str]:
"""Return the set of folder paths that contain at least one descendant.
A folder is "non-empty" if any document path or any other folder path
is strictly under it. Documents propagate emptiness up to every
ancestor folder, while a sub-folder only marks its direct ancestors
non-empty (so a chain of empty folders all read ``(empty)``).
"""
non_empty: set[str] = set()
folder_set = set(folder_paths)
for doc_path in doc_paths:
parent = doc_path.rsplit("/", 1)[0]
while parent and parent != DOCUMENTS_ROOT:
if parent in folder_set:
non_empty.add(parent)
parent = parent.rsplit("/", 1)[0]
for child in folder_paths:
parent = child.rsplit("/", 1)[0]
while parent and parent != DOCUMENTS_ROOT and parent in folder_set:
non_empty.add(parent)
parent = parent.rsplit("/", 1)[0]
return non_empty
def _format_root_summary(
self, folder_paths: list[str], doc_paths: list[str]
) -> str:

View file

@ -360,6 +360,74 @@ class LocalFolderBackend:
self.move, source_path, destination_path, overwrite
)
def delete_file(self, file_path: str) -> WriteResult:
"""Hard-delete a single file under root.
Refuses directories, root, and missing paths. Roughly mirrors POSIX
``rm path``; ``-r`` recursion and glob expansion are explicitly
out of scope.
"""
try:
path = self._resolve_virtual(file_path)
except ValueError:
return WriteResult(error=f"Error: Invalid path '{file_path}'")
with self._lock_for(file_path):
if not path.exists():
return WriteResult(error=f"Error: File '{file_path}' not found")
if path.is_dir():
return WriteResult(
error=(
f"Error: '{file_path}' is a directory. "
"Use rmdir for empty directories."
)
)
try:
os.unlink(path)
except OSError as exc:
return WriteResult(
error=f"Error: failed to delete '{file_path}': {exc}"
)
return WriteResult(path=file_path, files_update=None)
async def adelete_file(self, file_path: str) -> WriteResult:
return await asyncio.to_thread(self.delete_file, file_path)
def rmdir(self, dir_path: str) -> WriteResult:
"""Hard-delete an empty directory under root.
Refuses files, root, missing paths, and non-empty directories.
``os.rmdir`` is naturally empty-only; we pre-check so the error is
clearer for the agent.
"""
try:
path = self._resolve_virtual(dir_path)
except ValueError:
return WriteResult(error=f"Error: Invalid path '{dir_path}'")
with self._lock_for(dir_path):
if not path.exists():
return WriteResult(error=f"Error: Directory '{dir_path}' not found")
if not path.is_dir():
return WriteResult(error=f"Error: '{dir_path}' is not a directory")
try:
next(path.iterdir())
except StopIteration:
pass
else:
return WriteResult(
error=(
f"Error: directory '{dir_path}' is not empty. "
"Remove its contents first."
)
)
try:
os.rmdir(path)
except OSError as exc:
return WriteResult(error=f"Error: failed to rmdir '{dir_path}': {exc}")
return WriteResult(path=dir_path, files_update=None)
async def armdir(self, dir_path: str) -> WriteResult:
return await asyncio.to_thread(self.rmdir, dir_path)
def edit(
self,
file_path: str,

View file

@ -285,6 +285,34 @@ class MultiRootLocalFolderBackend:
overwrite,
)
def delete_file(self, file_path: str) -> WriteResult:
try:
mount, local_path = self._split_mount_path(file_path)
except ValueError as exc:
return WriteResult(error=f"Error: {exc}")
result = self._mount_to_backend[mount].delete_file(local_path)
if result.path:
result.path = self._prefix_mount_path(mount, result.path)
return result
async def adelete_file(self, file_path: str) -> WriteResult:
return await asyncio.to_thread(self.delete_file, file_path)
def rmdir(self, dir_path: str) -> WriteResult:
try:
mount, local_path = self._split_mount_path(dir_path)
except ValueError as exc:
return WriteResult(error=f"Error: {exc}")
if local_path == "/":
return WriteResult(error=f"Error: cannot rmdir mount root '{dir_path}'")
result = self._mount_to_backend[mount].rmdir(local_path)
if result.path:
result.path = self._prefix_mount_path(mount, result.path)
return result
async def armdir(self, dir_path: str) -> WriteResult:
return await asyncio.to_thread(self.rmdir, dir_path)
def edit(
self,
file_path: str,

View file

@ -0,0 +1,166 @@
"""LiteLLM-native prompt caching configuration for SurfSense agents.
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
activated for our LiteLLM-based stack its ``isinstance(model, ChatAnthropic)``
gate always failed) with LiteLLM's universal caching mechanism.
Coverage:
- Marker-based providers (need ``cache_control`` injection, which LiteLLM
performs automatically when ``cache_control_injection_points`` is set):
``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``,
``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/``
(Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM).
- Auto-cached (LiteLLM strips the marker silently): ``openai/``,
``deepseek/``, ``xai/`` these caches automatically for prompts 1024
tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``.
We inject **two** breakpoints per request:
- ``role: system`` pins the SurfSense system prompt (provider variant,
citation rules, tool catalog, KB tree, skills metadata) into the cache.
- ``index: -1`` pins the latest message so multi-turn savings compound:
Anthropic-family providers use longest-matching-prefix lookup, so turn
N+1 still reads turn N's cache up to the shared prefix.
For OpenAI-family configs we additionally pass:
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` routing hint that
raises hit rate by sending requests with a shared prefix to the same
backend.
- ``prompt_cache_retention="24h"`` extends cache TTL beyond the default
5-10 min in-memory cache.
Safety net: ``litellm.drop_params=True`` is set globally in
``app.services.llm_service`` at module-load time. Any kwarg the destination
provider doesn't recognise is auto-stripped at the provider transformer
layer, so an OpenAIBedrock auto-mode fallback can't 400 on
``prompt_cache_key`` etc.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from langchain_core.language_models import BaseChatModel
if TYPE_CHECKING:
from app.agents.new_chat.llm_config import AgentConfig
logger = logging.getLogger(__name__)
# Two-breakpoint policy: system + latest message. See module docstring for
# rationale. Anthropic limits requests to 4 ``cache_control`` blocks; we
# use 2 here, leaving headroom for Phase-2 tool caching.
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
{"location": "message", "role": "system"},
{"location": "message", "index": -1},
)
# Providers (uppercase ``AgentConfig.provider`` values) that natively expose
# OpenAI-style automatic prompt caching with ``prompt_cache_key`` and
# ``prompt_cache_retention`` kwargs. Strict whitelist — many other providers
# in ``PROVIDER_MAP`` route through litellm's ``openai`` prefix without
# implementing the OpenAI prompt-cache surface (e.g. MOONSHOT, ZHIPU,
# MINIMAX), so we can't infer family from the litellm prefix alone.
_OPENAI_FAMILY_PROVIDERS: frozenset[str] = frozenset({"OPENAI", "DEEPSEEK", "XAI"})
def _is_router_llm(llm: BaseChatModel) -> bool:
"""Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import.
Importing ``app.services.llm_router_service`` at module-load time would
create a cycle via ``llm_config -> prompt_caching -> llm_router_service``.
Class-name comparison is sufficient since the class is defined in a
single place.
"""
return type(llm).__name__ == "ChatLiteLLMRouter"
def _is_openai_family_config(agent_config: AgentConfig | None) -> bool:
"""Whether the config targets an OpenAI-style prompt-cache surface.
Strict only returns True when the user explicitly chose OPENAI,
DEEPSEEK, or XAI as the provider in their ``NewLLMConfig`` /
``YAMLConfig``. Auto-mode and custom providers return False because
we can't statically know the destination.
"""
if agent_config is None or not agent_config.provider:
return False
if agent_config.is_auto_mode:
return False
if agent_config.custom_provider:
return False
return agent_config.provider.upper() in _OPENAI_FAMILY_PROVIDERS
def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None:
"""Return ``llm.model_kwargs`` as a writable dict, or ``None`` to bail.
Initialises the field to ``{}`` when present-but-None on a Pydantic v2
model. Returns ``None`` if the LLM type doesn't expose a writable
``model_kwargs`` attribute (caller should treat as no-op).
"""
model_kwargs = getattr(llm, "model_kwargs", None)
if isinstance(model_kwargs, dict):
return model_kwargs
try:
llm.model_kwargs = {} # type: ignore[attr-defined]
except Exception:
return None
refreshed = getattr(llm, "model_kwargs", None)
return refreshed if isinstance(refreshed, dict) else None
def apply_litellm_prompt_caching(
llm: BaseChatModel,
*,
agent_config: AgentConfig | None = None,
thread_id: int | None = None,
) -> None:
"""Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter.
Idempotent values already present in ``llm.model_kwargs`` (e.g. from
``agent_config.litellm_params`` overrides) are preserved. Mutates
``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion``
via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge
in our custom ``ChatLiteLLMRouter``.
Args:
llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance.
agent_config: Optional ``AgentConfig`` driving provider-specific
behaviour. When omitted (or auto-mode), only the universal
``cache_control_injection_points`` are set.
thread_id: Optional thread id used to construct a per-thread
``prompt_cache_key`` for OpenAI-family providers. Caching still
works without it (server-side automatic), but the key improves
backend routing affinity and therefore hit rate.
"""
model_kwargs = _get_or_init_model_kwargs(llm)
if model_kwargs is None:
logger.debug(
"apply_litellm_prompt_caching: %s exposes no writable model_kwargs; skipping",
type(llm).__name__,
)
return
if "cache_control_injection_points" not in model_kwargs:
model_kwargs["cache_control_injection_points"] = [
dict(point) for point in _DEFAULT_INJECTION_POINTS
]
# OpenAI-family extras only when we statically know the destination is
# OpenAI / DeepSeek / xAI. Auto-mode router fans out across providers
# so we can't safely set OpenAI-only kwargs there (drop_params would
# strip them but it's wasteful to set them in the first place).
if _is_router_llm(llm):
return
if not _is_openai_family_config(agent_config):
return
if thread_id is not None and "prompt_cache_key" not in model_kwargs:
model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}"
if "prompt_cache_retention" not in model_kwargs:
model_kwargs["prompt_cache_retention"] = "24h"

View file

@ -181,9 +181,13 @@ def _initial_filesystem_state() -> dict[str, Any]:
return {
"cwd": "/documents",
"staged_dirs": [],
"staged_dir_tool_calls": {},
"pending_moves": [],
"pending_deletes": [],
"pending_dir_deletes": [],
"doc_id_by_path": {},
"dirty_paths": [],
"dirty_path_tool_calls": {},
"kb_priority": [],
"kb_matched_chunk_ids": {},
"kb_anon_doc": None,

View file

@ -84,6 +84,8 @@ WRITE_TOOL_DENY_PATTERNS: tuple[str, ...] = (
"write_file",
"move_file",
"mkdir",
"rm",
"rmdir",
"update_memory",
"update_memory_team",
"update_memory_private",

View file

@ -30,6 +30,35 @@ from langgraph.types import interrupt
logger = logging.getLogger(__name__)
# Tools that mirror the safety profile of ``write_file`` against the
# SurfSense KB: each call creates ONE artifact in the user's own workspace
# with no external visibility (drafts aren't sent; new files aren't shared
# unless the user shares them later). These are auto-approved by default
# so the agent can compose drafts and seed scratch files without a popup
# on every call.
#
# Members of this set still call ``request_approval`` exactly as before;
# the function returns immediately with ``decision_type="auto_approved"``
# and the original params untouched. This preserves the call-site shape
# (logging, metadata fetching, account fallbacks) so the only behavior
# change is "no interrupt fires".
#
# To re-enable prompting, the future per-search-space rules table
# (``agent_permission_rules``) takes precedence — see the ``# (future)``
# layer-3 comment in :mod:`app.agents.new_chat.chat_deepagent`.
DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
{
"create_gmail_draft",
"update_gmail_draft",
"create_notion_page",
"create_confluence_page",
"create_google_drive_file",
"create_dropbox_file",
"create_onedrive_file",
}
)
@dataclass(frozen=True, slots=True)
class HITLResult:
"""Outcome of a human-in-the-loop approval request."""
@ -119,6 +148,19 @@ def request_approval(
logger.info("Tool '%s' is user-trusted — skipping HITL", tool_name)
return HITLResult(rejected=False, decision_type="trusted", params=dict(params))
if tool_name in DEFAULT_AUTO_APPROVED_TOOLS:
# Default policy: low-stakes creation tools (drafts + new-file
# creates) skip HITL because they're as recoverable as a local
# ``write_file`` against the SurfSense KB. The user can still
# delete the artifact in <30s if it's wrong.
logger.info(
"Tool '%s' is in DEFAULT_AUTO_APPROVED_TOOLS — skipping HITL",
tool_name,
)
return HITLResult(
rejected=False, decision_type="auto_approved", params=dict(params)
)
approval = interrupt(
{
"type": action_type,

View file

@ -63,6 +63,27 @@ def load_global_llm_configs():
else:
seen_slugs[slug] = cfg.get("id", 0)
# Stamp Auto (Fastest) ranking metadata. YAML configs are always
# Tier A — operator-curated, locked first when premium-eligible.
# The OpenRouter refresh tick later re-stamps health for any cfg
# whose provider == "OPENROUTER" via _enrich_health.
try:
from app.services.quality_score import static_score_yaml
for cfg in configs:
cfg["auto_pin_tier"] = "A"
static_q = static_score_yaml(cfg)
cfg["quality_score_static"] = static_q
cfg["quality_score"] = static_q
cfg["quality_score_health"] = None
# YAML cfgs whose provider is OPENROUTER are also subject
# to health gating against their own /endpoints data — a
# hand-picked dead OR model is still dead. _enrich_health
# re-stamps health_gated for them on the next refresh tick.
cfg["health_gated"] = False
except Exception as e:
print(f"Warning: Failed to score global LLM configs: {e}")
return configs
except Exception as e:
print(f"Warning: Failed to load global LLM configs: {e}")
@ -194,6 +215,9 @@ def load_openrouter_integration_settings() -> dict | None:
"""
Load OpenRouter integration settings from the YAML config.
Emits startup warnings for deprecated keys (``billing_tier``,
``anonymous_enabled``) and seeds their replacements for back-compat.
Returns:
dict with settings if present and enabled, None otherwise
"""
@ -206,9 +230,31 @@ def load_openrouter_integration_settings() -> dict | None:
with open(global_config_file, encoding="utf-8") as f:
data = yaml.safe_load(f)
settings = data.get("openrouter_integration")
if settings and settings.get("enabled"):
return settings
return None
if not settings or not settings.get("enabled"):
return None
if "billing_tier" in settings:
print(
"Warning: openrouter_integration.billing_tier is deprecated; "
"tier is now derived per model from OpenRouter data "
"(':free' suffix or zero pricing). Remove this key."
)
if "anonymous_enabled" in settings:
print(
"Warning: openrouter_integration.anonymous_enabled is "
"deprecated; use anonymous_enabled_paid and/or "
"anonymous_enabled_free instead. Both new flags have been "
"seeded from the legacy value for back-compat."
)
settings.setdefault(
"anonymous_enabled_paid", settings["anonymous_enabled"]
)
settings.setdefault(
"anonymous_enabled_free", settings["anonymous_enabled"]
)
return settings
except Exception as e:
print(f"Warning: Failed to load OpenRouter integration settings: {e}")
return None
@ -217,9 +263,14 @@ def load_openrouter_integration_settings() -> dict | None:
def initialize_openrouter_integration():
"""
If enabled, fetch all OpenRouter models and append them to
config.GLOBAL_LLM_CONFIGS as dynamic premium entries.
Should be called BEFORE initialize_llm_router() so the router
correctly excludes premium models from Auto mode.
config.GLOBAL_LLM_CONFIGS as dynamic entries. Each model's ``billing_tier``
is derived per-model from OpenRouter's API signals (``:free`` suffix or
zero pricing), so free OpenRouter models correctly skip premium quota.
Should be called BEFORE initialize_llm_router(). Dynamic entries are
tagged ``router_pool_eligible=False`` so the LiteLLM Router pool (used
by title-gen / sub-agent flows) remains scoped to curated YAML configs,
while user-facing Auto-mode thread pinning still considers them.
"""
settings = load_openrouter_integration_settings()
if not settings:
@ -235,9 +286,13 @@ def initialize_openrouter_integration():
if new_configs:
config.GLOBAL_LLM_CONFIGS.extend(new_configs)
free_count = sum(1 for c in new_configs if c.get("billing_tier") == "free")
premium_count = sum(
1 for c in new_configs if c.get("billing_tier") == "premium"
)
print(
f"Info: OpenRouter integration added {len(new_configs)} models "
f"(billing_tier={settings.get('billing_tier', 'premium')})"
f"(free={free_count}, premium={premium_count})"
)
else:
print("Info: OpenRouter integration enabled but no models fetched")

View file

@ -245,31 +245,53 @@ global_llm_configs:
# =============================================================================
# When enabled, dynamically fetches ALL available models from the OpenRouter API
# and injects them as global configs. This gives premium users access to any model
# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota.
# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota,
# while free-tier OpenRouter models show up with a green Free badge and do NOT
# consume premium quota.
# Models are fetched at startup and refreshed periodically in the background.
# All calls go through LiteLLM with the openrouter/ prefix.
openrouter_integration:
enabled: false
api_key: "sk-or-your-openrouter-api-key"
# billing_tier: "premium" or "free". Controls whether users need premium tokens.
billing_tier: "premium"
# anonymous_enabled: set true to also show OpenRouter models to no-login users
anonymous_enabled: false
# Tier is derived PER MODEL from OpenRouter's own API signals:
# - id ends with ":free" -> billing_tier=free
# - pricing.prompt AND pricing.completion == "0" -> billing_tier=free
# - otherwise -> billing_tier=premium
# No global billing_tier knob is honored; any legacy value emits a startup warning.
# Anonymous access is split by tier so operators can expose only free
# models to no-login users without leaking paid inference.
anonymous_enabled_paid: false
anonymous_enabled_free: false
seo_enabled: false
# quota_reserve_tokens: tokens reserved per call for quota enforcement
quota_reserve_tokens: 4000
# id_offset: starting negative ID for dynamically generated configs.
# Must not overlap with your static global_llm_configs IDs above.
# id_offset: base negative ID for dynamically generated configs.
# Model IDs are derived deterministically via BLAKE2b so they survive
# catalogue churn. Must not overlap with your static global_llm_configs IDs.
id_offset: -10000
# refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only)
refresh_interval_hours: 24
# rpm/tpm: Applied uniformly to all OpenRouter models for LiteLLM Router load balancing.
# OpenRouter doesn't expose per-model rate limits via API; actual throttling is handled
# upstream by OpenRouter itself (your account limits are at https://openrouter.ai/settings/limits).
# These values only matter if you set billing_tier to "free" (adding them to Auto mode).
# For premium-only models they are cosmetic. Set conservatively or match your account tier.
# Rate limits for PAID OpenRouter models. These are used by LiteLLM Router
# for per-deployment accounting when OR premium models participate in the
# shared sub-agent "auto" pool. They do NOT cap OpenRouter itself — your
# real account limits live at https://openrouter.ai/settings/limits.
rpm: 200
tpm: 1000000
# Rate limits for FREE OpenRouter models. Informational only: free OR
# models are intentionally kept OUT of the LiteLLM Router pool, because
# OpenRouter enforces free-tier limits globally per account (~20 RPM +
# 50-1000 daily requests across every ":free" model combined) —
# per-deployment router accounting can't represent a shared bucket
# correctly. Free OR models stay fully available in the model selector
# and for user-facing Auto thread pinning.
free_rpm: 20
free_tpm: 100000
litellm_params:
max_tokens: 16384
system_instructions: ""

View file

@ -638,6 +638,12 @@ class NewChatThread(BaseModel, TimestampMixin):
default=False,
server_default="false",
)
# Auto (Fastest) model pin for this thread: concrete resolved global LLM
# config id. NULL means no pin; Auto will resolve on the next turn.
# Single-writer invariant: only app.services.auto_model_pin_service sets
# or clears this column (plus bulk clears when a search space's
# agent_llm_id changes). Unindexed: all reads are by primary key.
pinned_llm_config_id = Column(Integer, nullable=True)
# Relationships
search_space = relationship("SearchSpace", back_populates="new_chat_threads")
@ -689,6 +695,12 @@ class NewChatMessage(BaseModel, TimestampMixin):
index=True,
)
# Per-turn correlation id sourced from ``configurable.turn_id`` at
# streaming time (``f"{chat_id}:{ms}"``). Nullable because legacy rows
# predate the column. Used by C1's edit-from-arbitrary-position to map
# a message back to the LangGraph checkpoint that produced its turn.
turn_id = Column(String(64), nullable=True, index=True)
# Relationships
thread = relationship("NewChatThread", back_populates="messages")
author = relationship("User")
@ -2292,7 +2304,13 @@ class AgentActionLog(BaseModel):
nullable=False,
index=True,
)
# ``turn_id`` historically held the LangChain ``tool_call.id``. It has
# been renamed to ``tool_call_id`` (with a parallel column kept for one
# release for back-compat). The real chat-turn id lives in
# ``chat_turn_id`` and is sourced from ``configurable.turn_id``.
turn_id = Column(String(64), nullable=True, index=True)
tool_call_id = Column(String(64), nullable=True, index=True)
chat_turn_id = Column(String(64), nullable=True, index=True)
message_id = Column(String(128), nullable=True, index=True)
tool_name = Column(String(255), nullable=False, index=True)
args = Column(JSONB, nullable=True)
@ -2318,6 +2336,16 @@ class AgentActionLog(BaseModel):
__table_args__ = (
Index("ix_agent_action_log_thread_created", "thread_id", "created_at"),
# Partial unique index enforces "at most one revert per
# original action". Created in migration 137 with
# ``WHERE reverse_of IS NOT NULL`` so non-revert rows
# (the vast majority) are unaffected and NULLs don't collide.
Index(
"ux_agent_action_log_reverse_of",
"reverse_of",
unique=True,
postgresql_where=text("reverse_of IS NOT NULL"),
),
)
@ -2332,10 +2360,13 @@ class DocumentRevision(BaseModel):
__tablename__ = "document_revisions"
# ``ON DELETE SET NULL`` (not CASCADE) so the snapshot survives the
# hard-delete it describes — without that, ``rm`` would wipe the row
# we'd need to undo it. See migration ``134_relax_revision_fks``.
document_id = Column(
Integer,
ForeignKey("documents.id", ondelete="CASCADE"),
nullable=False,
ForeignKey("documents.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
search_space_id = Column(
@ -2370,10 +2401,13 @@ class FolderRevision(BaseModel):
__tablename__ = "folder_revisions"
# ``ON DELETE SET NULL`` (not CASCADE) so the snapshot survives the
# hard-delete it describes — without that, ``rmdir`` would wipe the
# row we'd need to undo it. See migration ``134_relax_revision_fks``.
folder_id = Column(
Integer,
ForeignKey("folders.id", ondelete="CASCADE"),
nullable=False,
ForeignKey("folders.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
search_space_id = Column(

View file

@ -65,6 +65,13 @@ class AgentActionRead(BaseModel):
reverse_of: int | None
reverted_by_action_id: int | None
is_revert_action: bool
# Correlation ids added in migration 135. ``tool_call_id`` is the
# LangChain tool-call id (joinable to ``data-action-log`` SSE events
# via ``langchainToolCallId``). ``chat_turn_id`` is the per-turn id
# from ``configurable.turn_id`` (used by the
# ``revert-turn/{chat_turn_id}`` endpoint).
tool_call_id: str | None = None
chat_turn_id: str | None = None
created_at: datetime
@ -172,6 +179,8 @@ async def list_thread_actions(
reverse_of=row.reverse_of,
reverted_by_action_id=revert_map.get(row.id),
is_revert_action=row.reverse_of is not None,
tool_call_id=row.tool_call_id,
chat_turn_id=row.chat_turn_id,
created_at=row.created_at,
)
for row in rows

View file

@ -11,14 +11,25 @@ flag flips. Once enabled, the route runs:
4. Revert dispatch via :func:`app.services.revert_service.revert_action`.
5. Idempotent on retries: if the same action is reverted twice the second
call returns 409 ``"already reverted"``.
This module also hosts the per-turn batch endpoint
``POST /api/threads/{thread_id}/revert-turn/{chat_turn_id}``. It
walks every reversible action emitted during a chat turn in reverse
``created_at`` order and reverts each independently. Partial success is the
common case the response always contains a per-action result list and a
``status`` of ``"ok"`` or ``"partial"``; we never collapse the batch into a
whole-batch 4xx.
"""
from __future__ import annotations
import logging
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.feature_flags import get_flags
@ -97,6 +108,16 @@ async def revert_agent_action(
action=action,
requester_user_id=str(user.id) if user is not None else None,
)
except IntegrityError:
# Partial unique index ``ux_agent_action_log_reverse_of`` caught
# a concurrent revert. Translate to the existing 409 "already
# reverted" contract so racing clients see consistent
# behaviour with the pre-flight TOCTOU check above.
await session.rollback()
raise HTTPException(
status_code=409,
detail="This action has already been reverted.",
) from None
except Exception as err:
logger.exception("Revert dispatch raised for action_id=%s", action_id)
await session.rollback()
@ -105,7 +126,16 @@ async def revert_agent_action(
) from err
if outcome.status == "ok":
await session.commit()
try:
await session.commit()
except IntegrityError:
# Race lost on commit (constraint enforced at flush in some
# configs but at commit in others — defensive).
await session.rollback()
raise HTTPException(
status_code=409,
detail="This action has already been reverted.",
) from None
return {
"status": "ok",
"message": outcome.message,
@ -122,3 +152,357 @@ async def revert_agent_action(
raise HTTPException(status_code=501, detail=outcome.message)
# not_reversible
raise HTTPException(status_code=409, detail=outcome.message)
# ---------------------------------------------------------------------------
# Per-turn revert batch endpoint
# ---------------------------------------------------------------------------
PerActionStatus = Literal[
"reverted",
"already_reverted",
"not_reversible",
"permission_denied",
"failed",
"skipped",
]
class RevertTurnActionResult(BaseModel):
"""Per-action outcome inside a ``revert-turn`` batch response."""
action_id: int
tool_name: str
status: PerActionStatus
message: str | None = None
new_action_id: int | None = None
error: str | None = None
class RevertTurnResponse(BaseModel):
"""Top-level response for ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
``status`` is ``"ok"`` only when every reversible row succeeded. Any
``failed`` / ``not_reversible`` / ``permission_denied`` entry downgrades
it to ``"partial"``. Empty turns (no rows) return ``"ok"`` with an empty
``results`` list callers should treat that as a no-op.
Counter invariant:
``total == reverted + already_reverted + not_reversible
+ permission_denied + failed + skipped``
Frontend toasts and the ``RevertTurnButton`` summary rely on this
invariant to display "X of Y reverted, Z could not be undone" without
silently dropping ``permission_denied`` or ``skipped`` rows.
"""
status: Literal["ok", "partial"]
chat_turn_id: str
total: int
reverted: int
already_reverted: int
not_reversible: int
permission_denied: int = 0
failed: int = 0
skipped: int = 0
results: list[RevertTurnActionResult]
def _classify_outcome(outcome: RevertOutcome) -> PerActionStatus:
if outcome.status == "ok":
return "reverted"
if outcome.status == "permission_denied":
return "permission_denied"
# ``not_found`` / ``tool_unavailable`` / ``reverse_not_implemented`` /
# ``not_reversible`` are all surfaced to the caller as "not_reversible"
# — they share the same UX (this row cannot be undone) and only the
# ``message`` differs.
return "not_reversible"
async def _was_already_reverted(session: AsyncSession, *, action_id: int) -> int | None:
"""Return the id of an existing successful revert row, if any.
Single-action variant kept for the post-IntegrityError lookup
path where we already know we lost a race for one specific id.
"""
stmt = select(AgentActionLog.id).where(AgentActionLog.reverse_of == action_id)
result = await session.execute(stmt)
return result.scalars().first()
async def _was_already_reverted_batch(
session: AsyncSession, *, action_ids: list[int]
) -> dict[int, int]:
"""Batch idempotency probe for the revert-turn loop.
Replaces N individual ``SELECT id WHERE reverse_of = :id`` queries
(one per row in the turn) with a single ``SELECT id, reverse_of
WHERE reverse_of IN (:ids)``. The route still iterates rows in
reverse-chronological order, but the membership check is O(1) per
iteration after this query. For a turn with 30 actions that's 30
fewer round-trips through asyncpg + a smaller transaction footprint.
Returns a ``{original_action_id -> revert_action_id}`` map. Missing
keys mean "not yet reverted" callers should treat them as
eligible for revert.
"""
if not action_ids:
return {}
stmt = select(AgentActionLog.id, AgentActionLog.reverse_of).where(
AgentActionLog.reverse_of.in_(action_ids)
)
result = await session.execute(stmt)
return {
original_id: revert_id
for revert_id, original_id in result.all()
if original_id is not None
}
@router.post(
"/threads/{thread_id}/revert-turn/{chat_turn_id}",
response_model=RevertTurnResponse,
)
async def revert_agent_turn(
thread_id: int,
chat_turn_id: str,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
) -> RevertTurnResponse:
"""Revert every reversible action emitted during ``chat_turn_id``.
Walks ``AgentActionLog`` rows for the turn in reverse ``created_at``
order so dependencies (e.g. ``mkdir`` -> ``write_file`` inside the new
folder) unwind in the right sequence. Each action is reverted in its
own SAVEPOINT so a single failure does not poison the batch.
Partial success is intentional and returned with HTTP 200. Callers
must inspect ``results[*].status`` to find rows that need attention.
"""
flags = get_flags()
if flags.disable_new_agent_stack or not flags.enable_revert_route:
raise HTTPException(
status_code=503,
detail=(
"Revert is not available on this deployment yet. The route "
"ships before the UI; flip SURFSENSE_ENABLE_REVERT_ROUTE to "
"enable it."
),
)
thread = await load_thread(session, thread_id=thread_id)
if thread is None:
raise HTTPException(status_code=404, detail="Thread not found.")
# Reverse-chronological so the latest mutation in the turn unwinds
# first. ``id.desc()`` is the deterministic tiebreaker for actions
# written in the same millisecond.
rows_stmt = (
select(AgentActionLog)
.where(
AgentActionLog.thread_id == thread_id,
AgentActionLog.chat_turn_id == chat_turn_id,
)
.order_by(AgentActionLog.created_at.desc(), AgentActionLog.id.desc())
)
rows = (await session.execute(rows_stmt)).scalars().all()
requester_user_id = str(user.id) if user is not None else None
results: list[RevertTurnActionResult] = []
# Counters MUST be exhaustive so the response invariant
# ``total == sum(counters)`` always holds. Frontend toasts and
# ``RevertTurnButton`` rely on this for "X of Y reverted" math.
counts: dict[str, int] = {
"reverted": 0,
"already_reverted": 0,
"not_reversible": 0,
"permission_denied": 0,
"failed": 0,
"skipped": 0,
}
# Single batched idempotency probe replaces the previous per-row
# SELECT. ``rows`` are filtered in the loop so we pre-collect only
# the original-action ids (skip rows that are themselves
# reverts).
eligible_ids = [r.id for r in rows if r.reverse_of is None]
already_reverted_map = await _was_already_reverted_batch(
session, action_ids=eligible_ids
)
for action in rows:
# Skip rows that ARE reverts of an earlier action — reverting a
# revert is meaningless inside a batch (the user wants to wipe
# the original effects, not chase tail).
if action.reverse_of is not None:
counts["skipped"] += 1
results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="skipped",
message="Row is itself a revert action; skipped.",
)
)
continue
# Idempotency: surface "already_reverted" instead of failing.
existing_revert_id = already_reverted_map.get(action.id)
if existing_revert_id is not None:
counts["already_reverted"] += 1
results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="already_reverted",
new_action_id=existing_revert_id,
)
)
continue
if not can_revert(
requester_user_id=requester_user_id,
action=action,
is_admin=False,
):
counts["permission_denied"] += 1
results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="permission_denied",
message="You are not allowed to revert this action.",
)
)
continue
# Per-row SAVEPOINT so one failed revert never poisons later
# successful ones.
try:
async with session.begin_nested():
outcome = await revert_action(
session,
action=action,
requester_user_id=requester_user_id,
)
if outcome.status != "ok":
raise _OutcomeRollbackError(outcome)
except _OutcomeRollbackError as rollback:
outcome = rollback.outcome
classified = _classify_outcome(outcome)
if classified == "permission_denied":
counts["permission_denied"] += 1
else:
counts["not_reversible"] += 1
results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status=classified,
message=outcome.message,
)
)
continue
except IntegrityError:
# Partial unique index caught a concurrent revert that won
# the race against our pre-flight ``_was_already_reverted``
# SELECT. Look up the winner so
# we can surface its ``new_action_id`` to the client.
existing_revert_id = await _was_already_reverted(
session, action_id=action.id
)
counts["already_reverted"] += 1
results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="already_reverted",
new_action_id=existing_revert_id,
)
)
continue
except Exception as err: # pragma: no cover — defensive, logged
logger.exception(
"Unexpected revert failure inside batch for action_id=%s",
action.id,
)
counts["failed"] += 1
results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="failed",
error=str(err) or err.__class__.__name__,
)
)
continue
counts["reverted"] += 1
results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="reverted",
message=outcome.message,
new_action_id=outcome.new_action_id,
)
)
# Single commit at the end — successful SAVEPOINTs above already
# released; failed ones rolled back to their savepoint. No row leaks
# across the boundary.
try:
await session.commit()
except Exception as err: # pragma: no cover — defensive
logger.exception(
"Final commit for revert-turn failed (thread=%s turn=%s)",
thread_id,
chat_turn_id,
)
await session.rollback()
raise HTTPException(
status_code=500,
detail="Internal error while finalising revert-turn batch.",
) from err
has_partial = (
counts["failed"] > 0
or counts["not_reversible"] > 0
or counts["permission_denied"] > 0
)
overall_status: Literal["ok", "partial"] = "partial" if has_partial else "ok"
return RevertTurnResponse(
status=overall_status,
chat_turn_id=chat_turn_id,
total=len(rows),
reverted=counts["reverted"],
already_reverted=counts["already_reverted"],
not_reversible=counts["not_reversible"],
permission_denied=counts["permission_denied"],
failed=counts["failed"],
skipped=counts["skipped"],
results=results,
)
class _OutcomeRollbackError(Exception):
"""Sentinel raised inside the SAVEPOINT to roll back a non-OK outcome.
``revert_action`` writes a new ``agent_action_log`` row only on the
happy path, but on the failure paths it sometimes mutates the
``DocumentRevision``/``Document`` tables before deciding the action
is not reversible. Wrapping each call in ``begin_nested`` and raising
this from the failure branch ensures we always discard partial
writes for failed rows.
"""
def __init__(self, outcome: RevertOutcome) -> None:
self.outcome = outcome
super().__init__(outcome.message)
__all__ = ["router"]

View file

@ -745,6 +745,51 @@ async def search_document_titles(
) from e
@router.get("/documents/by-virtual-path", response_model=DocumentTitleRead)
async def get_document_by_virtual_path(
search_space_id: int,
virtual_path: str,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Resolve a knowledge-base document id by exact virtual path."""
try:
await check_permission(
session,
user,
search_space_id,
Permission.DOCUMENTS_READ.value,
"You don't have permission to read documents in this search space",
)
result = await session.execute(
select(
Document.id,
Document.title,
Document.document_type,
).filter(
Document.search_space_id == search_space_id,
Document.document_metadata["virtual_path"].as_string() == virtual_path,
)
)
row = result.first()
if row is None:
raise HTTPException(status_code=404, detail="Document not found")
return DocumentTitleRead(
id=row.id,
title=row.title,
document_type=row.document_type,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to resolve document by virtual path: {e!s}",
) from e
@router.get("/documents/status", response_model=DocumentStatusBatchResponse)
async def get_documents_status(
search_space_id: int,

View file

@ -11,10 +11,11 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
"""
import asyncio
import json
import logging
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from sqlalchemy import func, or_
from sqlalchemy.exc import IntegrityError, OperationalError
@ -28,6 +29,12 @@ from app.agents.new_chat.filesystem_selection import (
FilesystemSelection,
LocalFilesystemMount,
)
from app.agents.new_chat.middleware.busy_mutex import (
get_cancel_state,
is_cancel_requested,
manager,
request_cancel,
)
from app.config import config
from app.db import (
ChatComment,
@ -43,6 +50,7 @@ from app.db import (
)
from app.schemas.new_chat import (
AgentToolInfo,
CancelActiveTurnResponse,
LocalFilesystemMountPayload,
NewChatMessageRead,
NewChatRequest,
@ -59,6 +67,7 @@ from app.schemas.new_chat import (
ThreadListItem,
ThreadListResponse,
TokenUsageSummary,
TurnStatusResponse,
)
from app.services.token_tracking_service import record_token_usage
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
@ -71,6 +80,9 @@ from app.utils.user_message_multimodal import (
_logger = logging.getLogger(__name__)
_background_tasks: set[asyncio.Task] = set()
TURN_CANCELLING_INITIAL_DELAY_MS = 200
TURN_CANCELLING_BACKOFF_FACTOR = 2
TURN_CANCELLING_MAX_DELAY_MS = 1500
router = APIRouter()
@ -136,6 +148,326 @@ def _resolve_filesystem_selection(
)
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
"""Bounded exponential delay for TURN_CANCELLING retry hints."""
if attempt < 1:
attempt = 1
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
)
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
def _build_turn_status_payload(thread_id: int) -> dict[str, object]:
lock = manager.lock_for(str(thread_id))
if not lock.locked():
return {"status": "idle"}
if is_cancel_requested(str(thread_id)):
cancel_state = get_cancel_state(str(thread_id))
attempt = cancel_state[0] if cancel_state else 1
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
retry_after_at = int(datetime.now(UTC).timestamp() * 1000) + retry_after_ms
return {
"status": "cancelling",
"retry_after_ms": retry_after_ms,
"retry_after_at": retry_after_at,
}
return {"status": "busy"}
def _set_retry_after_headers(response: Response, retry_after_ms: int) -> None:
response.headers["retry-after-ms"] = str(retry_after_ms)
response.headers["Retry-After"] = str(max(1, (retry_after_ms + 999) // 1000))
def _raise_if_thread_busy_for_start(thread_id: int) -> None:
status_payload = _build_turn_status_payload(thread_id)
status = status_payload["status"]
if status == "idle":
return
if status == "cancelling":
retry_after_ms = int(status_payload.get("retry_after_ms") or 0)
detail = {
"errorCode": "TURN_CANCELLING",
"message": "A previous response is still stopping. Please try again in a moment.",
"retry_after_ms": retry_after_ms if retry_after_ms > 0 else None,
"retry_after_at": status_payload.get("retry_after_at"),
}
headers = (
{
"retry-after-ms": str(retry_after_ms),
"Retry-After": str(max(1, (retry_after_ms + 999) // 1000)),
}
if retry_after_ms > 0
else None
)
raise HTTPException(status_code=409, detail=detail, headers=headers)
raise HTTPException(
status_code=409,
detail={
"errorCode": "THREAD_BUSY",
"message": "Another response is still finishing for this thread. Please try again in a moment.",
},
)
def _find_pre_turn_checkpoint_id(
checkpoint_tuples: list,
*,
turn_id: str,
) -> str | None:
"""Locate the LangGraph checkpoint immediately before ``turn_id`` started.
``checkpoint_tuples`` arrives newest-first from
``checkpointer.alist(config)``. We walk OLDEST-first (``reversed``)
and remember the most recent checkpoint that does NOT belong to the
edited turn. As soon as we cross into the edited turn (a checkpoint
whose ``turn_id`` matches), we return the previously-tracked
checkpoint that's the state immediately before ``turn_id`` began.
The naive "newest-first, return first non-matching" approach is
INCORRECT when later turns exist after ``turn_id``: their
checkpoints also satisfy ``cp_turn_id != turn_id`` and would be
returned before the real pre-turn boundary is reached.
Reads from ``cp_tuple.metadata`` (the durable surface promoted from
``configurable`` at write time) rather than ``config["configurable"]``
so the lookup is portable across checkpointer implementations.
Returns ``None`` when no eligible pre-turn checkpoint exists (e.g.
the edited turn is the very first turn of the thread). Callers fall
back to the oldest available checkpoint in that case.
"""
last_pre_turn_target: str | None = None
for cp_tuple in reversed(checkpoint_tuples): # oldest -> newest
metadata = getattr(cp_tuple, "metadata", None) or {}
cp_turn_id = metadata.get("turn_id") if isinstance(metadata, dict) else None
if cp_turn_id == turn_id:
# Crossed into the edited turn; the previous tracked
# checkpoint is the rewind target. May be ``None`` if we hit
# the edited turn on the very first iteration.
return last_pre_turn_target
try:
last_pre_turn_target = cp_tuple.config["configurable"]["checkpoint_id"]
except (KeyError, TypeError):
continue
return last_pre_turn_target
async def _revert_turns_for_regenerate(
*,
thread_id: int,
chat_turn_ids: list[str],
requester_user_id: str,
) -> dict:
"""Best-effort revert pass for every ``chat_turn_id`` in ``chat_turn_ids``.
Runs BEFORE the regenerate stream so the frontend can surface
partial-rollback feedback alongside the new assistant turn. Each
turn's actions are reverted in their own SAVEPOINTs (handled
inside :mod:`app.routes.agent_revert_route`'s helpers) so a single
failure never poisons the batch.
Sequencing inside the request: revert THEN regenerate. The
operation is NOT atomic and partial state IS surfaced see the
plan's "Sequencing inside the request" note.
"""
from app.routes.agent_revert_route import (
RevertTurnActionResult,
_classify_outcome,
_OutcomeRollbackError,
_was_already_reverted,
_was_already_reverted_batch,
)
from app.services.revert_service import (
can_revert,
revert_action,
)
aggregated_results: list[dict] = []
# Exhaustive counters keep the response invariant
# ``total == sum(counters)`` true for ``data-revert-results``.
counts = {
"reverted": 0,
"already_reverted": 0,
"not_reversible": 0,
"permission_denied": 0,
"failed": 0,
"skipped": 0,
}
# Local import keeps the route module's existing imports tidy and
# avoids a circular dependency at module-load time.
from app.db import AgentActionLog as _AgentActionLog
async with shielded_async_session() as session:
for chat_turn_id in chat_turn_ids:
rows_stmt = (
select(_AgentActionLog)
.where(
_AgentActionLog.thread_id == thread_id,
_AgentActionLog.chat_turn_id == chat_turn_id,
)
.order_by(
_AgentActionLog.created_at.desc(),
_AgentActionLog.id.desc(),
)
)
rows = (await session.execute(rows_stmt)).scalars().all()
# Batch idempotency probe across the turn (single SELECT
# instead of one per row).
eligible_ids = [r.id for r in rows if r.reverse_of is None]
already_reverted_map = await _was_already_reverted_batch(
session, action_ids=eligible_ids
)
for action in rows:
if action.reverse_of is not None:
counts["skipped"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="skipped",
message="Row is itself a revert action; skipped.",
).model_dump()
)
continue
existing_revert_id = already_reverted_map.get(action.id)
if existing_revert_id is not None:
counts["already_reverted"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="already_reverted",
new_action_id=existing_revert_id,
).model_dump()
)
continue
if not can_revert(
requester_user_id=requester_user_id,
action=action,
is_admin=False,
):
counts["permission_denied"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="permission_denied",
message="You are not allowed to revert this action.",
).model_dump()
)
continue
try:
async with session.begin_nested():
outcome = await revert_action(
session,
action=action,
requester_user_id=requester_user_id,
)
if outcome.status != "ok":
raise _OutcomeRollbackError(outcome)
except _OutcomeRollbackError as rollback:
outcome = rollback.outcome
classified = _classify_outcome(outcome)
if classified == "permission_denied":
counts["permission_denied"] += 1
else:
counts["not_reversible"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status=classified,
message=outcome.message,
).model_dump()
)
continue
except IntegrityError:
# Concurrent revert won the race against the
# pre-flight ``_was_already_reverted`` SELECT.
# Surface the winning revert id so the client can
# treat this as a successful idempotent op.
existing_revert_id = await _was_already_reverted(
session, action_id=action.id
)
counts["already_reverted"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="already_reverted",
new_action_id=existing_revert_id,
).model_dump()
)
continue
except Exception as err: # pragma: no cover — defensive
_logger.exception(
"Unexpected revert failure during regenerate batch "
"for action_id=%s",
action.id,
)
counts["failed"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="failed",
error=str(err) or err.__class__.__name__,
).model_dump()
)
continue
counts["reverted"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="reverted",
message=outcome.message,
new_action_id=outcome.new_action_id,
).model_dump()
)
try:
await session.commit()
except Exception:
_logger.exception(
"[regenerate-revert] Final commit failed; rolling back batch."
)
await session.rollback()
has_partial = (
counts["failed"] > 0
or counts["not_reversible"] > 0
or counts["permission_denied"] > 0
)
return {
"status": "partial" if has_partial else "ok",
"chat_turn_ids": chat_turn_ids,
"total": len(aggregated_results),
"reverted": counts["reverted"],
"already_reverted": counts["already_reverted"],
"not_reversible": counts["not_reversible"],
"permission_denied": counts["permission_denied"],
"failed": counts["failed"],
"skipped": counts["skipped"],
"results": aggregated_results,
}
def _try_delete_sandbox(thread_id: int) -> None:
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
from app.agents.new_chat.sandbox import (
@ -574,6 +906,7 @@ async def get_thread_messages(
token_usage=TokenUsageSummary.model_validate(msg.token_usage)
if msg.token_usage
else None,
turn_id=msg.turn_id,
)
for msg in db_messages
]
@ -1006,12 +1339,24 @@ async def append_message(
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
# Create message
# Create message. ``turn_id`` is the per-turn correlation id from
# ``configurable.turn_id`` (added in migration 136) — when the
# client streams it back to ``appendMessage``, we persist it so
# C1's edit-from-arbitrary-position can later map this message
# back to the LangGraph checkpoint that produced its turn.
raw_turn_id = raw_body.get("turn_id")
turn_id_value = (
str(raw_turn_id).strip()
if isinstance(raw_turn_id, str) and raw_turn_id.strip()
else None
)
db_message = NewChatMessage(
thread_id=thread_id,
role=message_role,
content=content,
author_id=user.id,
turn_id=turn_id_value,
)
session.add(db_message)
@ -1050,6 +1395,7 @@ async def append_message(
created_at=db_message.created_at,
author_id=db_message.author_id,
token_usage=None,
turn_id=db_message.turn_id,
)
except HTTPException:
@ -1207,6 +1553,7 @@ async def handle_new_chat(
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(request.chat_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
@ -1281,6 +1628,93 @@ async def handle_new_chat(
) from None
@router.post(
"/threads/{thread_id}/cancel-active-turn",
response_model=CancelActiveTurnResponse,
)
async def cancel_active_turn(
thread_id: int,
response: Response,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Signal cancellation for the currently running turn on ``thread_id``."""
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
await check_thread_access(session, thread, user)
status_payload = _build_turn_status_payload(thread_id)
if status_payload["status"] == "idle":
return CancelActiveTurnResponse(
status="idle",
error_code="NO_ACTIVE_TURN",
)
request_cancel(str(thread_id))
response.status_code = 202
updated_payload = _build_turn_status_payload(thread_id)
retry_after_ms = int(updated_payload.get("retry_after_ms") or 0)
retry_after_at = (
int(updated_payload["retry_after_at"])
if "retry_after_at" in updated_payload
else None
)
if retry_after_ms > 0:
_set_retry_after_headers(response, retry_after_ms)
return CancelActiveTurnResponse(
status="cancelling",
error_code="TURN_CANCELLING",
retry_after_ms=retry_after_ms if retry_after_ms > 0 else None,
retry_after_at=retry_after_at,
)
@router.get(
"/threads/{thread_id}/turn-status",
response_model=TurnStatusResponse,
)
async def get_turn_status(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to view chats in this search space",
)
await check_thread_access(session, thread, user)
status_payload = _build_turn_status_payload(thread_id)
return TurnStatusResponse(
status=status_payload["status"], # type: ignore[arg-type]
active_turn_id=None,
retry_after_ms=status_payload.get("retry_after_ms"), # type: ignore[arg-type]
retry_after_at=status_payload.get("retry_after_at"), # type: ignore[arg-type]
)
# =============================================================================
# Chat Regeneration Endpoint (Edit/Reload)
# =============================================================================
@ -1336,6 +1770,7 @@ async def regenerate_response(
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(thread_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
@ -1373,43 +1808,123 @@ async def regenerate_response(
user_query_to_use = request.user_query
regenerate_image_urls: list[str] = []
# Look through checkpoints to find the right one
# We want to find the checkpoint just before the last HumanMessage
for i, cp_tuple in enumerate(checkpoint_tuples):
# Access the checkpoint's channel_values which contains "messages"
checkpoint_data = cp_tuple.checkpoint
channel_values = checkpoint_data.get("channel_values", {})
state_messages = channel_values.get("messages", [])
# ---------------------------------------------------------------
# Edit-from-arbitrary-position. When the client passes
# ``from_message_id`` we look up its persisted ``turn_id`` (added
# in migration 136) and pick the checkpoint immediately before
# that turn started.
#
# Legacy graceful-degradation contract:
# * Rows persisted BEFORE migration 136 have ``turn_id IS NULL``.
# Returning 400 in that case is the wrong UX — the user is
# editing an old message in an existing thread and just wants
# it to work. We instead skip the checkpoint rewind (the
# stream falls back to the latest state) and skip the revert
# pass (no chat_turn_id available to walk). Deletion still
# uses ``created_at``, so the messages-after-cursor slice is
# correct on both legacy and post-136 rows.
# ---------------------------------------------------------------
from_message_turn_id: str | None = None
from_message_created_at: datetime | None = None
legacy_from_message: bool = False
if request.from_message_id is not None:
from_msg_row = await session.execute(
select(NewChatMessage).filter(
NewChatMessage.id == request.from_message_id,
NewChatMessage.thread_id == thread_id,
)
)
from_msg = from_msg_row.scalars().first()
if from_msg is None:
raise HTTPException(
status_code=404,
detail="from_message_id not found in this thread.",
)
from_message_created_at = from_msg.created_at
if not from_msg.turn_id:
# Legacy row — surface the degradation in logs but let
# the request proceed with the slice-based delete and a
# cold-start checkpoint.
legacy_from_message = True
_logger.warning(
"[regenerate] from_message_id=%s on thread=%s has no "
"turn_id (legacy row pre-migration-136). Falling back "
"to slice-based delete without checkpoint rewind. "
"revert_actions=%s will be ignored.",
request.from_message_id,
thread_id,
request.revert_actions,
)
else:
from_message_turn_id = from_msg.turn_id
if state_messages:
last_msg = state_messages[-1]
# Find a checkpoint where the last message is NOT a HumanMessage
# This means we're at a state before the user's last message
if not isinstance(last_msg, HumanMessage):
# If no new user_query provided (reload), extract from a later checkpoint
if user_query_to_use is None and i > 0:
# Get the user query from a more recent checkpoint
for prev_cp_tuple in checkpoint_tuples[:i]:
prev_checkpoint_data = prev_cp_tuple.checkpoint
prev_channel_values = prev_checkpoint_data.get(
"channel_values", {}
)
prev_messages = prev_channel_values.get("messages", [])
for msg in reversed(prev_messages):
if isinstance(msg, HumanMessage):
q, imgs = split_langchain_human_content(msg.content)
user_query_to_use = q
regenerate_image_urls = imgs
break
if user_query_to_use is not None and (
str(user_query_to_use).strip() or regenerate_image_urls
):
break
target_checkpoint_id = cp_tuple.config["configurable"][
# Walk oldest-to-newest and pick the LAST checkpoint whose
# ``turn_id`` differs from the edited turn — that's the state
# immediately before this turn started running. We read from
# ``metadata`` (the durable surface) rather than
# ``config["configurable"]`` so the lookup works across
# checkpointer implementations.
target_checkpoint_id = _find_pre_turn_checkpoint_id(
checkpoint_tuples,
turn_id=from_message_turn_id,
)
if target_checkpoint_id is None and len(checkpoint_tuples) > 0:
# Fall back to the oldest checkpoint — better than
# 400ing when the agent didn't checkpoint pre-turn
# (e.g. very first turn of the thread).
target_checkpoint_id = checkpoint_tuples[-1].config["configurable"][
"checkpoint_id"
]
break
# Look through checkpoints to find the right one
# We want to find the checkpoint just before the last HumanMessage.
# We enter this branch when:
# * the client did NOT pin ``from_message_id`` (legacy reload/edit), OR
# * the client pinned ``from_message_id`` but the row is a
# legacy pre-migration-136 row with no ``turn_id`` (we
# downgraded to the same heuristic as a regular reload).
# We DO skip it when a real turn_id pinned ``target_checkpoint_id``
# — that's the C1 happy path and the heuristic below would just
# re-derive a worse target.
if request.from_message_id is None or legacy_from_message:
for i, cp_tuple in enumerate(checkpoint_tuples):
# Access the checkpoint's channel_values which contains "messages"
checkpoint_data = cp_tuple.checkpoint
channel_values = checkpoint_data.get("channel_values", {})
state_messages = channel_values.get("messages", [])
if state_messages:
last_msg = state_messages[-1]
# Find a checkpoint where the last message is NOT a HumanMessage
# This means we're at a state before the user's last message
if not isinstance(last_msg, HumanMessage):
# If no new user_query provided (reload), extract from a later checkpoint
if user_query_to_use is None and i > 0:
# Get the user query from a more recent checkpoint
for prev_cp_tuple in checkpoint_tuples[:i]:
prev_checkpoint_data = prev_cp_tuple.checkpoint
prev_channel_values = prev_checkpoint_data.get(
"channel_values", {}
)
prev_messages = prev_channel_values.get("messages", [])
for msg in reversed(prev_messages):
if isinstance(msg, HumanMessage):
q, imgs = split_langchain_human_content(
msg.content
)
user_query_to_use = q
regenerate_image_urls = imgs
break
if user_query_to_use is not None and (
str(user_query_to_use).strip()
or regenerate_image_urls
):
break
target_checkpoint_id = cp_tuple.config["configurable"][
"checkpoint_id"
]
break
# If we couldn't find a good checkpoint, try alternative approaches
if target_checkpoint_id is None and checkpoint_tuples:
@ -1472,18 +1987,51 @@ async def regenerate_response(
detail="Could not determine user query for regeneration. Please provide a user_query.",
)
# Get the last two messages to delete AFTER streaming succeeds
# This prevents data loss if streaming fails
last_messages_result = await session.execute(
select(NewChatMessage)
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at.desc())
.limit(2)
)
# Get the messages to delete AFTER streaming succeeds.
# This prevents data loss if streaming fails.
#
# When ``from_message_id`` is set we slice from that message
# forward (using ``created_at`` so we also catch any tool/system
# messages persisted into the same turn). Otherwise
# we keep the legacy "last 2 messages" rewind.
if request.from_message_id is not None and from_message_created_at is not None:
last_messages_result = await session.execute(
select(NewChatMessage)
.filter(
NewChatMessage.thread_id == thread_id,
NewChatMessage.created_at >= from_message_created_at,
)
.order_by(NewChatMessage.created_at.desc())
)
else:
last_messages_result = await session.execute(
select(NewChatMessage)
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at.desc())
.limit(2)
)
messages_to_delete = list(last_messages_result.scalars().all())
message_ids_to_delete = [msg.id for msg in messages_to_delete]
# When revert_actions is requested, collect the set of
# ``chat_turn_id``s present in the slice we're about to delete.
# Each one will be reverted (best-effort) BEFORE the regenerate
# stream begins. Legacy rows have ``turn_id=None`` and silently
# contribute nothing — we already logged the degradation above.
revert_turn_ids: list[str] = []
if (
request.revert_actions
and request.from_message_id is not None
and not legacy_from_message
):
seen_turns: set[str] = set()
for msg in messages_to_delete:
tid = msg.turn_id
if tid and tid not in seen_turns:
seen_turns.add(tid)
revert_turn_ids.append(tid)
# Get search space for LLM config
search_space_result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
@ -1507,6 +2055,24 @@ async def regenerate_response(
# This prevents data loss if streaming fails (network error, LLM error, etc.)
async def stream_with_cleanup():
streaming_completed = False
# Best-effort revert pass BEFORE the regenerate stream begins.
# Each turn is reverted independently (per-row SAVEPOINTs
# inside the route helper) and the per-action results are surfaced
# on a single ``data-revert-results`` SSE event so the frontend
# can render any failed rows alongside the new turn. Failures here
# do NOT abort the regeneration — partial rollback is documented
# behaviour.
if revert_turn_ids:
revert_results = await _revert_turns_for_regenerate(
thread_id=thread_id,
chat_turn_ids=revert_turn_ids,
requester_user_id=str(user.id),
)
envelope = {
"type": "data-revert-results",
"data": revert_results,
}
yield f"data: {json.dumps(envelope, default=str)}\n\n".encode()
try:
async for chunk in stream_new_chat(
user_query=str(user_query_to_use),
@ -1524,6 +2090,7 @@ async def regenerate_response(
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
user_image_data_urls=regenerate_image_urls or None,
flow="regenerate",
):
yield chunk
streaming_completed = True
@ -1611,6 +2178,7 @@ async def resume_chat(
)
await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(thread_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,

View file

@ -3,7 +3,7 @@ import logging
from fastapi import APIRouter, Depends, HTTPException
from langchain_core.messages import HumanMessage
from pydantic import BaseModel as PydanticBaseModel
from sqlalchemy import func
from sqlalchemy import func, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
@ -15,6 +15,7 @@ from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_mem
from app.config import config
from app.db import (
ImageGenerationConfig,
NewChatThread,
NewLLMConfig,
Permission,
SearchSpace,
@ -790,9 +791,27 @@ async def update_llm_preferences(
# Update preferences
update_data = preferences.model_dump(exclude_unset=True)
previous_agent_llm_id = search_space.agent_llm_id
for key, value in update_data.items():
setattr(search_space, key, value)
agent_llm_changed = (
"agent_llm_id" in update_data
and update_data["agent_llm_id"] != previous_agent_llm_id
)
if agent_llm_changed:
await session.execute(
update(NewChatThread)
.where(NewChatThread.search_space_id == search_space_id)
.values(pinned_llm_config_id=None)
)
logger.info(
"Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)",
search_space_id,
previous_agent_llm_id,
update_data["agent_llm_id"],
)
await session.commit()
await session.refresh(search_space)

View file

@ -51,6 +51,11 @@ class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
author_display_name: str | None = None
author_avatar_url: str | None = None
token_usage: TokenUsageSummary | None = None
# Per-turn correlation id (``f"{chat_id}:{ms}"``) from
# ``configurable.turn_id`` at streaming time. Nullable because
# legacy rows predate the column; clients should treat NULL as
# "edit-from-this-message is unavailable".
turn_id: str | None = None
model_config = ConfigDict(from_attributes=True)
@ -241,6 +246,15 @@ class RegenerateRequest(BaseModel):
For edit, optional user_images (when not None) replaces image URLs resolved from
checkpoint/DB so the client can send the full user turn (text and/or images).
Edit-from-arbitrary-position. When ``from_message_id`` is provided
the route slices conversation history starting at that message (instead of
the legacy "last 2 messages" rewind), rewinds the LangGraph checkpoint by
matching ``configurable.turn_id`` stored on the message (added in migration 136), and
optionally reverts every reversible action emitted in turns at or after
``from_message_id``. The revert step is best-effort and runs BEFORE the
regenerate stream partial failures are surfaced via SSE
``data-revert-results`` and do not abort the regeneration.
"""
search_space_id: int
@ -257,6 +271,28 @@ class RegenerateRequest(BaseModel):
default=None,
description="If set, use these images for the regenerated turn (edit); overrides checkpoint/DB",
)
from_message_id: int | None = Field(
default=None,
description=(
"Message id to rewind to. When set, history is sliced "
"from this message forward and the LangGraph checkpoint is "
"rewound to the state immediately preceding this turn. Legacy "
"rows that predate migration 136 have ``turn_id=None`` and "
"still process — the route logs a warning, skips the "
"checkpoint rewind, and ignores ``revert_actions`` (no "
"chat_turn_id available to walk)."
),
)
revert_actions: bool = Field(
default=False,
description=(
"When true, every reversible action emitted at or "
"after ``from_message_id`` is reverted before the regenerate "
"stream begins. Per-action results are surfaced via the "
"``data-revert-results`` SSE event. Partial failures DO NOT "
"abort the regeneration."
),
)
@model_validator(mode="after")
def _validate_regenerate_user_images(self) -> Self:
@ -264,6 +300,14 @@ class RegenerateRequest(BaseModel):
raise ValueError(f"At most {MAX_NEW_CHAT_IMAGES} images allowed")
return self
@model_validator(mode="after")
def _validate_revert_actions_requires_from_message(self) -> Self:
if self.revert_actions and self.from_message_id is None:
raise ValueError(
"revert_actions requires from_message_id; specify which message to rewind to"
)
return self
# =============================================================================
# Agent Tools Schemas
@ -291,6 +335,24 @@ class ResumeRequest(BaseModel):
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
class CancelActiveTurnResponse(BaseModel):
"""Response for canceling an active turn on a chat thread."""
status: Literal["cancelling", "idle"]
error_code: Literal["TURN_CANCELLING", "NO_ACTIVE_TURN"]
retry_after_ms: int | None = None
retry_after_at: int | None = None
class TurnStatusResponse(BaseModel):
"""Current turn execution status for a thread."""
status: Literal["idle", "busy", "cancelling"]
active_turn_id: str | None = None
retry_after_ms: int | None = None
retry_after_at: int | None = None
# =============================================================================
# Public Chat Snapshot Schemas
# =============================================================================

View file

@ -0,0 +1,385 @@
"""Resolve and persist Auto (Fastest) model pins per chat thread.
Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we
resolve that virtual mode to one concrete global LLM config exactly once and
persist the chosen config id on ``new_chat_threads.pinned_llm_config_id`` so
subsequent turns are stable.
Single-writer invariant: this module is the only writer of
``NewChatThread.pinned_llm_config_id`` (aside from the bulk clear in
``search_spaces_routes`` when a search space's ``agent_llm_id`` changes).
Therefore a non-NULL value unambiguously means "this thread has an
Auto-resolved pin"; no separate source/policy column is needed.
"""
from __future__ import annotations
import hashlib
import logging
import threading
import time
from dataclasses import dataclass
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import NewChatThread
from app.services.quality_score import _QUALITY_TOP_K
from app.services.token_quota_service import TokenQuotaService
logger = logging.getLogger(__name__)
AUTO_FASTEST_ID = 0
AUTO_FASTEST_MODE = "auto_fastest"
_RUNTIME_COOLDOWN_SECONDS = 600
_HEALTHY_TTL_SECONDS = 45
# In-memory runtime cooldown map for configs that recently hard-failed at
# provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps
# the same unhealthy config from being reselected immediately during repair.
_runtime_cooldown_until: dict[int, float] = {}
_runtime_cooldown_lock = threading.Lock()
# Short-TTL "recently healthy" cache for configs that just passed a runtime
# preflight ping. Lets back-to-back turns on the same model skip the probe
# without eroding correctness — entries auto-expire and are wiped any time
# the same config is cooled down or the OR catalogue is refreshed.
_healthy_until: dict[int, float] = {}
_healthy_lock = threading.Lock()
@dataclass
class AutoPinResolution:
resolved_llm_config_id: int
resolved_tier: str
from_existing_pin: bool
def _is_usable_global_config(cfg: dict) -> bool:
return bool(
cfg.get("id") is not None
and cfg.get("model_name")
and cfg.get("provider")
and cfg.get("api_key")
)
def _prune_runtime_cooldowns(now_ts: float | None = None) -> None:
now = time.time() if now_ts is None else now_ts
stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now]
for cid in stale:
_runtime_cooldown_until.pop(cid, None)
def _is_runtime_cooled_down(config_id: int) -> bool:
with _runtime_cooldown_lock:
_prune_runtime_cooldowns()
return config_id in _runtime_cooldown_until
def mark_runtime_cooldown(
config_id: int,
*,
reason: str = "rate_limited",
cooldown_seconds: int = _RUNTIME_COOLDOWN_SECONDS,
) -> None:
"""Temporarily suppress a config from Auto selection.
Used by runtime error handlers (e.g. OpenRouter 429) so an already pinned
config that is currently unhealthy does not get immediately reused on the
same thread during repair.
"""
if cooldown_seconds <= 0:
cooldown_seconds = _RUNTIME_COOLDOWN_SECONDS
until = time.time() + int(cooldown_seconds)
with _runtime_cooldown_lock:
_runtime_cooldown_until[int(config_id)] = until
_prune_runtime_cooldowns()
# A cooled cfg can never be "recently healthy"; drop any stale credit so
# the next turn that resolves to it (after cooldown) re-runs preflight.
clear_healthy(int(config_id))
logger.info(
"auto_pin_runtime_cooled_down config_id=%s reason=%s cooldown_seconds=%s",
config_id,
reason,
cooldown_seconds,
)
def clear_runtime_cooldown(config_id: int | None = None) -> None:
"""Test/ops helper to clear runtime cooldown entries."""
with _runtime_cooldown_lock:
if config_id is None:
_runtime_cooldown_until.clear()
return
_runtime_cooldown_until.pop(int(config_id), None)
def _prune_healthy(now_ts: float | None = None) -> None:
now = time.time() if now_ts is None else now_ts
stale = [cid for cid, until in _healthy_until.items() if until <= now]
for cid in stale:
_healthy_until.pop(cid, None)
def is_recently_healthy(config_id: int) -> bool:
"""Return True if ``config_id`` passed preflight within the TTL window."""
with _healthy_lock:
_prune_healthy()
return int(config_id) in _healthy_until
def mark_healthy(
config_id: int,
*,
ttl_seconds: int = _HEALTHY_TTL_SECONDS,
) -> None:
"""Record that ``config_id`` just passed a preflight probe.
Subsequent calls within ``ttl_seconds`` can skip the preflight ping. The
healthy state is intentionally process-local it's a latency hint, not a
correctness primitive so multi-worker drift is acceptable.
"""
if ttl_seconds <= 0:
ttl_seconds = _HEALTHY_TTL_SECONDS
until = time.time() + int(ttl_seconds)
with _healthy_lock:
_healthy_until[int(config_id)] = until
_prune_healthy()
def clear_healthy(config_id: int | None = None) -> None:
"""Drop one (or all) healthy-cache entries.
Called from runtime cooldown and OR catalogue refresh so a freshly cooled
or replaced config never carries stale "healthy" credit.
"""
with _healthy_lock:
if config_id is None:
_healthy_until.clear()
return
_healthy_until.pop(int(config_id), None)
def _global_candidates() -> list[dict]:
"""Return Auto-eligible global cfgs.
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
can't be picked as the thread's pin. Also excludes configs currently
in runtime cooldown (e.g. temporary 429 bursts).
"""
candidates = [
cfg
for cfg in config.GLOBAL_LLM_CONFIGS
if _is_usable_global_config(cfg)
and not cfg.get("health_gated")
and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
]
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
def _tier_of(cfg: dict) -> str:
return str(cfg.get("billing_tier", "free")).lower()
def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
"""Pick a config with quality-first ranking + deterministic spread.
Tier policy is lock-first: prefer Tier A (operator-curated YAML)
cfgs and only fall through to Tier B/C (dynamic OpenRouter) if no
Tier A cfg is eligible after upstream filters. Within the locked
pool, sort by ``quality_score`` and pick from the top-K via
``SHA256(thread_id)`` so different new threads spread across the
best models without ever picking a low-ranked one.
Returns ``(chosen_cfg, top_k_size)``. ``top_k_size`` is exposed for
structured logging in the caller.
"""
tier_a = [c for c in eligible if c.get("auto_pin_tier") in (None, "A")]
pool = tier_a if tier_a else eligible
pool = sorted(pool, key=lambda c: -int(c.get("quality_score") or 0))
top_k = pool[:_QUALITY_TOP_K]
digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest()
idx = int.from_bytes(digest[:8], "big") % len(top_k)
return top_k[idx], len(top_k)
def _to_uuid(user_id: str | UUID | None) -> UUID | None:
if user_id is None:
return None
if isinstance(user_id, UUID):
return user_id
try:
return UUID(str(user_id))
except Exception:
return None
async def _is_premium_eligible(
session: AsyncSession, user_id: str | UUID | None
) -> bool:
parsed = _to_uuid(user_id)
if parsed is None:
return False
usage = await TokenQuotaService.premium_get_usage(session, parsed)
return bool(usage.allowed)
async def resolve_or_get_pinned_llm_config_id(
session: AsyncSession,
*,
thread_id: int,
search_space_id: int,
user_id: str | UUID | None,
selected_llm_config_id: int,
force_repin_free: bool = False,
exclude_config_ids: set[int] | None = None,
) -> AutoPinResolution:
"""Resolve Auto (Fastest) to one concrete config id and persist the pin.
For non-auto selections, this function clears any existing pin and returns
the selected id as-is.
"""
thread = (
(
await session.execute(
select(NewChatThread)
.where(NewChatThread.id == thread_id)
.with_for_update(of=NewChatThread)
)
)
.unique()
.scalar_one_or_none()
)
if thread is None:
raise ValueError(f"Thread {thread_id} not found")
if thread.search_space_id != search_space_id:
raise ValueError(
f"Thread {thread_id} does not belong to search space {search_space_id}"
)
# Explicit model selected: clear any stale pin.
if selected_llm_config_id != AUTO_FASTEST_ID:
if thread.pinned_llm_config_id is not None:
thread.pinned_llm_config_id = None
await session.commit()
return AutoPinResolution(
resolved_llm_config_id=selected_llm_config_id,
resolved_tier="explicit",
from_existing_pin=False,
)
excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
candidates = [
c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids
]
if not candidates:
raise ValueError("No usable global LLM configs are available for Auto mode")
candidate_by_id = {int(c["id"]): c for c in candidates}
# Reuse an existing valid pin without re-checking current quota (no silent
# tier switch), unless the caller explicitly requests a forced repin to free.
pinned_id = thread.pinned_llm_config_id
if (
not force_repin_free
and pinned_id is not None
and int(pinned_id) in candidate_by_id
):
pinned_cfg = candidate_by_id[int(pinned_id)]
logger.info(
"auto_pin_reused thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s",
thread_id,
search_space_id,
pinned_id,
_tier_of(pinned_cfg),
)
logger.info(
"auto_pin_resolved thread_id=%s config_id=%s tier=%s "
"auto_pin_tier=%s score=%s top_k_size=0 from_existing_pin=True",
thread_id,
pinned_id,
_tier_of(pinned_cfg),
pinned_cfg.get("auto_pin_tier", "?"),
int(pinned_cfg.get("quality_score") or 0),
)
return AutoPinResolution(
resolved_llm_config_id=int(pinned_id),
resolved_tier=_tier_of(pinned_cfg),
from_existing_pin=True,
)
if pinned_id is not None:
logger.info(
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s",
thread_id,
search_space_id,
pinned_id,
)
premium_eligible = (
False if force_repin_free else await _is_premium_eligible(session, user_id)
)
if premium_eligible:
eligible = candidates
else:
eligible = [c for c in candidates if _tier_of(c) != "premium"]
if not eligible:
raise ValueError(
"Auto mode could not find an eligible LLM config for this user and quota state"
)
selected_cfg, top_k_size = _select_pin(eligible, thread_id)
selected_id = int(selected_cfg["id"])
selected_tier = _tier_of(selected_cfg)
thread.pinned_llm_config_id = selected_id
await session.commit()
if force_repin_free:
logger.info(
"auto_pin_forced_free_repin thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s",
thread_id,
search_space_id,
pinned_id,
selected_id,
)
if pinned_id is None:
logger.info(
"auto_pin_created thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s premium_eligible=%s",
thread_id,
search_space_id,
selected_id,
selected_tier,
premium_eligible,
)
else:
logger.info(
"auto_pin_repaired thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s tier=%s premium_eligible=%s",
thread_id,
search_space_id,
pinned_id,
selected_id,
selected_tier,
premium_eligible,
)
logger.info(
"auto_pin_resolved thread_id=%s config_id=%s tier=%s "
"auto_pin_tier=%s score=%s top_k_size=%d from_existing_pin=False",
thread_id,
selected_id,
selected_tier,
selected_cfg.get("auto_pin_tier", "?"),
int(selected_cfg.get("quality_score") or 0),
top_k_size,
)
return AutoPinResolution(
resolved_llm_config_id=selected_id,
resolved_tier=selected_tier,
from_existing_pin=False,
)

View file

@ -28,6 +28,7 @@ from litellm.exceptions import (
BadRequestError as LiteLLMBadRequestError,
ContextWindowExceededError,
)
from pydantic import Field
from app.utils.perf import get_perf_logger
@ -207,6 +208,12 @@ class LLMRouterService:
"""
Initialize the router with global LLM configurations.
Configs with ``router_pool_eligible=False`` are skipped so that
dynamic OpenRouter entries stay out of the shared router pool used
by title-gen / sub-agent ``model="auto"`` flows. Those dynamic
entries are still available for user-facing Auto-mode thread pinning
via ``auto_model_pin_service``.
Args:
global_configs: List of global LLM config dictionaries from YAML
router_settings: Optional router settings (routing_strategy, num_retries, etc.)
@ -220,6 +227,8 @@ class LLMRouterService:
model_list = []
premium_models: set[str] = set()
for config in global_configs:
if config.get("router_pool_eligible") is False:
continue
deployment = cls._config_to_deployment(config)
if deployment:
model_list.append(deployment)
@ -308,10 +317,45 @@ class LLMRouterService:
logger.error(f"Failed to initialize LLM Router: {e}")
instance._router = None
@classmethod
def rebuild(
cls,
global_configs: list[dict],
router_settings: dict | None = None,
) -> None:
"""Reset the router and re-run ``initialize`` with fresh configs.
``initialize`` short-circuits once it has run to avoid re-creating the
LiteLLM Router on every request; ``rebuild`` deliberately clears
``_initialized`` so a caller (e.g. background OpenRouter refresh)
can force the pool to be rebuilt after catalogue changes.
"""
instance = cls.get_instance()
instance._initialized = False
instance._router = None
instance._model_list = []
instance._premium_model_strings = set()
cls.initialize(global_configs, router_settings)
@classmethod
def is_premium_model(cls, model_string: str) -> bool:
"""Return True if *model_string* (as reported by LiteLLM) belongs to a
premium-tier deployment in the router pool."""
"""Return True if *model_string* belongs to a premium-tier deployment
in the LiteLLM router pool.
Scope: only covers configs with ``router_pool_eligible`` truthy. That
includes static YAML premium configs AND dynamic OpenRouter *premium*
entries (which opt in at generation time). Dynamic OpenRouter *free*
entries are deliberately kept out of the router pool OpenRouter
enforces free-tier limits globally per account, so per-deployment
router accounting can't represent them correctly — and therefore
return ``False`` here, which matches their ``billing_tier="free"``
(no premium quota).
For per-request premium checks on an arbitrary config (static or
dynamic, pool or non-pool), read ``agent_config.is_premium`` instead;
that reflects the per-config ``billing_tier`` directly and is what
user-facing Auto-mode thread pinning uses to bill correctly.
"""
instance = cls.get_instance()
return model_string in instance._premium_model_strings
@ -573,6 +617,11 @@ class ChatLiteLLMRouter(BaseChatModel):
# Public attributes that Pydantic will manage
model: str = "auto"
streaming: bool = True
# Static kwargs that flow through to ``litellm.completion(...)`` on every
# invocation (e.g. ``cache_control_injection_points`` set by
# ``apply_litellm_prompt_caching``). Per-call ``**kwargs`` from
# ``invoke()`` still take precedence — see ``_generate``/``_astream``.
model_kwargs: dict[str, Any] = Field(default_factory=dict)
# Bound tools and tool choice for tool calling
_bound_tools: list[dict] | None = None
@ -898,13 +947,16 @@ class ChatLiteLLMRouter(BaseChatModel):
logger.warning(f"Failed to convert tool {tool}: {e}")
continue
# Create a new instance with tools bound
# Create a new instance with tools bound. Carry through ``model_kwargs``
# so static settings (e.g. cache_control_injection_points) survive the
# bind_tools rebuild.
return ChatLiteLLMRouter(
router=self._router,
bound_tools=formatted_tools if formatted_tools else None,
tool_choice=tool_choice,
model=self.model,
streaming=self.streaming,
model_kwargs=dict(self.model_kwargs),
**kwargs,
)
@ -929,8 +981,10 @@ class ChatLiteLLMRouter(BaseChatModel):
formatted_messages = self._convert_messages(messages)
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
# Add tools if bound
call_kwargs = {**kwargs}
# Merge static model_kwargs (e.g. cache_control_injection_points) under
# per-call kwargs so callers can still override per invocation. Then add
# bound tools.
call_kwargs = {**self.model_kwargs, **kwargs}
if self._bound_tools:
call_kwargs["tools"] = self._bound_tools
if self._tool_choice is not None:
@ -997,8 +1051,10 @@ class ChatLiteLLMRouter(BaseChatModel):
formatted_messages = self._convert_messages(messages)
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
# Add tools if bound
call_kwargs = {**kwargs}
# Merge static model_kwargs (e.g. cache_control_injection_points) under
# per-call kwargs so callers can still override per invocation. Then add
# bound tools.
call_kwargs = {**self.model_kwargs, **kwargs}
if self._bound_tools:
call_kwargs["tools"] = self._bound_tools
if self._tool_choice is not None:
@ -1060,8 +1116,10 @@ class ChatLiteLLMRouter(BaseChatModel):
formatted_messages = self._convert_messages(messages)
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
# Add tools if bound
call_kwargs = {**kwargs}
# Merge static model_kwargs (e.g. cache_control_injection_points) under
# per-call kwargs so callers can still override per invocation. Then add
# bound tools.
call_kwargs = {**self.model_kwargs, **kwargs}
if self._bound_tools:
call_kwargs["tools"] = self._bound_tools
if self._tool_choice is not None:
@ -1110,8 +1168,10 @@ class ChatLiteLLMRouter(BaseChatModel):
formatted_messages = self._convert_messages(messages)
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
# Add tools if bound
call_kwargs = {**kwargs}
# Merge static model_kwargs (e.g. cache_control_injection_points) under
# per-call kwargs so callers can still override per invocation. Then add
# bound tools.
call_kwargs = {**self.model_kwargs, **kwargs}
if self._bound_tools:
call_kwargs["tools"] = self._bound_tools
if self._tool_choice is not None:

View file

@ -565,32 +565,63 @@ class VercelStreamingService:
# Error Part
# =========================================================================
def format_error(self, error_text: str) -> str:
def format_error(
self,
error_text: str,
error_code: str | None = None,
extra: dict[str, object] | None = None,
) -> str:
"""
Format an error message.
Args:
error_text: The error message text
error_code: Optional machine-readable error code for frontend branching
Returns:
str: SSE formatted error part
Example output:
data: {"type":"error","errorText":"Something went wrong"}
data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"}
"""
return self._format_sse({"type": "error", "errorText": error_text})
payload: dict[str, object] = {"type": "error", "errorText": error_text}
if error_code:
payload["errorCode"] = error_code
if extra:
payload.update(extra)
return self._format_sse(payload)
# =========================================================================
# Tool Parts
# =========================================================================
def format_tool_input_start(self, tool_call_id: str, tool_name: str) -> str:
def format_tool_input_start(
self,
tool_call_id: str,
tool_name: str,
*,
langchain_tool_call_id: str | None = None,
) -> str:
"""
Format the start of tool input streaming.
Args:
tool_call_id: The unique tool call identifier
tool_name: The name of the tool being called
tool_call_id: The unique tool call identifier. May be EITHER the
synthetic ``call_<run_id>`` id derived from LangGraph
``run_id`` (legacy / ``SURFSENSE_ENABLE_STREAM_PARITY_V2``
OFF, or the unmatched-fallback path under parity_v2) OR
the authoritative LangChain ``tool_call.id`` (parity_v2
path: when the provider streams ``tool_call_chunks`` we
register the ``index`` and reuse the lc-id as the card
id so live ``tool-input-delta`` events can be routed
without a downstream join). Either way, the same id is
preserved across ``tool-input-start`` / ``-delta`` /
``-available`` / ``tool-output-available`` for one call.
tool_name: The name of the tool being called.
langchain_tool_call_id: Optional authoritative LangChain
``tool_call.id``. When set, surfaces as
``langchainToolCallId`` so the frontend can join this card
to the action-log row written by ``ActionLogMiddleware``.
Returns:
str: SSE formatted tool input start part
@ -598,13 +629,14 @@ class VercelStreamingService:
Example output:
data: {"type":"tool-input-start","toolCallId":"call_abc123","toolName":"getWeather"}
"""
return self._format_sse(
{
"type": "tool-input-start",
"toolCallId": tool_call_id,
"toolName": tool_name,
}
)
payload: dict[str, Any] = {
"type": "tool-input-start",
"toolCallId": tool_call_id,
"toolName": tool_name,
}
if langchain_tool_call_id:
payload["langchainToolCallId"] = langchain_tool_call_id
return self._format_sse(payload)
def format_tool_input_delta(self, tool_call_id: str, input_text_delta: str) -> str:
"""
@ -629,7 +661,12 @@ class VercelStreamingService:
)
def format_tool_input_available(
self, tool_call_id: str, tool_name: str, input_data: dict[str, Any]
self,
tool_call_id: str,
tool_name: str,
input_data: dict[str, Any],
*,
langchain_tool_call_id: str | None = None,
) -> str:
"""
Format the completion of tool input.
@ -638,6 +675,8 @@ class VercelStreamingService:
tool_call_id: The tool call identifier
tool_name: The name of the tool
input_data: The complete tool input parameters
langchain_tool_call_id: Optional authoritative LangChain
``tool_call.id`` (see ``format_tool_input_start``).
Returns:
str: SSE formatted tool input available part
@ -645,22 +684,34 @@ class VercelStreamingService:
Example output:
data: {"type":"tool-input-available","toolCallId":"call_abc123","toolName":"getWeather","input":{"city":"SF"}}
"""
return self._format_sse(
{
"type": "tool-input-available",
"toolCallId": tool_call_id,
"toolName": tool_name,
"input": input_data,
}
)
payload: dict[str, Any] = {
"type": "tool-input-available",
"toolCallId": tool_call_id,
"toolName": tool_name,
"input": input_data,
}
if langchain_tool_call_id:
payload["langchainToolCallId"] = langchain_tool_call_id
return self._format_sse(payload)
def format_tool_output_available(self, tool_call_id: str, output: Any) -> str:
def format_tool_output_available(
self,
tool_call_id: str,
output: Any,
*,
langchain_tool_call_id: str | None = None,
) -> str:
"""
Format tool execution output.
Args:
tool_call_id: The tool call identifier
output: The tool execution result
langchain_tool_call_id: Optional authoritative LangChain
``tool_call.id`` extracted from ``ToolMessage.tool_call_id``.
When set, the frontend can backfill any card whose
``langchainToolCallId`` was not yet known at
``tool-input-start`` time.
Returns:
str: SSE formatted tool output available part
@ -668,13 +719,14 @@ class VercelStreamingService:
Example output:
data: {"type":"tool-output-available","toolCallId":"call_abc123","output":{"weather":"sunny"}}
"""
return self._format_sse(
{
"type": "tool-output-available",
"toolCallId": tool_call_id,
"output": output,
}
)
payload: dict[str, Any] = {
"type": "tool-output-available",
"toolCallId": tool_call_id,
"output": output,
}
if langchain_tool_call_id:
payload["langchainToolCallId"] = langchain_tool_call_id
return self._format_sse(payload)
# =========================================================================
# Step Parts

View file

@ -11,20 +11,81 @@ this service only manages the catalogue, not the inference path.
"""
import asyncio
import hashlib
import logging
import threading
import time
from typing import Any
import httpx
from app.services.quality_score import (
_HEALTH_BLEND_WEIGHT,
_HEALTH_ENRICH_CONCURRENCY,
_HEALTH_ENRICH_TOP_N_FREE,
_HEALTH_ENRICH_TOP_N_PREMIUM,
_HEALTH_FAIL_RATIO_FALLBACK,
_HEALTH_FETCH_TIMEOUT_SEC,
aggregate_health,
static_score_or,
)
logger = logging.getLogger(__name__)
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models"
OPENROUTER_ENDPOINTS_URL_TEMPLATE = (
"https://openrouter.ai/api/v1/models/{model_id}/endpoints"
)
# Sentinel value stored on each generated config so we can distinguish
# dynamic OpenRouter entries from hand-written YAML entries during refresh.
_OPENROUTER_DYNAMIC_MARKER = "__openrouter_dynamic__"
# Width of the hash space used by ``_stable_config_id``. 9_000_000 provides
# enough headroom to avoid frequent collisions for OpenRouter's catalogue
# (~300 models) while keeping IDs comfortably within Postgres INTEGER range.
_STABLE_ID_HASH_WIDTH = 9_000_000
def _stable_config_id(model_id: str, offset: int, taken: set[int]) -> int:
"""Derive a deterministic negative config ID from ``model_id``.
The same ``model_id`` always hashes to the same base value so thread pins
survive catalogue churn (models appearing/disappearing/reordering between
refreshes). On collision we decrement until we find an unused slot; this
keeps the mapping stable for the first config that claimed a slot and
only shifts collisions, which is much less disruptive than the legacy
index-based scheme that reshuffled every ID when the catalogue changed.
"""
digest = hashlib.blake2b(model_id.encode("utf-8"), digest_size=6).digest()
base = offset - (int.from_bytes(digest, "big") % _STABLE_ID_HASH_WIDTH)
cid = base
while cid in taken:
cid -= 1
taken.add(cid)
return cid
def _openrouter_tier(model: dict) -> str:
"""Classify an OpenRouter model as ``"free"`` or ``"premium"``.
Per OpenRouter's API contract, a model is free if:
- Its id ends with ``:free`` (OpenRouter's own free-variant convention), or
- Both ``pricing.prompt`` and ``pricing.completion`` are zero strings.
Anything else (missing pricing, non-zero pricing) falls through to
``"premium"`` so we never under-charge users. This derivation runs off the
already-cached /api/v1/models payload, so it adds no network cost.
"""
if model.get("id", "").endswith(":free"):
return "free"
pricing = model.get("pricing") or {}
prompt = str(pricing.get("prompt", "")).strip()
completion = str(pricing.get("completion", "")).strip()
if prompt == "0" and completion == "0":
return "free"
return "premium"
def _is_text_output_model(model: dict) -> bool:
"""Return True if the model produces text output only (skip image/audio generators)."""
@ -56,6 +117,11 @@ _EXCLUDED_MODEL_IDS: set[str] = {
# Deep-research models reject standard params (temperature, etc.)
"openai/o3-deep-research",
"openai/o4-mini-deep-research",
# OpenRouter's own meta-router over free models. We already enumerate every
# concrete ``:free`` model into GLOBAL_LLM_CONFIGS and Auto-mode thread
# pinning handles churn via the repair path, so exposing an additional
# indirection layer would only duplicate the capability with an opaque slug.
"openrouter/free",
}
_EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",)
@ -113,20 +179,41 @@ def _generate_configs(
raw_models: list[dict],
settings: dict[str, Any],
) -> list[dict]:
"""
Convert raw OpenRouter model entries into global LLM config dicts.
"""Convert raw OpenRouter model entries into global LLM config dicts.
Models are sorted by ID for deterministic, stable ID assignment across
restarts and refreshes.
Tier (``billing_tier``) is derived per-model from OpenRouter's own API
signals via ``_openrouter_tier`` there is no longer a uniform YAML
override. Config IDs are derived via ``_stable_config_id`` so they
survive catalogue churn across refreshes.
Router-pool membership is tier-aware:
- Premium OR models join the LiteLLM router pool (``router_pool_eligible=True``)
so sub-agent ``model="auto"`` flows benefit from load balancing and
failover across the curated YAML configs and the OR premium passthrough.
- Free OR models stay excluded (``router_pool_eligible=False``). LiteLLM
Router tracks rate limits per deployment, but OpenRouter enforces a
single global free-tier quota (~20 RPM + 50-1000 daily requests
account-wide across every ``:free`` model), so rotating across many
free deployments would only burn the shared bucket faster. Free OR
models remain fully available for user-facing Auto-mode thread pinning
via ``auto_model_pin_service``.
OpenRouter's own ``openrouter/free`` meta-router is filtered out upstream
via ``_EXCLUDED_MODEL_IDS``; we don't expose a redundant auto-select layer
because our own Auto (Fastest) pin + 24 h refresh + repair logic already
cover the catalogue-churn case.
"""
id_offset: int = settings.get("id_offset", -10000)
api_key: str = settings.get("api_key", "")
billing_tier: str = settings.get("billing_tier", "premium")
anonymous_enabled: bool = settings.get("anonymous_enabled", False)
seo_enabled: bool = settings.get("seo_enabled", False)
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
rpm: int = settings.get("rpm", 200)
tpm: int = settings.get("tpm", 1000000)
tpm: int = settings.get("tpm", 1_000_000)
free_rpm: int = settings.get("free_rpm", 20)
free_tpm: int = settings.get("free_tpm", 100_000)
anon_paid: bool = settings.get("anonymous_enabled_paid", False)
anon_free: bool = settings.get("anonymous_enabled_free", False)
litellm_params: dict = settings.get("litellm_params") or {}
system_instructions: str = settings.get("system_instructions", "")
use_default: bool = settings.get("use_default_system_instructions", True)
@ -142,19 +229,24 @@ def _generate_configs(
and _is_allowed_model(m)
and "/" in m.get("id", "")
]
text_models.sort(key=lambda m: m["id"])
configs: list[dict] = []
for idx, model in enumerate(text_models):
taken: set[int] = set()
now_ts = int(time.time())
for model in text_models:
model_id: str = model["id"]
name: str = model.get("name", model_id)
tier = _openrouter_tier(model)
static_q = static_score_or(model, now_ts=now_ts)
cfg: dict[str, Any] = {
"id": id_offset - idx,
"id": _stable_config_id(model_id, id_offset, taken),
"name": name,
"description": f"{name} via OpenRouter",
"billing_tier": billing_tier,
"anonymous_enabled": anonymous_enabled,
"billing_tier": tier,
"anonymous_enabled": anon_free if tier == "free" else anon_paid,
"seo_enabled": seo_enabled,
"seo_slug": None,
"quota_reserve_tokens": quota_reserve_tokens,
@ -162,13 +254,28 @@ def _generate_configs(
"model_name": model_id,
"api_key": api_key,
"api_base": "",
"rpm": rpm,
"tpm": tpm,
"rpm": free_rpm if tier == "free" else rpm,
"tpm": free_tpm if tier == "free" else tpm,
"litellm_params": dict(litellm_params),
"system_instructions": system_instructions,
"use_default_system_instructions": use_default,
"citations_enabled": citations_enabled,
# Premium OR deployments join the LiteLLM router pool so sub-agent
# model="auto" flows can load-balance / fail over across them.
# Free OR deployments stay out: OpenRouter's free tier is a single
# account-wide quota, so per-deployment routing can't spread load
# there — it just drains the shared bucket faster.
"router_pool_eligible": tier == "premium",
_OPENROUTER_DYNAMIC_MARKER: True,
# Auto (Fastest) ranking metadata. ``quality_score`` is initialised
# to the static score and gets re-blended with health on the next
# ``_enrich_health`` pass (synchronous on refresh, deferred on cold
# start so startup latency is unchanged).
"auto_pin_tier": "B" if tier == "premium" else "C",
"quality_score_static": static_q,
"quality_score_health": None,
"quality_score": static_q,
"health_gated": False,
}
configs.append(cfg)
@ -187,6 +294,12 @@ class OpenRouterIntegrationService:
self._configs_by_id: dict[int, dict] = {}
self._initialized = False
self._refresh_task: asyncio.Task | None = None
# Last-good per-model health snapshot. Survives across refresh
# cycles so a transient OpenRouter /endpoints outage doesn't drop
# every cfg back to static-only scoring.
# Shape: {model_name: {"gated": bool, "score": float | None}}
self._health_cache: dict[str, dict[str, Any]] = {}
self._enrich_task: asyncio.Task | None = None
@classmethod
def get_instance(cls) -> "OpenRouterIntegrationService":
@ -220,12 +333,27 @@ class OpenRouterIntegrationService:
self._configs_by_id = {c["id"]: c for c in self._configs}
self._initialized = True
tier_counts = self._tier_counts(self._configs)
logger.info(
"OpenRouter integration: loaded %d models (IDs %d to %d)",
"OpenRouter integration: loaded %d models (free=%d, premium=%d)",
len(self._configs),
self._configs[0]["id"] if self._configs else 0,
self._configs[-1]["id"] if self._configs else 0,
tier_counts["free"],
tier_counts["premium"],
)
# Schedule the first health-enrichment pass as a deferred task so
# cold-start latency is unchanged. Only valid when an event loop is
# already running (e.g. FastAPI lifespan); Celery worker init is
# fully sync so we silently skip — its first refresh tick (or the
# next refresh from the web process) will populate health data.
try:
loop = asyncio.get_running_loop()
self._enrich_task = loop.create_task(
self._enrich_health_safely(self._configs)
)
except RuntimeError:
pass
return self._configs
# ------------------------------------------------------------------
@ -254,7 +382,225 @@ class OpenRouterIntegrationService:
self._configs = new_configs
self._configs_by_id = new_by_id
logger.info("OpenRouter refresh: updated to %d models", len(new_configs))
# Catalogue churn invalidates per-config "recently healthy" credit
# earned by the previous turn's preflight. Drop the whole table so
# the next turn re-probes against the freshly loaded configs.
try:
from app.services.auto_model_pin_service import clear_healthy
clear_healthy()
except Exception:
logger.debug(
"OpenRouter refresh: clear_healthy import skipped", exc_info=True
)
tier_counts = self._tier_counts(new_configs)
logger.info(
"OpenRouter refresh: updated to %d models (free=%d, premium=%d)",
len(new_configs),
tier_counts["free"],
tier_counts["premium"],
)
# Re-blend health scores against the freshly fetched catalogue. Also
# re-stamps health for any YAML-curated cfg with provider==OPENROUTER
# so a hand-picked dead OR model is gated like a dynamic one.
await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
# Rebuild the LiteLLM router so freshly fetched configs flow through
# (dynamic OR premium entries now opt into the pool, free ones stay
# out; a refresh also needs to pick up any static-config edits and
# reset cached context-window profiles).
try:
from app.config import config as _app_config
from app.services.llm_router_service import (
LLMRouterService,
_router_instance_cache as _chat_router_cache,
)
LLMRouterService.rebuild(
_app_config.GLOBAL_LLM_CONFIGS,
getattr(_app_config, "ROUTER_SETTINGS", None),
)
_chat_router_cache.clear()
except Exception as exc:
logger.warning("OpenRouter refresh: router rebuild skipped (%s)", exc)
@staticmethod
def _tier_counts(configs: list[dict]) -> dict[str, int]:
counts = {"free": 0, "premium": 0}
for cfg in configs:
tier = str(cfg.get("billing_tier", "")).lower()
if tier in counts:
counts[tier] += 1
return counts
# ------------------------------------------------------------------
# Auto (Fastest) health enrichment
# ------------------------------------------------------------------
async def _enrich_health_safely(
self, configs: list[dict], *, log_summary: bool = True
) -> None:
"""Wrapper around ``_enrich_health`` that swallows all errors.
Health enrichment is best-effort: any failure must leave cfgs in
their static-only state and never break refresh / startup.
"""
try:
await self._enrich_health(configs, log_summary=log_summary)
except Exception:
logger.exception("OpenRouter health enrichment failed")
async def _enrich_health(
self, configs: list[dict], *, log_summary: bool = True
) -> None:
"""Fetch per-model ``/endpoints`` data for the top OR cfgs and blend
the resulting health score into ``cfg["quality_score"]``.
Bounded fan-out: top-N per tier by ``quality_score_static`` only,
with ``asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY)`` guarding the
outbound HTTP. Misses fall back to a per-model last-good cache; if
the failure ratio crosses ``_HEALTH_FAIL_RATIO_FALLBACK`` we keep
the entire previous cycle's cache for this run.
"""
or_cfgs = [
c for c in configs if str(c.get("provider", "")).upper() == "OPENROUTER"
]
if not or_cfgs:
return
premium_pool = sorted(
[c for c in or_cfgs if str(c.get("billing_tier", "")).lower() == "premium"],
key=lambda c: -int(c.get("quality_score_static") or 0),
)[:_HEALTH_ENRICH_TOP_N_PREMIUM]
free_pool = sorted(
[c for c in or_cfgs if str(c.get("billing_tier", "")).lower() == "free"],
key=lambda c: -int(c.get("quality_score_static") or 0),
)[:_HEALTH_ENRICH_TOP_N_FREE]
# De-duplicate while preserving order: a cfg shouldn't fall in both
# tiers, but defensive code is cheap here.
seen_ids: set[int] = set()
selected: list[dict] = []
for cfg in premium_pool + free_pool:
cid = int(cfg.get("id", 0))
if cid in seen_ids:
continue
seen_ids.add(cid)
selected.append(cfg)
if not selected:
return
api_key = str(self._settings.get("api_key") or "")
semaphore = asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY)
async with httpx.AsyncClient(timeout=_HEALTH_FETCH_TIMEOUT_SEC) as client:
results = await asyncio.gather(
*(
self._fetch_endpoints(client, semaphore, api_key, cfg)
for cfg in selected
)
)
fail_count = sum(1 for _, _, err in results if err is not None)
fail_ratio = fail_count / len(results) if results else 0.0
degraded = fail_ratio >= _HEALTH_FAIL_RATIO_FALLBACK
if degraded:
logger.warning(
"auto_pin_health_enrich_degraded fail_ratio=%.2f total=%d "
"using_last_good_cache=true",
fail_ratio,
len(results),
)
# Per-cfg health update.
for cfg, endpoints, err in results:
model_name = str(cfg.get("model_name", ""))
if not degraded and err is None and endpoints is not None:
gated, h_score = aggregate_health(endpoints)
cfg["health_gated"] = bool(gated)
cfg["quality_score_health"] = h_score
self._health_cache[model_name] = {
"gated": bool(gated),
"score": h_score,
}
else:
cached = self._health_cache.get(model_name)
if cached is not None:
cfg["health_gated"] = bool(cached.get("gated", False))
cfg["quality_score_health"] = cached.get("score")
# else: keep current values (initial defaults from
# _generate_configs / load_global_llm_configs).
# Blend health into the final score for every OR cfg, including
# those outside the enriched top-N (they fall through to static).
gated_count = 0
by_provider: dict[str, int] = {}
for cfg in or_cfgs:
static_q = int(cfg.get("quality_score_static") or 0)
h = cfg.get("quality_score_health")
if h is not None and not cfg.get("health_gated"):
blended = (
_HEALTH_BLEND_WEIGHT * float(h)
+ (1 - _HEALTH_BLEND_WEIGHT) * static_q
)
cfg["quality_score"] = round(blended)
else:
cfg["quality_score"] = static_q
if cfg.get("health_gated"):
gated_count += 1
model_id = str(cfg.get("model_name", ""))
provider_slug = (
model_id.split("/", 1)[0] if "/" in model_id else "unknown"
)
by_provider[provider_slug] = by_provider.get(provider_slug, 0) + 1
if log_summary:
logger.info(
"auto_pin_health_gated count=%d by_provider=%s fail_ratio=%.2f "
"total_enriched=%d",
gated_count,
dict(sorted(by_provider.items(), key=lambda kv: -kv[1])),
fail_ratio,
len(selected),
)
@staticmethod
async def _fetch_endpoints(
client: httpx.AsyncClient,
semaphore: asyncio.Semaphore,
api_key: str,
cfg: dict,
) -> tuple[dict, list[dict] | None, Exception | None]:
"""Fetch ``/api/v1/models/{id}/endpoints`` for one cfg.
Returns ``(cfg, endpoints, err)`` so the caller can keep batched
results aligned with their cfgs without raising.
"""
model_id = str(cfg.get("model_name", ""))
if not model_id:
return cfg, None, ValueError("missing model_name")
url = OPENROUTER_ENDPOINTS_URL_TEMPLATE.format(model_id=model_id)
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
async with semaphore:
try:
resp = await client.get(url, headers=headers)
resp.raise_for_status()
data = resp.json()
except Exception as exc:
return cfg, None, exc
payload = data.get("data") if isinstance(data, dict) else None
if not isinstance(payload, dict):
return cfg, None, ValueError("malformed endpoints payload")
endpoints = payload.get("endpoints")
if not isinstance(endpoints, list):
return cfg, [], None
return cfg, endpoints, None
async def _refresh_loop(self, interval_hours: float) -> None:
interval_sec = interval_hours * 3600

View file

@ -0,0 +1,380 @@
"""Pure-function quality scoring for Auto (Fastest) model selection.
This module is import-free of any service / request-path dependencies. All
numbers are computed once during the OpenRouter refresh tick (or YAML load)
and cached on the cfg dict, so the chat hot path only does a precomputed
sort and a SHA256 pick.
Score components (0-100 scale, higher is better):
* ``static_score_or`` - derived from the bulk ``/api/v1/models`` payload
(provider prestige + ``created`` recency + pricing band + context window
+ capabilities + narrow tiny/legacy slug penalty).
* ``static_score_yaml`` - same shape for hand-curated YAML configs, plus
an operator-trust bonus (the operator deliberately picked this model).
* ``aggregate_health`` - run on per-model ``/api/v1/models/{id}/endpoints``
responses; returns ``(gated, score_or_none)``.
The blended ``quality_score`` (0.5 * static + 0.5 * health) is computed in
:mod:`app.services.openrouter_integration_service` because that's the only
caller that sees both halves.
"""
from __future__ import annotations
# ---------------------------------------------------------------------------
# Tunables (constants, not flags)
# ---------------------------------------------------------------------------
# Top-K size for deterministic spread inside the locked tier.
_QUALITY_TOP_K: int = 5
# Hard health gate: any cfg whose best non-null uptime is below this %
# is excluded from Auto-mode selection entirely.
_HEALTH_GATE_UPTIME_PCT: float = 90.0
# Health/static blend weight when a cfg has fresh /endpoints data.
_HEALTH_BLEND_WEIGHT: float = 0.5
# Static bonus applied to YAML cfgs because the operator hand-picked them.
_OPERATOR_TRUST_BONUS: int = 20
# /endpoints fan-out is bounded per refresh tick.
_HEALTH_ENRICH_TOP_N_PREMIUM: int = 50
_HEALTH_ENRICH_TOP_N_FREE: int = 30
_HEALTH_ENRICH_CONCURRENCY: int = 15
_HEALTH_FETCH_TIMEOUT_SEC: float = 5.0
# If at least this fraction of /endpoints fetches fail in a refresh cycle,
# fall back to the previous cycle's last-good cache instead of writing
# partial / stale health values.
_HEALTH_FAIL_RATIO_FALLBACK: float = 0.25
# Narrow tiny/legacy slug penalties only. We deliberately do NOT penalise
# ``-nano`` / ``-mini`` / ``-lite`` because modern frontier models ship with
# those naming patterns (``gpt-5-mini``, ``gemini-2.5-flash-lite`` etc.) and
# blanket-penalising them suppresses high-quality picks.
_TINY_LEGACY_PENALTY_PATTERNS: tuple[str, ...] = (
"-1b-",
"-1.2b-",
"-1.5b-",
"-2b-",
"-3b-",
"gemma-3n",
"lfm-",
"-base",
"-distill",
":nitro",
"-preview",
)
# ---------------------------------------------------------------------------
# Provider prestige tables
# ---------------------------------------------------------------------------
# OpenRouter-side provider slug (the prefix before ``/`` in the model id).
# Tiers are coarse: frontier labs > strong open / fast-moving labs >
# specialist labs > everything else.
PROVIDER_PRESTIGE_OR: dict[str, int] = {
# Frontier labs
"openai": 50,
"anthropic": 50,
"google": 50,
"x-ai": 50,
# Strong open / fast-moving labs
"deepseek": 38,
"qwen": 38,
"meta-llama": 38,
"mistralai": 38,
"cohere": 38,
"nvidia": 38,
"alibaba": 38,
# Specialist / regional / strong second-tier
"microsoft": 28,
"01-ai": 28,
"minimax": 28,
"moonshot": 28,
"z-ai": 28,
"nousresearch": 28,
"ai21": 28,
"perplexity": 28,
# Smaller / niche providers
"liquid": 18,
"cognitivecomputations": 18,
"venice": 18,
"inflection": 18,
}
# YAML provider field (the upstream API shape the operator selected).
PROVIDER_PRESTIGE_YAML: dict[str, int] = {
"AZURE_OPENAI": 50,
"OPENAI": 50,
"ANTHROPIC": 50,
"GOOGLE": 50,
"VERTEX_AI": 50,
"GEMINI": 50,
"XAI": 50,
"MISTRAL": 38,
"DEEPSEEK": 38,
"COHERE": 38,
"GROQ": 30,
"TOGETHER_AI": 28,
"FIREWORKS_AI": 28,
"PERPLEXITY": 28,
"MINIMAX": 28,
"BEDROCK": 28,
"OPENROUTER": 25,
"OLLAMA": 12,
"CUSTOM": 12,
}
# ---------------------------------------------------------------------------
# Pure scoring helpers
# ---------------------------------------------------------------------------
# Calibrated against the live /api/v1/models bulk dump. Frontier models
# released in the last ~6 months (GPT-5 family, Claude 4.x, Gemini 2.5,
# Grok 4) score in the 18-20 band; mid-2024 models in the 8-12 band;
# anything older trails off.
_RECENCY_BANDS_DAYS: tuple[tuple[int, int], ...] = (
(60, 20),
(180, 16),
(365, 12),
(540, 9),
(730, 6),
(1095, 3),
)
def created_recency_signal(created_ts: int | None, now_ts: int) -> int:
"""Return 0-20 based on how recently the model was published.
Uses the OpenRouter ``created`` Unix timestamp (or any equivalent for
YAML cfgs). Models without a usable timestamp get 0 (we don't penalise,
we just don't reward).
"""
if created_ts is None or created_ts <= 0 or now_ts <= 0:
return 0
age_days = max(0, (now_ts - int(created_ts)) // 86_400)
for cutoff, score in _RECENCY_BANDS_DAYS:
if age_days <= cutoff:
return score
return 0
def pricing_band(
prompt: str | float | int | None,
completion: str | float | int | None,
) -> int:
"""Return 0-15 based on combined prompt+completion cost per 1M tokens.
Higher-priced models tend to be the larger / more capable ones. A free
model returns 0 (we use other signals to rank free-vs-free instead).
Uncoercible inputs are treated as 0 rather than raising.
"""
def _to_float(value) -> float:
if value is None:
return 0.0
try:
return float(value)
except (TypeError, ValueError):
return 0.0
p = _to_float(prompt)
c = _to_float(completion)
total_per_million = (p + c) * 1_000_000
if total_per_million >= 20.0:
return 15
if total_per_million >= 5.0:
return 12
if total_per_million >= 1.0:
return 9
if total_per_million >= 0.3:
return 6
if total_per_million >= 0.05:
return 4
if total_per_million > 0.0:
return 2
return 0
def context_signal(ctx: int | None) -> int:
"""Return 0-10 based on the model's context window."""
if not ctx or ctx <= 0:
return 0
if ctx >= 1_000_000:
return 10
if ctx >= 400_000:
return 8
if ctx >= 200_000:
return 6
if ctx >= 128_000:
return 4
if ctx >= 100_000:
return 2
return 0
def capabilities_signal(supported_parameters: list[str] | None) -> int:
"""Return 0-5 for capabilities that matter for our agent flows."""
if not supported_parameters:
return 0
params = set(supported_parameters)
score = 0
if "tools" in params:
score += 2
if "structured_outputs" in params or "response_format" in params:
score += 2
if "reasoning" in params or "include_reasoning" in params:
score += 1
return min(score, 5)
def slug_penalty(model_id: str) -> int:
"""Return a non-positive number; matches the narrow tiny/legacy patterns."""
if not model_id:
return 0
needle = model_id.lower()
for pattern in _TINY_LEGACY_PENALTY_PATTERNS:
if pattern in needle:
return -10
return 0
def _provider_prestige_or(model_id: str) -> int:
if "/" not in model_id:
return 0
slug = model_id.split("/", 1)[0].lower()
return PROVIDER_PRESTIGE_OR.get(slug, 15)
def static_score_or(or_model: dict, *, now_ts: int) -> int:
"""Score a raw OpenRouter ``/api/v1/models`` entry on a 0-100 scale."""
model_id = str(or_model.get("id", ""))
pricing = or_model.get("pricing") or {}
score = (
_provider_prestige_or(model_id)
+ created_recency_signal(or_model.get("created"), now_ts)
+ pricing_band(pricing.get("prompt"), pricing.get("completion"))
+ context_signal(or_model.get("context_length"))
+ capabilities_signal(or_model.get("supported_parameters"))
+ slug_penalty(model_id)
)
return max(0, min(100, int(score)))
def static_score_yaml(cfg: dict) -> int:
"""Score a YAML-curated cfg on a 0-100 scale.
Includes ``_OPERATOR_TRUST_BONUS`` because the operator deliberately
listed this model. Pricing / context fall through to lazy ``litellm``
lookups; failures are silent (we just lose those sub-points).
"""
provider = str(cfg.get("provider", "")).upper()
base = PROVIDER_PRESTIGE_YAML.get(provider, 15)
model_name = cfg.get("model_name") or ""
litellm_params = cfg.get("litellm_params") or {}
lookup_name = (
litellm_params.get("base_model") or litellm_params.get("model") or model_name
)
ctx = 0
p_cost: float = 0.0
c_cost: float = 0.0
try:
from litellm import get_model_info # lazy: avoid cold-import cost
info = get_model_info(lookup_name) or {}
ctx = int(info.get("max_input_tokens") or info.get("max_tokens") or 0)
p_cost = float(info.get("input_cost_per_token") or 0.0)
c_cost = float(info.get("output_cost_per_token") or 0.0)
except Exception:
# Unknown to litellm — that's fine for prestige+operator-bonus weighting.
pass
score = (
base
+ _OPERATOR_TRUST_BONUS
+ pricing_band(p_cost, c_cost)
+ context_signal(ctx)
+ slug_penalty(str(model_name))
)
return max(0, min(100, int(score)))
# ---------------------------------------------------------------------------
# Health aggregation
# ---------------------------------------------------------------------------
def _coerce_pct(value) -> float | None:
try:
if value is None:
return None
f = float(value)
except (TypeError, ValueError):
return None
if f < 0:
return None
# OpenRouter reports uptime as a 0-1 fraction; some endpoints surface it
# as a 0-100 percentage. Normalise.
return f * 100.0 if f <= 1.0 else f
def _best_uptime(endpoints: list[dict]) -> tuple[float | None, str | None]:
"""Pick the best (highest) non-null uptime across all endpoints.
Window preference: ``uptime_last_30m`` > ``uptime_last_1d`` >
``uptime_last_5m``. Returns ``(uptime_pct, window_used)``.
"""
for window in ("uptime_last_30m", "uptime_last_1d", "uptime_last_5m"):
values = [_coerce_pct(ep.get(window)) for ep in endpoints]
values = [v for v in values if v is not None]
if values:
return max(values), window
return None, None
def aggregate_health(endpoints: list[dict]) -> tuple[bool, float | None]:
"""Aggregate a model's per-endpoint health into ``(gated, score_or_none)``.
Hard gate (returns ``(True, None)``):
* ``endpoints`` empty,
* no endpoint reports ``status == 0`` (OK), or
* best non-null uptime below ``_HEALTH_GATE_UPTIME_PCT``.
On a pass, returns a 0-100 health score blending uptime, status, and a
freshness-weighted recent uptime sample.
"""
if not endpoints:
return True, None
any_ok = any(int(ep.get("status", 1)) == 0 for ep in endpoints)
if not any_ok:
return True, None
best_uptime, _ = _best_uptime(endpoints)
if best_uptime is None or best_uptime < _HEALTH_GATE_UPTIME_PCT:
return True, None
# Freshness term: prefer 5m, fall through to 30m / 1d if 5m is missing.
freshness = None
for window in ("uptime_last_5m", "uptime_last_30m", "uptime_last_1d"):
values = [_coerce_pct(ep.get(window)) for ep in endpoints]
values = [v for v in values if v is not None]
if values:
freshness = max(values)
break
uptime_term = best_uptime
status_term = 100.0 if any_ok else 0.0
freshness_term = freshness if freshness is not None else best_uptime
score = 0.50 * uptime_term + 0.30 * status_term + 0.20 * freshness_term
return False, max(0.0, min(100.0, score))

View file

@ -8,7 +8,9 @@ Operation outcomes mirror the plan:
* **KB-owned actions** (NOTE / FILE / FOLDER mutations): restore from
:class:`app.db.DocumentRevision` / :class:`app.db.FolderRevision` rows
written before the original mutation.
written before the original mutation. ``rm``/``rmdir`` re-INSERT a fresh
row from the snapshot; ``write_file`` create / ``mkdir`` DELETE the row
that was created; everything else is an in-place restore.
* **Connector-owned actions with a declared ``reverse_descriptor``**: invoke
the inverse tool through the agent's normal permission stack (NOT
bypassed). Out of scope for this PR returns ``REVERSE_NOT_IMPLEMENTED``.
@ -18,6 +20,11 @@ Operation outcomes mirror the plan:
A successful revert appends a NEW row to ``agent_action_log`` with
``reverse_of=<original_action_id>`` and the requesting user's
``user_id``, preserving an auditable chain.
Dispatch must be exact-match (``tool_name == name``), NOT prefix matching.
``"rmdir".startswith("rm")`` would otherwise mis-route directory revert
to the document branch (and ``delete_note`` vs ``delete_folder`` is the
same trap waiting to happen).
"""
from __future__ import annotations
@ -25,17 +32,31 @@ from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Literal
from typing import Any, Literal
from sqlalchemy import select
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.path_resolver import (
DOCUMENTS_ROOT,
safe_filename,
safe_folder_segment,
)
from app.db import (
AgentActionLog,
Chunk,
Document,
DocumentRevision,
DocumentType,
Folder,
FolderRevision,
NewChatThread,
)
from app.utils.document_converters import (
embed_texts,
generate_content_hash,
generate_unique_identifier_hash,
)
logger = logging.getLogger(__name__)
@ -110,14 +131,244 @@ def can_revert(
# ---------------------------------------------------------------------------
# Revert paths
# Helper: reconstruct virtual path from a snapshot
# ---------------------------------------------------------------------------
async def _virtual_path_from_snapshot(
session: AsyncSession,
revision: DocumentRevision,
) -> str | None:
"""Reconstruct the virtual_path the document was at before mutation.
Preference order:
1. ``metadata_before["virtual_path"]`` written by every snapshot
helper since this PR.
2. Compose ``"<folder_path>/<title_before>"`` from
``folder_id_before`` + ``title_before``. Walks the folder chain via
``parent_id``.
"""
metadata = revision.metadata_before or {}
candidate = metadata.get("virtual_path") if isinstance(metadata, dict) else None
if isinstance(candidate, str) and candidate.startswith(DOCUMENTS_ROOT):
return candidate
title = revision.title_before
if not isinstance(title, str) or not title:
return None
parts: list[str] = []
cursor: int | None = revision.folder_id_before
visited: set[int] = set()
while cursor is not None and cursor not in visited:
visited.add(cursor)
folder = await session.get(Folder, cursor)
if folder is None:
return None
parts.append(safe_folder_segment(str(folder.name or "")))
cursor = folder.parent_id
parts.reverse()
base = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT
filename = safe_filename(title)
return f"{base}/{filename}"
# ---------------------------------------------------------------------------
# Document revision restore (write/edit/move/rm)
# ---------------------------------------------------------------------------
def _set_field(target: Any, field: str, value: Any) -> None:
if value is not None:
setattr(target, field, value)
async def _restore_in_place_document(
session: AsyncSession,
*,
revision: DocumentRevision,
) -> RevertOutcome:
"""Apply an in-place restore to an existing :class:`Document`."""
if revision.document_id is None:
return RevertOutcome(
status="tool_unavailable",
message=(
"Original document was hard-deleted; in-place restore is not possible."
),
)
doc = await session.get(Document, revision.document_id)
if doc is None:
return RevertOutcome(
status="tool_unavailable",
message="Original document has been deleted; revert cannot proceed.",
)
_set_field(doc, "content", revision.content_before)
_set_field(doc, "source_markdown", revision.content_before)
_set_field(doc, "title", revision.title_before)
_set_field(doc, "folder_id", revision.folder_id_before)
metadata_before = revision.metadata_before or {}
if isinstance(metadata_before, dict) and metadata_before:
doc.document_metadata = dict(metadata_before)
if isinstance(revision.content_before, str):
doc.content_hash = generate_content_hash(
revision.content_before, doc.search_space_id
)
virtual_path = await _virtual_path_from_snapshot(session, revision)
if virtual_path:
doc.unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.NOTE,
virtual_path,
doc.search_space_id,
)
chunks_before = revision.chunks_before
if isinstance(chunks_before, list):
await session.execute(delete(Chunk).where(Chunk.document_id == doc.id))
chunk_texts = [
str(c.get("content"))
for c in chunks_before
if isinstance(c, dict) and isinstance(c.get("content"), str)
]
if chunk_texts:
chunk_embeddings = embed_texts(chunk_texts)
session.add_all(
[
Chunk(document_id=doc.id, content=text, embedding=embedding)
for text, embedding in zip(
chunk_texts, chunk_embeddings, strict=True
)
]
)
if isinstance(revision.content_before, str):
doc.embedding = embed_texts([revision.content_before])[0]
doc.updated_at = datetime.now(UTC)
return RevertOutcome(status="ok", message="Document restored from snapshot.")
async def _reinsert_document_from_revision(
session: AsyncSession,
*,
revision: DocumentRevision,
) -> RevertOutcome:
"""Re-INSERT a deleted :class:`Document` from a snapshot row (``rm`` revert)."""
if not isinstance(revision.title_before, str) or not revision.title_before:
return RevertOutcome(
status="not_reversible",
message="Snapshot lacks title_before; cannot recreate document.",
)
if not isinstance(revision.content_before, str):
return RevertOutcome(
status="not_reversible",
message="Snapshot lacks content_before; cannot recreate document.",
)
virtual_path = await _virtual_path_from_snapshot(session, revision)
if not virtual_path:
return RevertOutcome(
status="not_reversible",
message=(
"Snapshot is missing both metadata_before['virtual_path'] AND "
"a resolvable (folder_id_before, title_before) pair."
),
)
search_space_id = revision.search_space_id
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.NOTE,
virtual_path,
search_space_id,
)
collision = await session.execute(
select(Document.id).where(
Document.search_space_id == search_space_id,
Document.unique_identifier_hash == unique_identifier_hash,
)
)
if collision.scalar_one_or_none() is not None:
return RevertOutcome(
status="tool_unavailable",
message=(
f"A document already exists at '{virtual_path}'; revert would "
"collide. Move the live doc out of the way first."
),
)
metadata = revision.metadata_before or {}
if not isinstance(metadata, dict):
metadata = {}
metadata = dict(metadata)
metadata["virtual_path"] = virtual_path
content = revision.content_before
new_doc = Document(
title=revision.title_before,
document_type=DocumentType.NOTE,
document_metadata=metadata,
content=content,
content_hash=generate_content_hash(content, search_space_id),
unique_identifier_hash=unique_identifier_hash,
source_markdown=content,
search_space_id=search_space_id,
folder_id=revision.folder_id_before,
updated_at=datetime.now(UTC),
)
session.add(new_doc)
await session.flush()
new_doc.embedding = embed_texts([content])[0]
chunk_texts = []
chunks_before = revision.chunks_before
if isinstance(chunks_before, list):
chunk_texts = [
str(c.get("content"))
for c in chunks_before
if isinstance(c, dict) and isinstance(c.get("content"), str)
]
if chunk_texts:
chunk_embeddings = embed_texts(chunk_texts)
session.add_all(
[
Chunk(document_id=new_doc.id, content=text, embedding=embedding)
for text, embedding in zip(chunk_texts, chunk_embeddings, strict=True)
]
)
# Repoint the snapshot at the recreated row so a follow-up revert of
# the same row works as expected.
revision.document_id = new_doc.id
return RevertOutcome(
status="ok",
message=f"Re-inserted document '{revision.title_before}' from snapshot.",
)
async def _delete_created_document(
session: AsyncSession,
*,
revision: DocumentRevision,
) -> RevertOutcome:
"""Delete the document that ``write_file`` created (``content_before IS NULL``)."""
if revision.document_id is None:
return RevertOutcome(
status="ok",
message="No live row to delete (already removed elsewhere).",
)
await session.execute(delete(Document).where(Document.id == revision.document_id))
return RevertOutcome(
status="ok",
message="Deleted the document that was created by this action.",
)
async def _restore_document_revision(
session: AsyncSession, *, action: AgentActionLog
) -> RevertOutcome:
"""Restore the most recent :class:`DocumentRevision` for ``action``."""
"""Dispatch document-level revert based on ``action.tool_name``."""
stmt = (
select(DocumentRevision)
.where(DocumentRevision.agent_action_id == action.id)
@ -132,23 +383,111 @@ async def _restore_document_revision(
message="No document_revisions row tied to this action.",
)
from app.db import Document # late import to avoid cycles at module load
tool_name = (action.tool_name or "").lower()
doc = await session.get(Document, revision.document_id)
if doc is None:
if tool_name == "rm":
return await _reinsert_document_from_revision(session, revision=revision)
if tool_name == "write_file" and revision.content_before is None:
return await _delete_created_document(session, revision=revision)
return await _restore_in_place_document(session, revision=revision)
# ---------------------------------------------------------------------------
# Folder revision restore (mkdir/rmdir/rename/move)
# ---------------------------------------------------------------------------
async def _restore_in_place_folder(
session: AsyncSession,
*,
revision: FolderRevision,
) -> RevertOutcome:
if revision.folder_id is None:
return RevertOutcome(
status="tool_unavailable",
message="Original document has been deleted; revert cannot proceed.",
message="Original folder was hard-deleted; in-place restore is impossible.",
)
folder = await session.get(Folder, revision.folder_id)
if folder is None:
return RevertOutcome(
status="tool_unavailable",
message="Original folder has been deleted; revert cannot proceed.",
)
_set_field(folder, "name", revision.name_before)
_set_field(folder, "parent_id", revision.parent_id_before)
_set_field(folder, "position", revision.position_before)
folder.updated_at = datetime.now(UTC)
return RevertOutcome(status="ok", message="Folder restored from snapshot.")
async def _reinsert_folder_from_revision(
session: AsyncSession,
*,
revision: FolderRevision,
) -> RevertOutcome:
if not isinstance(revision.name_before, str) or not revision.name_before:
return RevertOutcome(
status="not_reversible",
message="Snapshot lacks name_before; cannot recreate folder.",
)
new_folder = Folder(
name=revision.name_before,
parent_id=revision.parent_id_before,
position=revision.position_before,
search_space_id=revision.search_space_id,
updated_at=datetime.now(UTC),
)
session.add(new_folder)
await session.flush()
revision.folder_id = new_folder.id
return RevertOutcome(
status="ok",
message=f"Re-inserted folder '{revision.name_before}' from snapshot.",
)
async def _delete_created_folder(
session: AsyncSession,
*,
revision: FolderRevision,
) -> RevertOutcome:
if revision.folder_id is None:
return RevertOutcome(
status="ok",
message="No live folder row to delete (already removed elsewhere).",
)
folder_id = revision.folder_id
has_doc = await session.execute(
select(Document.id).where(Document.folder_id == folder_id).limit(1)
)
if has_doc.scalar_one_or_none() is not None:
return RevertOutcome(
status="tool_unavailable",
message=(
"Folder is no longer empty (documents have been added since "
"mkdir); cannot revert."
),
)
has_child = await session.execute(
select(Folder.id).where(Folder.parent_id == folder_id).limit(1)
)
if has_child.scalar_one_or_none() is not None:
return RevertOutcome(
status="tool_unavailable",
message=(
"Folder is no longer empty (sub-folders have been added "
"since mkdir); cannot revert."
),
)
if revision.content_before is not None:
doc.content = revision.content_before
if revision.title_before is not None:
doc.title = revision.title_before
if revision.folder_id_before is not None:
doc.folder_id = revision.folder_id_before
doc.updated_at = datetime.now(UTC)
return RevertOutcome(status="ok", message="Document restored from snapshot.")
await session.execute(delete(Folder).where(Folder.id == folder_id))
return RevertOutcome(
status="ok",
message="Deleted the folder that was created by this action.",
)
async def _restore_folder_revision(
@ -168,41 +507,44 @@ async def _restore_folder_revision(
message="No folder_revisions row tied to this action.",
)
from app.db import Folder
tool_name = (action.tool_name or "").lower()
folder = await session.get(Folder, revision.folder_id)
if folder is None:
return RevertOutcome(
status="tool_unavailable",
message="Original folder has been deleted; revert cannot proceed.",
)
if tool_name == "rmdir":
return await _reinsert_folder_from_revision(session, revision=revision)
if revision.name_before is not None:
folder.name = revision.name_before
if revision.parent_id_before is not None:
folder.parent_id = revision.parent_id_before
if revision.position_before is not None:
folder.position = revision.position_before
folder.updated_at = datetime.now(UTC)
return RevertOutcome(status="ok", message="Folder restored from snapshot.")
if tool_name == "mkdir":
return await _delete_created_folder(session, revision=revision)
return await _restore_in_place_folder(session, revision=revision)
# Tool-name prefixes that route to KB document / folder revert paths. Kept
# as data so a future PR adding new KB-owned tools doesn't have to touch
# this module's control flow.
_DOC_TOOL_PREFIXES: tuple[str, ...] = (
"edit_file",
"write_file",
"update_memory",
"create_note",
"update_note",
"delete_note",
# ---------------------------------------------------------------------------
# Dispatch
# ---------------------------------------------------------------------------
#
# Exact-name dispatch: ``tool_name == name``, NOT ``startswith(...)``.
# Prefix-matching mis-routes pairs like ``rm``/``rmdir`` and
# ``delete_note``/``delete_folder``.
_DOC_TOOLS: frozenset[str] = frozenset(
{
"edit_file",
"write_file",
"move_file",
"rm",
"update_memory",
"create_note",
"update_note",
"delete_note",
}
)
_FOLDER_TOOL_PREFIXES: tuple[str, ...] = (
"mkdir",
"move_file",
"rename_folder",
"delete_folder",
_FOLDER_TOOLS: frozenset[str] = frozenset(
{
"mkdir",
"rmdir",
"rename_folder",
"delete_folder",
}
)
@ -220,9 +562,9 @@ async def revert_action(
"""
tool_name = (action.tool_name or "").lower()
if tool_name.startswith(_DOC_TOOL_PREFIXES):
if tool_name in _DOC_TOOLS:
outcome = await _restore_document_revision(session, action=action)
elif tool_name.startswith(_FOLDER_TOOL_PREFIXES):
elif tool_name in _FOLDER_TOOLS:
outcome = await _restore_folder_revision(session, action=action)
elif action.reverse_descriptor:
# Connector-owned reversibles run through the normal permission

File diff suppressed because it is too large Load diff

View file

@ -74,7 +74,7 @@ dependencies = [
"deepagents>=0.4.12",
"stripe>=15.0.0",
"azure-ai-documentintelligence>=1.0.2",
"litellm>=1.83.4",
"litellm>=1.83.7",
"langchain-litellm>=0.6.4",
]

View file

@ -226,6 +226,31 @@ class TestCompose:
# Default block should NOT be present
assert "<knowledge_base_only_policy>" not in prompt
def test_provider_hints_render_with_custom_system_instructions(
self, fixed_today: datetime
) -> None:
"""Regression guard for the always-append decision: provider hints
append AFTER a custom system prompt.
Provider hints are stylistic nudges (parallel tool-call rules,
formatting guidance, etc.) that help the model regardless of
what the system instructions say. Suppressing them when a
custom prompt is set would partially defeat the per-family
prompt machinery.
"""
prompt = compose_system_prompt(
today=fixed_today,
custom_system_instructions="You are a custom assistant.",
model_name="anthropic/claude-3-5-sonnet",
)
assert "You are a custom assistant." in prompt
assert "<provider_hints>" in prompt
# The custom prompt must come BEFORE the provider hints so the
# user's framing isn't drowned out by the stylistic nudges.
assert prompt.index("You are a custom assistant.") < prompt.index(
"<provider_hints>"
)
def test_use_default_false_with_no_custom_yields_no_system_block(
self, fixed_today: datetime
) -> None:

View file

@ -15,6 +15,17 @@ from app.agents.new_chat.middleware.action_log import ActionLogMiddleware
from app.agents.new_chat.tools.registry import ToolDefinition
@dataclass
class _FakeRuntime:
"""Minimal stand-in for ``ToolRuntime`` used in unit tests.
``ActionLogMiddleware`` reads ``runtime.config['configurable']['turn_id']``
to populate the new ``chat_turn_id`` column (see migration 135).
"""
config: dict[str, Any] | None = None
@dataclass
class _FakeRequest:
"""Minimal stand-in for ToolCallRequest used in unit tests."""
@ -120,6 +131,9 @@ class TestActionLogMiddlewarePersistence:
"args": {"color": "red", "size": 3},
"id": "tc-abc",
},
runtime=_FakeRuntime(
config={"configurable": {"turn_id": "42:1700000000000"}}
),
)
result_msg = ToolMessage(content="ok", tool_call_id="tc-abc", id="msg-1")
handler = AsyncMock(return_value=result_msg)
@ -142,6 +156,32 @@ class TestActionLogMiddlewarePersistence:
assert row.error is None
assert row.reverse_descriptor is None
assert row.reversible is False
# Migration 135: ``turn_id`` is the deprecated alias of ``tool_call_id``;
# ``chat_turn_id`` comes from ``runtime.config['configurable']['turn_id']``.
assert row.tool_call_id == "tc-abc"
assert row.turn_id == "tc-abc"
assert row.chat_turn_id == "42:1700000000000"
@pytest.mark.asyncio
async def test_chat_turn_id_none_when_runtime_missing(
self, patch_get_flags, fake_session_factory
) -> None:
"""``chat_turn_id`` falls back to NULL when ``runtime.config`` is absent."""
captured, factory = fake_session_factory
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
request = _FakeRequest(
tool_call={"name": "make_widget", "args": {}, "id": "tc-1"},
runtime=None,
)
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc-1"))
with (
patch_get_flags(_enabled_flags()),
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
):
await mw.awrap_tool_call(request, handler)
row = captured["rows"][0]
assert row.tool_call_id == "tc-1"
assert row.chat_turn_id is None
@pytest.mark.asyncio
async def test_writes_row_on_failure_and_reraises(
@ -293,6 +333,76 @@ class TestReverseDescriptor:
assert row.reversible is False
class TestActionLogDispatch:
"""Verify ``adispatch_custom_event`` fires after commit."""
@pytest.mark.asyncio
async def test_dispatches_action_log_event_on_success(
self, patch_get_flags, fake_session_factory
) -> None:
_captured, factory = fake_session_factory
mw = ActionLogMiddleware(thread_id=42, search_space_id=7, user_id="u1")
request = _FakeRequest(
tool_call={
"name": "make_widget",
"args": {"color": "red"},
"id": "tc-evt",
},
runtime=_FakeRuntime(
config={"configurable": {"turn_id": "42:1700000000000"}}
),
)
result_msg = ToolMessage(content="ok", tool_call_id="tc-evt", id="msg-42")
handler = AsyncMock(return_value=result_msg)
dispatch_mock = AsyncMock()
with (
patch_get_flags(_enabled_flags()),
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
patch(
"app.agents.new_chat.middleware.action_log.adispatch_custom_event",
dispatch_mock,
),
):
await mw.awrap_tool_call(request, handler)
dispatch_mock.assert_awaited_once()
call_args = dispatch_mock.await_args
assert call_args is not None
assert call_args.args[0] == "action_log"
payload = call_args.args[1]
assert payload["lc_tool_call_id"] == "tc-evt"
assert payload["chat_turn_id"] == "42:1700000000000"
assert payload["tool_name"] == "make_widget"
assert payload["reversible"] is False
assert payload["reverse_descriptor_present"] is False
assert payload["error"] is False
@pytest.mark.asyncio
async def test_no_dispatch_when_persistence_fails(self, patch_get_flags) -> None:
"""If commit fails the dispatch is suppressed (no row to surface)."""
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
request = _FakeRequest(
tool_call={"name": "make_widget", "args": {}, "id": "tc1"}
)
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
dispatch_mock = AsyncMock()
def _exploding_session():
raise RuntimeError("DB is down")
with (
patch_get_flags(_enabled_flags()),
patch("app.db.shielded_async_session", side_effect=_exploding_session),
patch(
"app.agents.new_chat.middleware.action_log.adispatch_custom_event",
dispatch_mock,
),
):
await mw.awrap_tool_call(request, handler)
dispatch_mock.assert_not_awaited()
class TestArgsTruncation:
@pytest.mark.asyncio
async def test_huge_args_payload_is_truncated(

View file

@ -7,7 +7,9 @@ import pytest
from app.agents.new_chat.errors import BusyError
from app.agents.new_chat.middleware.busy_mutex import (
BusyMutexMiddleware,
end_turn,
get_cancel_event,
is_cancel_requested,
manager,
request_cancel,
reset_cancel,
@ -88,3 +90,65 @@ async def test_no_thread_id_skipped_when_not_required() -> None:
def test_reset_cancel_idempotent() -> None:
# Should not raise even if event was never created
reset_cancel("never-seen")
def test_request_cancel_creates_event_for_unseen_thread() -> None:
thread_id = "never-seen-cancel"
reset_cancel(thread_id)
assert request_cancel(thread_id) is True
assert get_cancel_event(thread_id).is_set()
assert is_cancel_requested(thread_id) is True
@pytest.mark.asyncio
async def test_end_turn_force_clears_lock_and_cancel_state() -> None:
thread_id = "forced-end-turn"
mw = BusyMutexMiddleware()
runtime = _Runtime(thread_id)
await mw.abefore_agent({}, runtime)
assert manager.lock_for(thread_id).locked()
request_cancel(thread_id)
assert is_cancel_requested(thread_id) is True
end_turn(thread_id)
assert not manager.lock_for(thread_id).locked()
assert not get_cancel_event(thread_id).is_set()
assert is_cancel_requested(thread_id) is False
@pytest.mark.asyncio
async def test_busy_mutex_stale_aafter_does_not_release_new_attempt_lock() -> None:
"""A stale aafter call from attempt A must not unlock attempt B.
Repro flow:
1) attempt A acquires thread lock
2) forced end_turn clears A so retry can proceed
3) attempt B acquires same thread lock
4) stale attempt-A aafter runs late
Expected: B lock remains held.
"""
thread_id = "stale-aafter-lock"
runtime = _Runtime(thread_id)
attempt_a = BusyMutexMiddleware()
attempt_b = BusyMutexMiddleware()
await attempt_a.abefore_agent({}, runtime)
lock = manager.lock_for(thread_id)
assert lock.locked()
end_turn(thread_id)
assert not lock.locked()
await attempt_b.abefore_agent({}, runtime)
assert lock.locked()
# Stale cleanup from attempt A must not release attempt B's lock.
await attempt_a.aafter_agent({}, runtime)
assert lock.locked()
await attempt_b.aafter_agent({}, runtime)

View file

@ -0,0 +1,122 @@
"""Tests for the desktop-mode safety ruleset.
In desktop mode the agent operates against the user's real disk with no
revision history, so destructive filesystem operations must require
explicit approval. These tests pin the set of tools that get the ``ask``
gate so it cannot silently regress.
"""
from __future__ import annotations
import pytest
from app.agents.new_chat.middleware.permission import PermissionMiddleware
from app.agents.new_chat.permissions import (
Rule,
Ruleset,
aggregate_action,
evaluate_many,
)
pytestmark = pytest.mark.unit
# Mirror the ruleset built inside ``chat_deepagent._build_compiled_agent_blocking``
# when ``filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER``. Keeping a
# copy here means the rule contract has a focused regression test even when
# the larger graph-build helper is hard to instantiate in unit tests.
DESKTOP_SAFETY_RULESET = Ruleset(
rules=[
Rule(permission="rm", pattern="*", action="ask"),
Rule(permission="rmdir", pattern="*", action="ask"),
Rule(permission="move_file", pattern="*", action="ask"),
Rule(permission="edit_file", pattern="*", action="ask"),
Rule(permission="write_file", pattern="*", action="ask"),
],
origin="desktop_safety",
)
SURFSENSE_DEFAULTS = Ruleset(
rules=[Rule(permission="*", pattern="*", action="allow")],
origin="surfsense_defaults",
)
def _action_for(tool_name: str, *rulesets: Ruleset) -> str:
rules = evaluate_many(tool_name, [tool_name], *rulesets)
return aggregate_action(rules)
class TestDesktopSafetyRulesGateDestructiveOps:
@pytest.mark.parametrize(
"tool_name",
["rm", "rmdir", "move_file", "edit_file", "write_file"],
)
def test_destructive_op_resolves_to_ask(self, tool_name: str) -> None:
# surfsense_defaults says "allow */*"; desktop_safety must override
# because it's layered later (last-match-wins).
action = _action_for(tool_name, SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
assert action == "ask", (
f"{tool_name} must require approval in desktop mode "
f"(no revert path on real disk); got {action!r}"
)
@pytest.mark.parametrize(
"tool_name",
["read_file", "ls", "list_tree", "grep", "glob", "cd", "pwd", "mkdir"],
)
def test_safe_ops_remain_allowed(self, tool_name: str) -> None:
# Read-only and trivially-reversible tools must NOT get gated —
# otherwise every navigation in desktop mode pops an interrupt.
action = _action_for(tool_name, SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
assert action == "allow", (
f"{tool_name} should not be gated in desktop mode; got {action!r}"
)
class TestDesktopSafetyOverridesAllowDefault:
def test_layer_order_last_match_wins(self) -> None:
# If desktop_safety is layered BEFORE surfsense_defaults, the allow
# default would win and the safety net would be inert. This test
# protects against accidentally swapping the rulesets in
# ``_build_compiled_agent_blocking``.
action = _action_for("rm", DESKTOP_SAFETY_RULESET, SURFSENSE_DEFAULTS)
# Layered "wrong way" — the broad allow now wins.
assert action == "allow"
# Correct order: defaults < desktop_safety -> ask wins.
action = _action_for("rm", SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
assert action == "ask"
class TestPermissionMiddlewareIntegration:
def test_middleware_raises_interrupt_for_rm_in_desktop_mode(self) -> None:
from langchain_core.messages import AIMessage
from app.agents.new_chat.errors import RejectedError
mw = PermissionMiddleware(rulesets=[SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET])
# Stub the interrupt to a "reject" decision so we can assert the
# ask path was taken without spinning up the LangGraph runtime.
mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment]
state = {
"messages": [
AIMessage(
content="",
tool_calls=[
{
"name": "rm",
"args": {"path": "/Users/me/Documents/important.docx"},
"id": "tc-rm",
}
],
)
]
}
class _FakeRuntime:
config: dict = {"configurable": {"thread_id": "test"}}
with pytest.raises(RejectedError):
mw.after_model(state, _FakeRuntime())

View file

@ -0,0 +1,111 @@
"""Tests for the default auto-approval list in ``hitl.request_approval``.
These pin the policy that low-stakes connector creation tools (drafts,
new-file creates) skip the HITL interrupt by default. Without this set,
every "draft my newsletter" turn used to fire ~3 interrupts before any
useful work happened.
"""
from __future__ import annotations
import pytest
from app.agents.new_chat.tools.hitl import (
DEFAULT_AUTO_APPROVED_TOOLS,
HITLResult,
request_approval,
)
pytestmark = pytest.mark.unit
class TestDefaultAutoApprovedToolsList:
def test_set_contains_expected_creation_tools(self) -> None:
# If anyone changes the policy list, we want a single test to
# update so the contract is explicit. Keep this in sync with
# ``hitl.DEFAULT_AUTO_APPROVED_TOOLS``.
expected = {
"create_gmail_draft",
"update_gmail_draft",
"create_notion_page",
"create_confluence_page",
"create_google_drive_file",
"create_dropbox_file",
"create_onedrive_file",
}
assert expected == DEFAULT_AUTO_APPROVED_TOOLS
def test_set_is_immutable(self) -> None:
# frozenset prevents accidental at-runtime mutation that would
# silently widen the auto-approval surface.
assert isinstance(DEFAULT_AUTO_APPROVED_TOOLS, frozenset)
def test_send_tools_are_not_auto_approved(self) -> None:
# External-broadcast tools must always prompt.
for tool_name in (
"send_gmail_email",
"send_discord_message",
"send_teams_message",
"delete_notion_page",
"create_calendar_event",
"delete_calendar_event",
):
assert tool_name not in DEFAULT_AUTO_APPROVED_TOOLS, (
f"{tool_name} must remain HITL-gated"
)
class TestRequestApprovalAutoBypass:
def test_auto_approved_tool_skips_interrupt(self) -> None:
# No interrupt mock set up — if the function attempted to call
# ``langgraph.types.interrupt`` it would raise GraphInterrupt.
# The fact that we get a clean HITLResult proves the bypass.
result = request_approval(
action_type="gmail_draft_creation",
tool_name="create_gmail_draft",
params={"to": "alice@example.com", "subject": "hi", "body": "hey"},
)
assert isinstance(result, HITLResult)
assert result.rejected is False
assert result.decision_type == "auto_approved"
# Original params are preserved untouched (no user edits possible).
assert result.params == {
"to": "alice@example.com",
"subject": "hi",
"body": "hey",
}
def test_non_listed_tool_still_attempts_interrupt(self) -> None:
# A tool NOT in the default list must reach ``langgraph.interrupt``.
# Outside a runnable context that call raises a RuntimeError —
# which is exactly the signal we want: the bypass did NOT fire.
with pytest.raises(RuntimeError, match="runnable context"):
request_approval(
action_type="gmail_email_send",
tool_name="send_gmail_email",
params={"to": "alice@example.com", "subject": "hi", "body": "hey"},
)
def test_user_trusted_tools_still_take_precedence(self) -> None:
# ``trusted_tools`` (per-connector "always allow" from MCP/UI)
# was checked BEFORE the default list and must keep working
# for tools outside the default list.
result = request_approval(
action_type="mcp_tool_call",
tool_name="my_custom_mcp_tool",
params={"x": 1},
trusted_tools=["my_custom_mcp_tool"],
)
assert result.decision_type == "trusted"
assert result.rejected is False
def test_auto_approved_overrides_no_trusted_tools(self) -> None:
# When trusted_tools is empty and tool is in the default list,
# we should still bypass — proves the order in request_approval.
result = request_approval(
action_type="notion_page_creation",
tool_name="create_notion_page",
params={"title": "Plan"},
trusted_tools=[],
)
assert result.decision_type == "auto_approved"

View file

@ -0,0 +1,350 @@
"""Tests for ``apply_litellm_prompt_caching`` in
:mod:`app.agents.new_chat.prompt_caching`.
The helper replaces the legacy ``AnthropicPromptCachingMiddleware`` (which
never activated for our LiteLLM stack) with LiteLLM-native multi-provider
prompt caching. It mutates ``llm.model_kwargs`` so the kwargs flow to
``litellm.completion(...)``. The tests below pin its public contract:
1. Always sets BOTH ``role: system`` and ``index: -1`` injection points so
savings compound across multi-turn conversations on Anthropic-family
providers.
2. Adds ``prompt_cache_key``/``prompt_cache_retention`` only for
single-model OPENAI/DEEPSEEK/XAI configs (where OpenAI's automatic
prompt-cache surface is available).
3. Treats ``ChatLiteLLMRouter`` (auto-mode) as universal-only no
OpenAI-only kwargs because the router fans out across providers.
4. Idempotent: user-supplied values in ``model_kwargs`` are preserved.
5. Defensive: LLMs without a writable ``model_kwargs`` are silently
skipped rather than raising.
"""
from __future__ import annotations
from typing import Any
import pytest
from app.agents.new_chat.llm_config import AgentConfig
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Test doubles
# ---------------------------------------------------------------------------
class _FakeLLM:
"""Stand-in for ``ChatLiteLLM``/``SanitizedChatLiteLLM``.
The helper only inspects ``getattr(llm, "model_kwargs", None)``,
``getattr(llm, "model", None)``, and ``type(llm).__name__``. A simple
object suffices we don't need to spin up real LangChain/LiteLLM
machinery for unit tests of the helper's logic.
"""
def __init__(
self,
model: str = "openai/gpt-4o",
model_kwargs: dict[str, Any] | None = None,
) -> None:
self.model = model
self.model_kwargs: dict[str, Any] = dict(model_kwargs) if model_kwargs else {}
class ChatLiteLLMRouter:
"""Class-name-only impostor of the real router.
The helper's router gate is ``type(llm).__name__ == "ChatLiteLLMRouter"``
(a deliberate stringly-typed check to avoid an import cycle with
``app.services.llm_router_service``). Reusing the same class name here
triggers the same code path without instantiating a real ``Router``.
"""
def __init__(self) -> None:
self.model = "auto"
self.model_kwargs: dict[str, Any] = {}
def _make_cfg(**overrides: Any) -> AgentConfig:
"""Build an ``AgentConfig`` with sensible defaults for the helper test."""
defaults: dict[str, Any] = {
"provider": "OPENAI",
"model_name": "gpt-4o",
"api_key": "k",
}
return AgentConfig(**{**defaults, **overrides})
# ---------------------------------------------------------------------------
# (a) Universal injection points
# ---------------------------------------------------------------------------
def test_sets_both_cache_control_injection_points_with_no_config() -> None:
"""Bare call (no agent_config, no thread_id) still sets the two
universal breakpoints these cost nothing on providers that don't
consume them and unlock caching on every supported provider."""
llm = _FakeLLM()
apply_litellm_prompt_caching(llm)
points = llm.model_kwargs["cache_control_injection_points"]
assert {"location": "message", "role": "system"} in points
assert {"location": "message", "index": -1} in points
assert len(points) == 2
def test_injection_points_set_for_anthropic_config() -> None:
"""Anthropic-family configs need the marker — verify it lands."""
cfg = _make_cfg(provider="ANTHROPIC", model_name="claude-3-5-sonnet")
llm = _FakeLLM(model="anthropic/claude-3-5-sonnet")
apply_litellm_prompt_caching(llm, agent_config=cfg)
assert "cache_control_injection_points" in llm.model_kwargs
# ---------------------------------------------------------------------------
# (b) Idempotency / user override wins
# ---------------------------------------------------------------------------
def test_does_not_overwrite_user_supplied_cache_control_injection_points() -> None:
"""Users who set their own injection points (e.g. with ``ttl: "1h"``
via ``litellm_params``) keep them the helper merges, never
clobbers."""
user_points = [
{"location": "message", "role": "system", "ttl": "1h"},
]
llm = _FakeLLM(
model_kwargs={"cache_control_injection_points": user_points},
)
apply_litellm_prompt_caching(llm)
assert llm.model_kwargs["cache_control_injection_points"] is user_points
def test_idempotent_when_called_multiple_times() -> None:
"""Build-time + thread-time double-call must be a no-op the second time."""
cfg = _make_cfg(provider="OPENAI")
llm = _FakeLLM()
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=1)
snapshot = {
"cache_control_injection_points": list(
llm.model_kwargs["cache_control_injection_points"]
),
"prompt_cache_key": llm.model_kwargs["prompt_cache_key"],
"prompt_cache_retention": llm.model_kwargs["prompt_cache_retention"],
}
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=1)
assert (
llm.model_kwargs["cache_control_injection_points"]
== snapshot["cache_control_injection_points"]
)
assert llm.model_kwargs["prompt_cache_key"] == snapshot["prompt_cache_key"]
assert (
llm.model_kwargs["prompt_cache_retention"] == snapshot["prompt_cache_retention"]
)
def test_does_not_overwrite_user_supplied_prompt_cache_key() -> None:
"""A pre-set ``prompt_cache_key`` (e.g. tenant-aware override via
``litellm_params``) wins over our default per-thread key."""
cfg = _make_cfg(provider="OPENAI")
llm = _FakeLLM(model_kwargs={"prompt_cache_key": "tenant-abc"})
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
assert llm.model_kwargs["prompt_cache_key"] == "tenant-abc"
# ---------------------------------------------------------------------------
# (c) OpenAI-family extras (OPENAI / DEEPSEEK / XAI)
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("provider", ["OPENAI", "DEEPSEEK", "XAI"])
def test_sets_openai_family_extras(provider: str) -> None:
"""OpenAI-style providers gain ``prompt_cache_key`` (raises hit rate
via routing affinity) and ``prompt_cache_retention="24h"`` (extends
cache TTL beyond the default 5-10 min)."""
cfg = _make_cfg(provider=provider)
llm = _FakeLLM()
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
assert llm.model_kwargs["prompt_cache_key"] == "surfsense-thread-42"
assert llm.model_kwargs["prompt_cache_retention"] == "24h"
def test_skips_prompt_cache_key_when_no_thread_id() -> None:
"""Without a thread id we can't construct a per-thread key. Retention
is still useful so we set it (it's free)."""
cfg = _make_cfg(provider="OPENAI")
llm = _FakeLLM()
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=None)
assert "prompt_cache_key" not in llm.model_kwargs
assert llm.model_kwargs["prompt_cache_retention"] == "24h"
@pytest.mark.parametrize(
"provider",
["ANTHROPIC", "BEDROCK", "VERTEX_AI", "GOOGLE_AI_STUDIO", "GROQ", "MOONSHOT"],
)
def test_no_openai_extras_for_other_providers(provider: str) -> None:
"""Non-OpenAI-family providers don't expose ``prompt_cache_key`` —
skip it. ``cache_control_injection_points`` is still set (universal)."""
cfg = _make_cfg(provider=provider)
llm = _FakeLLM()
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
assert "prompt_cache_key" not in llm.model_kwargs
assert "prompt_cache_retention" not in llm.model_kwargs
assert "cache_control_injection_points" in llm.model_kwargs
def test_no_openai_extras_in_auto_mode() -> None:
"""Auto-mode fans out across mixed providers — we can't statically
target OpenAI-only kwargs."""
cfg = AgentConfig.from_auto_mode()
llm = _FakeLLM()
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
assert "prompt_cache_key" not in llm.model_kwargs
assert "prompt_cache_retention" not in llm.model_kwargs
assert "cache_control_injection_points" in llm.model_kwargs
def test_no_openai_extras_for_custom_provider() -> None:
"""Custom providers route through arbitrary user-supplied prefixes —
we don't try to infer OpenAI-family compatibility."""
cfg = _make_cfg(provider="OPENAI", custom_provider="my_proxy")
llm = _FakeLLM()
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
assert "prompt_cache_key" not in llm.model_kwargs
assert "prompt_cache_retention" not in llm.model_kwargs
# ---------------------------------------------------------------------------
# (d) ChatLiteLLMRouter — universal injection points only
# ---------------------------------------------------------------------------
def test_router_llm_gets_only_universal_injection_points() -> None:
"""Even with an OpenAI-flavoured config, a ``ChatLiteLLMRouter`` must
receive only the universal injection points its requests dispatch
across provider deployments and OpenAI-only kwargs would be wasted
(or stripped by ``drop_params``) on non-OpenAI legs."""
router = ChatLiteLLMRouter()
cfg = _make_cfg(provider="OPENAI")
apply_litellm_prompt_caching(router, agent_config=cfg, thread_id=42)
assert "cache_control_injection_points" in router.model_kwargs
assert "prompt_cache_key" not in router.model_kwargs
assert "prompt_cache_retention" not in router.model_kwargs
# ---------------------------------------------------------------------------
# (e) Defensive paths
# ---------------------------------------------------------------------------
def test_handles_llm_with_no_writable_model_kwargs() -> None:
"""Some LLM implementations (e.g. fakes / minimal subclasses) don't
expose a writable ``model_kwargs``. The helper must skip silently
raising would crash the entire LLM build path on a non-critical
optimisation."""
class _ImmutableLLM:
# ``__slots__`` blocks attribute creation, so ``setattr`` raises.
__slots__ = ("model",)
def __init__(self) -> None:
self.model = "openai/gpt-4o"
llm = _ImmutableLLM()
apply_litellm_prompt_caching(llm)
def test_initialises_missing_model_kwargs_dict() -> None:
"""When ``model_kwargs`` is present-but-None (Pydantic v2 default
pattern when no factory is set), the helper initialises it to an
empty dict before mutating."""
class _LazyLLM:
def __init__(self) -> None:
self.model = "openai/gpt-4o"
self.model_kwargs: dict[str, Any] | None = None
llm = _LazyLLM()
apply_litellm_prompt_caching(llm)
assert isinstance(llm.model_kwargs, dict)
assert "cache_control_injection_points" in llm.model_kwargs
def test_falls_back_to_llm_model_prefix_when_no_agent_config() -> None:
"""Direct caller path (e.g. ``create_chat_litellm_from_config`` for
YAML configs without a structured ``AgentConfig``): without
``agent_config`` the helper sets only the universal injection points
no OpenAI-family extras even if the prefix says ``openai/``.
Conservative: we'd rather miss the speedup than silently misroute."""
llm = _FakeLLM(model="openai/gpt-4o")
apply_litellm_prompt_caching(llm, agent_config=None, thread_id=99)
assert "cache_control_injection_points" in llm.model_kwargs
assert "prompt_cache_key" not in llm.model_kwargs
assert "prompt_cache_retention" not in llm.model_kwargs
# ---------------------------------------------------------------------------
# (f) drop_params safety net (regression guard for #19346)
# ---------------------------------------------------------------------------
def test_litellm_drop_params_is_globally_enabled() -> None:
"""``litellm.drop_params=True`` is set globally in
:mod:`app.services.llm_service` so any ``prompt_cache_key`` /
``prompt_cache_retention`` we set on an OpenAI-family config is
auto-stripped if the request later routes to a non-supporting
provider (e.g. via auto-mode router fallback). This test pins that
invariant losing it would mean Bedrock/Vertex 400s on ``prompt_cache_key``.
"""
import litellm
import app.services.llm_service # noqa: F401 (side-effect: sets globals)
assert litellm.drop_params is True
# ---------------------------------------------------------------------------
# Regression note: LiteLLM #15696 (multi-content-block last message)
# ---------------------------------------------------------------------------
#
# Before LiteLLM 1.81 a list-form last message ``[block_a, block_b]``
# would get ``cache_control`` applied to *every* content block instead
# of only the last one — wasting cache breakpoints and triggering 400s
# on Anthropic when it exceeded the 4-breakpoint limit. Fixed in
# https://github.com/BerriAI/litellm/pull/15699.
#
# We pin ``litellm>=1.83.7`` in ``pyproject.toml`` (well past the fix).
# An end-to-end behavioural test would need to run ``litellm.completion``
# through the Anthropic transformer, which is integration territory and
# better covered by LiteLLM's own test suite. The unit guard here is the
# version pin plus the build-time ``model_kwargs`` shape we verify above.

View file

@ -0,0 +1,117 @@
"""Tests for ``_resolve_prompt_model_name`` in :mod:`app.agents.new_chat.chat_deepagent`.
The helper picks the model id fed to ``detect_provider_variant`` so the
right ``<provider_hints>`` block lands in the system prompt. The tests
below pin its preference order:
1. ``agent_config.litellm_params["base_model"]`` (Azure-correct).
2. ``agent_config.model_name``.
3. ``getattr(llm, "model", None)``.
Without (1) an Azure deployment named e.g. ``"prod-chat-001"`` would
silently miss every provider regex.
"""
from __future__ import annotations
import pytest
from app.agents.new_chat.chat_deepagent import _resolve_prompt_model_name
from app.agents.new_chat.llm_config import AgentConfig
pytestmark = pytest.mark.unit
def _make_cfg(**overrides) -> AgentConfig:
"""Build an ``AgentConfig`` with sensible defaults for the helper test."""
defaults = {
"provider": "OPENAI",
"model_name": "x",
"api_key": "k",
}
return AgentConfig(**{**defaults, **overrides})
class _FakeLLM:
"""Stand-in for a ``ChatLiteLLM`` / ``ChatLiteLLMRouter`` instance.
The resolver only reads the ``.model`` attribute via ``getattr``,
matching the established idiom in ``knowledge_search.py`` /
``stream_new_chat.py`` / ``document_summarizer.py``.
"""
def __init__(self, model: str | None) -> None:
self.model = model
def test_prefers_litellm_params_base_model_over_deployment_name() -> None:
"""Azure deployment slug must NOT shadow the underlying model family.
This is the failure mode the helper exists to prevent: a deployment
named ``"azure/prod-chat-001"`` would not match any provider regex
on its own, but the family ``"gpt-4o"`` lives in
``litellm_params["base_model"]`` and routes to ``openai_classic``.
"""
cfg = _make_cfg(
model_name="azure/prod-chat-001",
litellm_params={"base_model": "gpt-4o"},
)
assert _resolve_prompt_model_name(cfg, _FakeLLM("azure/prod-chat-001")) == "gpt-4o"
def test_falls_back_to_model_name_when_litellm_params_is_none() -> None:
cfg = _make_cfg(
model_name="anthropic/claude-3-5-sonnet",
litellm_params=None,
)
got = _resolve_prompt_model_name(cfg, _FakeLLM("anthropic/claude-3-5-sonnet"))
assert got == "anthropic/claude-3-5-sonnet"
def test_handles_litellm_params_without_base_model_key() -> None:
cfg = _make_cfg(
model_name="openai/gpt-4o",
litellm_params={"temperature": 0.5},
)
assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o"
def test_ignores_blank_base_model() -> None:
"""Whitespace-only ``base_model`` must not shadow ``model_name``."""
cfg = _make_cfg(
model_name="openai/gpt-4o",
litellm_params={"base_model": " "},
)
assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o"
def test_ignores_non_string_base_model() -> None:
"""Defensive: a non-string ``base_model`` should not crash the resolver."""
cfg = _make_cfg(
model_name="openai/gpt-4o",
litellm_params={"base_model": 42},
)
assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o"
def test_falls_back_to_llm_model_when_no_agent_config() -> None:
"""No ``agent_config`` -> use ``llm.model`` directly. Defensive path
for direct callers; production callers always supply a config."""
assert (
_resolve_prompt_model_name(None, _FakeLLM("openai/gpt-4o-mini"))
== "openai/gpt-4o-mini"
)
def test_returns_none_when_nothing_available() -> None:
"""``compose_system_prompt`` treats ``None`` as the ``"default"``
variant and emits no provider block."""
assert _resolve_prompt_model_name(None, _FakeLLM(None)) is None
def test_auto_mode_resolves_to_auto_string() -> None:
"""Auto mode -> ``"auto"``. ``detect_provider_variant("auto")``
returns ``"default"``, which is correct: the child model isn't
known until the LiteLLM Router dispatches."""
cfg = AgentConfig.from_auto_mode()
assert _resolve_prompt_model_name(cfg, _FakeLLM("auto")) == "auto"

View file

@ -0,0 +1,333 @@
"""Cloud-mode behavior tests for the new ``rm`` and ``rmdir`` filesystem tools.
The tools build ``Command(update=...)`` payloads that the persistence
middleware applies at end of turn. These tests stub out the backend and
runtime to assert the staging payload shape:
* ``rm`` queues into ``pending_deletes`` and tombstones state files.
* ``rm`` rejects directories, ``/documents``, root, and the anonymous doc.
* ``rmdir`` queues into ``pending_dir_deletes`` and rejects non-empty dirs.
* ``rmdir`` un-stages a same-turn ``mkdir`` rather than queuing a delete.
* ``rmdir`` refuses to drop the cwd or any of its ancestors.
* ``KBPostgresBackend`` view-helpers honor staged deletes.
"""
from __future__ import annotations
from types import SimpleNamespace
from typing import Any
from unittest.mock import AsyncMock
import pytest
from app.agents.new_chat.filesystem_selection import FilesystemMode
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend
pytestmark = pytest.mark.unit
def _make_middleware(mode: FilesystemMode = FilesystemMode.CLOUD):
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
middleware._filesystem_mode = mode
middleware._custom_tool_descriptions = {}
return middleware
def _runtime(state: dict[str, Any] | None = None, *, tool_call_id: str = "tc-abc"):
state = state or {}
state.setdefault("cwd", "/documents")
return SimpleNamespace(state=state, tool_call_id=tool_call_id)
class _KBBackendStub(KBPostgresBackend):
"""Construct-able subclass of :class:`KBPostgresBackend` for tests.
We bypass the real ``__init__`` (which expects a runtime + DB session)
and inject just the methods the rm/rmdir tools touch. The class
inheritance keeps ``isinstance(backend, KBPostgresBackend)`` checks
inside the tools happy, which is what gates them from the desktop
code path.
"""
def __init__(self, *, children=None, file_data=None) -> None:
self.als_info = AsyncMock(return_value=children or [])
self._load_file_data = AsyncMock(
return_value=(file_data, 17) if file_data is not None else None
)
def _make_backend_stub(*, children=None, file_data=None) -> KBPostgresBackend:
return _KBBackendStub(children=children, file_data=file_data)
def _bind_backend(middleware, backend):
"""Inject a backend resolver onto the middleware test instance."""
middleware._get_backend = lambda runtime: backend
return backend
# ---------------------------------------------------------------------------
# rm
# ---------------------------------------------------------------------------
class TestRmStaging:
@pytest.mark.asyncio
async def test_stages_delete_and_tombstones_state(self):
m = _make_middleware()
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
runtime = _runtime(
{
"cwd": "/documents",
"files": {"/documents/notes.md": {"content": ["hello"]}},
"doc_id_by_path": {"/documents/notes.md": 17},
},
tool_call_id="tc-1",
)
tool = m._create_rm_tool()
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
assert hasattr(result, "update"), f"expected Command, got {result!r}"
update = result.update
assert update["pending_deletes"] == [
{"path": "/documents/notes.md", "tool_call_id": "tc-1"}
]
assert update["files"] == {"/documents/notes.md": None}
assert update["doc_id_by_path"] == {"/documents/notes.md": None}
@pytest.mark.asyncio
async def test_rejects_documents_root(self):
m = _make_middleware()
runtime = _runtime()
tool = m._create_rm_tool()
result = await tool.coroutine("/documents", runtime=runtime)
assert isinstance(result, str)
assert "refusing to rm" in result
@pytest.mark.asyncio
async def test_rejects_root(self):
m = _make_middleware()
runtime = _runtime()
tool = m._create_rm_tool()
result = await tool.coroutine("/", runtime=runtime)
assert isinstance(result, str)
assert "refusing to rm" in result
@pytest.mark.asyncio
async def test_rejects_directory_via_staged_dirs(self):
m = _make_middleware()
runtime = _runtime(
{
"staged_dirs": ["/documents/team-x"],
}
)
tool = m._create_rm_tool()
result = await tool.coroutine("/documents/team-x", runtime=runtime)
assert isinstance(result, str)
assert "directory" in result.lower()
assert "rmdir" in result
@pytest.mark.asyncio
async def test_rejects_directory_via_listing(self):
m = _make_middleware()
_bind_backend(
m,
_make_backend_stub(
children=[{"path": "/documents/foo/x.md", "is_dir": False}]
),
)
runtime = _runtime()
tool = m._create_rm_tool()
result = await tool.coroutine("/documents/foo", runtime=runtime)
assert isinstance(result, str)
assert "directory" in result.lower()
@pytest.mark.asyncio
async def test_rejects_anonymous_doc(self):
m = _make_middleware()
runtime = _runtime(
{
"kb_anon_doc": {
"path": "/documents/uploaded.xml",
"title": "uploaded",
"content": "",
"chunks": [],
}
}
)
tool = m._create_rm_tool()
result = await tool.coroutine("/documents/uploaded.xml", runtime=runtime)
assert isinstance(result, str)
assert "read-only" in result
@pytest.mark.asyncio
async def test_drops_path_from_dirty_paths(self):
m = _make_middleware()
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
runtime = _runtime(
{
"files": {"/documents/notes.md": {"content": ["x"]}},
"doc_id_by_path": {"/documents/notes.md": 17},
"dirty_paths": ["/documents/notes.md"],
}
)
tool = m._create_rm_tool()
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
update = result.update
# First element is _CLEAR sentinel; the rest must NOT contain the
# rm'd path.
dirty = update.get("dirty_paths") or []
assert "/documents/notes.md" not in dirty[1:]
# ---------------------------------------------------------------------------
# rmdir
# ---------------------------------------------------------------------------
class TestRmdirStaging:
@pytest.mark.asyncio
async def test_stages_dir_delete_when_empty_and_db_backed(self):
m = _make_middleware()
backend = _bind_backend(m, _make_backend_stub(children=[]))
# Override _load_file_data to return None (folder, not a file) and
# parent listing to claim the folder exists.
backend._load_file_data = AsyncMock(return_value=None)
backend.als_info = AsyncMock(
side_effect=[
[], # children of /documents/proj
[
{"path": "/documents/proj", "is_dir": True},
], # parent listing
]
)
runtime = _runtime(
{
"cwd": "/documents",
},
tool_call_id="tc-rd",
)
tool = m._create_rmdir_tool()
result = await tool.coroutine("/documents/proj", runtime=runtime)
assert hasattr(result, "update")
update = result.update
assert update["pending_dir_deletes"] == [
{"path": "/documents/proj", "tool_call_id": "tc-rd"}
]
@pytest.mark.asyncio
async def test_rejects_non_empty(self):
m = _make_middleware()
_bind_backend(
m,
_make_backend_stub(
children=[{"path": "/documents/proj/x.md", "is_dir": False}]
),
)
runtime = _runtime()
tool = m._create_rmdir_tool()
result = await tool.coroutine("/documents/proj", runtime=runtime)
assert isinstance(result, str)
assert "not empty" in result
@pytest.mark.asyncio
async def test_unstages_same_turn_mkdir(self):
m = _make_middleware()
_bind_backend(m, _make_backend_stub(children=[]))
runtime = _runtime(
{
"cwd": "/documents",
"staged_dirs": ["/documents/scratch"],
},
tool_call_id="tc-rd",
)
tool = m._create_rmdir_tool()
result = await tool.coroutine("/documents/scratch", runtime=runtime)
assert hasattr(result, "update")
update = result.update
assert "pending_dir_deletes" not in update
# _CLEAR sentinel + remaining items (in this case, none).
staged_after = update["staged_dirs"]
assert staged_after[0] == "\x00__SURFSENSE_FILESYSTEM_CLEAR__\x00"
assert "/documents/scratch" not in staged_after[1:]
@pytest.mark.asyncio
async def test_rejects_root(self):
m = _make_middleware()
runtime = _runtime()
tool = m._create_rmdir_tool()
for victim in ("/", "/documents"):
result = await tool.coroutine(victim, runtime=runtime)
assert isinstance(result, str)
assert "refusing to rmdir" in result
@pytest.mark.asyncio
async def test_rejects_cwd(self):
m = _make_middleware()
runtime = _runtime({"cwd": "/documents/proj"})
tool = m._create_rmdir_tool()
result = await tool.coroutine("/documents/proj", runtime=runtime)
assert isinstance(result, str)
assert "cwd" in result.lower()
@pytest.mark.asyncio
async def test_rejects_ancestor_of_cwd(self):
m = _make_middleware()
runtime = _runtime({"cwd": "/documents/proj/sub"})
tool = m._create_rmdir_tool()
result = await tool.coroutine("/documents/proj", runtime=runtime)
assert isinstance(result, str)
assert "cwd" in result.lower()
@pytest.mark.asyncio
async def test_rejects_files(self):
m = _make_middleware()
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
runtime = _runtime()
tool = m._create_rmdir_tool()
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
assert isinstance(result, str)
assert "is a file" in result
# ---------------------------------------------------------------------------
# KBPostgresBackend view filter
# ---------------------------------------------------------------------------
class TestKBPostgresBackendDeleteFilter:
"""als_info / glob / grep should suppress paths queued for delete."""
def _make_backend(self, state: dict[str, Any]) -> KBPostgresBackend:
runtime = SimpleNamespace(state=state)
backend = KBPostgresBackend(search_space_id=1, runtime=runtime)
return backend
def test_pending_filesystem_view_returns_deleted_paths(self):
backend = self._make_backend(
{
"pending_deletes": [
{"path": "/documents/x.md", "tool_call_id": "t1"},
],
"pending_dir_deletes": [
{"path": "/documents/d1", "tool_call_id": "t2"},
],
}
)
removed, alias, deleted_dirs = backend._pending_filesystem_view({})
assert "/documents/x.md" in removed
assert "/documents/d1" in deleted_dirs
assert alias == {}
def test_dir_suppressed_covers_descendants(self):
backend = self._make_backend({})
deleted_dirs = {"/documents/d"}
assert backend._is_dir_suppressed("/documents/d", deleted_dirs)
assert backend._is_dir_suppressed("/documents/d/x.md", deleted_dirs)
assert backend._is_dir_suppressed("/documents/d/sub/y.md", deleted_dirs)
assert not backend._is_dir_suppressed("/documents/other.md", deleted_dirs)

View file

@ -98,10 +98,54 @@ class TestInitialFilesystemState:
state = _initial_filesystem_state()
assert state["cwd"] == "/documents"
assert state["staged_dirs"] == []
assert state["staged_dir_tool_calls"] == {}
assert state["pending_moves"] == []
assert state["pending_deletes"] == []
assert state["pending_dir_deletes"] == []
assert state["doc_id_by_path"] == {}
assert state["dirty_paths"] == []
assert state["dirty_path_tool_calls"] == {}
assert state["kb_priority"] == []
assert state["kb_matched_chunk_ids"] == {}
assert state["kb_anon_doc"] is None
assert state["tree_version"] == 0
class TestMultiEditSamePathCoalescing:
"""Multi-edit-same-path turns must coalesce into ONE binding record.
The persistence body uses ``dirty_path_tool_calls[path]`` to find the
tool_call_id that produced the current state on disk. Because
``dirty_paths`` dedupes via :func:`_add_unique_reducer` the second
edit doesn't append a new path entry — and because
``_dict_merge_with_tombstones_reducer`` lets the right-hand side
overwrite, the LATEST tool_call_id wins. That's the correct behavior
for snapshotting: revert restores to the pre-mutation state, and
multiple back-to-back edits in one turn coalesce into a single
revisible op (the user sees ONE Revert button per turn-per-path,
not N).
"""
def test_dirty_paths_dedupes_repeated_writes(self):
# ``_add_unique_reducer`` is applied to ``dirty_paths``. Two writes
# to the same path produce one entry, not two.
first = _add_unique_reducer([], ["/documents/a.md"])
second = _add_unique_reducer(first, ["/documents/a.md"])
assert second == ["/documents/a.md"]
def test_dirty_path_tool_calls_keeps_latest_tool_call_id(self):
# First write tags the path with tcid-1.
merged = _dict_merge_with_tombstones_reducer({}, {"/documents/a.md": "tcid-1"})
# Second write to the same path tags it with tcid-2 (latest wins).
merged = _dict_merge_with_tombstones_reducer(
merged, {"/documents/a.md": "tcid-2"}
)
assert merged == {"/documents/a.md": "tcid-2"}
def test_rm_tombstones_dirty_path_tool_call(self):
# ``rm`` writes ``{path: None}`` into dirty_path_tool_calls to
# prevent a stale binding from leaking past the delete.
merged = _dict_merge_with_tombstones_reducer(
{"/documents/a.md": "tcid-1"}, {"/documents/a.md": None}
)
assert merged == {}

View file

@ -0,0 +1,83 @@
"""Smoke test for the ``134_relax_revision_fks`` Alembic migration.
A full apply/rollback test would require a live Postgres; here we verify
the migration module's static contract:
* The chain wires it as a successor of ``133_drop_documents_content_hash_unique``.
* ``upgrade()`` declares two FK creations with ``ondelete='SET NULL'``
(one for ``document_revisions.document_id``, one for
``folder_revisions.folder_id``).
* ``downgrade()`` re-establishes ``ondelete='CASCADE'`` after draining
orphaned revisions.
If any of these invariants regress the snapshot/revert pipeline silently
loses the ability to undo ``rm`` / ``rmdir`` on environments that ran the
migration "down" or never ran it at all.
"""
from __future__ import annotations
import importlib.util
import inspect
from pathlib import Path
import pytest
pytestmark = pytest.mark.unit
_MIGRATION_PATH = (
Path(__file__).resolve().parents[3]
/ "alembic"
/ "versions"
/ "134_relax_revision_fks.py"
)
def _load_migration():
"""Load the migration module by file path (no package import needed)."""
spec = importlib.util.spec_from_file_location("_migration_134", _MIGRATION_PATH)
assert spec and spec.loader, "could not load migration spec"
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def test_migration_chain_revision_ids() -> None:
module = _load_migration()
# The migration file uses short numeric revision IDs to match the
# in-tree convention (cf. ``133`` -> ``134``); the ``134_<slug>.py``
# filename is documentation, not the canonical revision string.
assert getattr(module, "revision", None) == "134"
assert getattr(module, "down_revision", None) == "133"
def test_migration_exposes_upgrade_and_downgrade() -> None:
module = _load_migration()
upgrade = getattr(module, "upgrade", None)
downgrade = getattr(module, "downgrade", None)
assert callable(upgrade), "upgrade() is required"
assert callable(downgrade), "downgrade() is required"
def test_upgrade_creates_set_null_fks_for_both_revision_tables() -> None:
module = _load_migration()
src = inspect.getsource(module.upgrade)
assert "document_revisions" in src
assert "folder_revisions" in src
# Both new FKs MUST be ON DELETE SET NULL — that's the entire point
# of the migration: snapshots must outlive their parent row.
assert src.count('ondelete="SET NULL"') >= 2
# And the ``document_id`` / ``folder_id`` columns become nullable.
assert "nullable=True" in src
def test_downgrade_drains_orphans_then_restores_cascade() -> None:
module = _load_migration()
src = inspect.getsource(module.downgrade)
# Drain orphaned rows BEFORE we can re-impose NOT NULL.
assert "DELETE FROM document_revisions WHERE document_id IS NULL" in src
assert "DELETE FROM folder_revisions WHERE folder_id IS NULL" in src
# Then restore the original CASCADE/NOT NULL contract.
assert src.count('ondelete="CASCADE"') >= 2
assert "nullable=False" in src

View file

@ -168,6 +168,8 @@ class TestModeSpecificPrompts:
"edit_file",
"move_file",
"mkdir",
"rm",
"rmdir",
"list_tree",
"grep",
):
@ -182,6 +184,8 @@ class TestModeSpecificPrompts:
"edit_file",
"move_file",
"mkdir",
"rm",
"rmdir",
"list_tree",
"grep",
):
@ -190,6 +194,18 @@ class TestModeSpecificPrompts:
assert "/documents/" not in text, f"{name} mentions cloud namespace"
assert "temp_" not in text, f"{name} mentions cloud temp_ semantics"
def test_cloud_descs_include_rm_and_rmdir(self):
descs = _build_tool_descriptions(FilesystemMode.CLOUD)
assert "rm" in descs and "rmdir" in descs
assert "Deletes a single file" in descs["rm"]
assert "Deletes an empty directory" in descs["rmdir"]
assert "rmdir" in descs["rmdir"] and "POSIX" in descs["rmdir"]
def test_desktop_descs_warn_about_irreversibility(self):
descs = _build_tool_descriptions(FilesystemMode.DESKTOP_LOCAL_FOLDER)
assert "NOT reversible" in descs["rm"]
assert "NOT reversible" in descs["rmdir"]
def test_sandbox_addendum_appended_when_available(self):
prompt = _build_filesystem_system_prompt(
FilesystemMode.CLOUD, sandbox_available=True

View file

@ -0,0 +1,309 @@
"""Unit tests for the kb_persistence snapshot helpers.
The full ``commit_staged_filesystem_state`` body exercises a real session
in integration tests; here we verify the building blocks used by the
snapshot/revert pipeline:
* ``_find_action_ids_batch`` issues a SINGLE query for N tool_call_ids
(regression guard against the N+1 lookup pattern).
* ``_mark_action_reversible`` is a no-op when ``action_id`` is ``None``.
* ``_doc_revision_payload`` and ``_load_chunks_for_snapshot`` produce the
shape the snapshot helpers consume.
These tests use ``MagicMock`` / ``AsyncMock`` against a fake session so
the assertions run in milliseconds and don't require Postgres.
"""
from __future__ import annotations
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.agents.new_chat.middleware import kb_persistence
pytestmark = pytest.mark.unit
class _FakeResult:
def __init__(self, rows: list[Any] | None = None, scalar: Any = None) -> None:
self._rows = rows or []
self._scalar = scalar
def all(self) -> list[Any]:
return list(self._rows)
def scalar_one_or_none(self) -> Any:
return self._scalar
class _FakeSession:
def __init__(self) -> None:
self.execute = AsyncMock()
@pytest.mark.asyncio
async def test_find_action_ids_batch_issues_single_query() -> None:
"""The lookup MUST be a single ``IN (...)`` SELECT, not N selects."""
session = _FakeSession()
session.execute.return_value = _FakeResult(
rows=[
MagicMock(id=11, tool_call_id="tc-a"),
MagicMock(id=22, tool_call_id="tc-b"),
MagicMock(id=33, tool_call_id="tc-c"),
]
)
mapping = await kb_persistence._find_action_ids_batch(
session, # type: ignore[arg-type]
thread_id=1,
tool_call_ids={"tc-a", "tc-b", "tc-c"},
)
assert mapping == {"tc-a": 11, "tc-b": 22, "tc-c": 33}
assert session.execute.await_count == 1, (
"Snapshot binding must batch into ONE query; got "
f"{session.execute.await_count} (regression: N+1 lookup pattern)."
)
@pytest.mark.asyncio
async def test_find_action_ids_batch_short_circuits_when_thread_id_missing() -> None:
session = _FakeSession()
mapping = await kb_persistence._find_action_ids_batch(
session, # type: ignore[arg-type]
thread_id=None,
tool_call_ids={"tc-a"},
)
assert mapping == {}
assert session.execute.await_count == 0
@pytest.mark.asyncio
async def test_find_action_ids_batch_short_circuits_when_no_calls() -> None:
session = _FakeSession()
mapping = await kb_persistence._find_action_ids_batch(
session, # type: ignore[arg-type]
thread_id=42,
tool_call_ids=set(),
)
assert mapping == {}
assert session.execute.await_count == 0
@pytest.mark.asyncio
async def test_mark_action_reversible_is_noop_for_null_id() -> None:
session = _FakeSession()
await kb_persistence._mark_action_reversible(session, action_id=None) # type: ignore[arg-type]
assert session.execute.await_count == 0
@pytest.mark.asyncio
async def test_mark_action_reversible_runs_update_for_real_id() -> None:
session = _FakeSession()
await kb_persistence._mark_action_reversible(session, action_id=99) # type: ignore[arg-type]
assert session.execute.await_count == 1
def test_doc_revision_payload_captures_metadata_virtual_path() -> None:
"""Snapshot helpers must capture ``metadata_before`` for revert reuse."""
doc = MagicMock()
doc.content = "body"
doc.title = "notes.md"
doc.folder_id = 7
doc.document_metadata = {"virtual_path": "/documents/team/notes.md"}
payload = kb_persistence._doc_revision_payload(
doc, chunks_before=[{"content": "x"}]
)
assert payload["title_before"] == "notes.md"
assert payload["folder_id_before"] == 7
assert payload["content_before"] == "body"
assert payload["chunks_before"] == [{"content": "x"}]
assert payload["metadata_before"] == {"virtual_path": "/documents/team/notes.md"}
def test_doc_revision_payload_handles_missing_metadata() -> None:
doc = MagicMock()
doc.content = ""
doc.title = ""
doc.folder_id = None
doc.document_metadata = None
payload = kb_persistence._doc_revision_payload(doc)
assert payload["metadata_before"] is None
@pytest.mark.asyncio
async def test_load_chunks_for_snapshot_returns_content_only() -> None:
"""Snapshot chunks intentionally omit embeddings (regenerated on revert)."""
session = _FakeSession()
session.execute.return_value = _FakeResult(
rows=[
MagicMock(content="alpha"),
MagicMock(content="beta"),
]
)
chunks = await kb_persistence._load_chunks_for_snapshot(
session,
doc_id=42, # type: ignore[arg-type]
)
assert chunks == [{"content": "alpha"}, {"content": "beta"}]
# ---------------------------------------------------------------------------
# Deferred reversibility-flip dispatches.
#
# The snapshot helpers used to dispatch ``action_log_updated`` directly
# from inside the SAVEPOINT block. That meant the SSE side-channel
# could tell the UI a row was reversible while the OUTER transaction
# was still pending — and if the outer commit failed, every SAVEPOINT
# rolled back too, leaving the UI in a state inconsistent with
# durable storage. The deferred-dispatch contract fixes that:
#
# • when a ``deferred_dispatches`` list is provided, the helper
# APPENDS the action_id and does NOT dispatch;
# • the caller (``commit_staged_filesystem_state``) flushes the list
# only AFTER ``await session.commit()`` succeeds; on rollback it
# clears the list so nothing is emitted.
# ---------------------------------------------------------------------------
class _NestedCtx:
"""Async context manager mimicking ``session.begin_nested()``."""
async def __aenter__(self) -> _NestedCtx:
return self
async def __aexit__(self, exc_type, exc, tb) -> bool:
return False
@pytest.mark.asyncio
async def test_pre_write_snapshot_defers_dispatch_when_list_provided(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Helpers MUST queue dispatches when ``deferred_dispatches`` is set."""
session = MagicMock()
session.begin_nested = MagicMock(return_value=_NestedCtx())
session.execute = AsyncMock(return_value=_FakeResult(rows=[]))
session.flush = AsyncMock()
def _add(rev: Any) -> None:
rev.id = 17
session.add = MagicMock(side_effect=_add)
dispatched: list[int] = []
async def _fake_dispatch(action_id: int | None) -> None:
if action_id is not None:
dispatched.append(int(action_id))
monkeypatch.setattr(
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
)
deferred: list[int] = []
doc = MagicMock(id=99, document_metadata={"virtual_path": "/documents/x.md"})
doc.title = "x.md"
doc.folder_id = None
doc.content = "body"
rev_id = await kb_persistence._snapshot_document_pre_write(
session, # type: ignore[arg-type]
doc=doc,
action_id=42,
search_space_id=1,
turn_id="t-1",
deferred_dispatches=deferred,
)
assert rev_id == 17
# Inline dispatch must NOT have fired; the action_id is queued.
assert dispatched == []
assert deferred == [42]
@pytest.mark.asyncio
async def test_pre_write_snapshot_dispatches_inline_when_list_omitted(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Direct callers (no outer transaction) keep the legacy inline dispatch."""
session = MagicMock()
session.begin_nested = MagicMock(return_value=_NestedCtx())
session.execute = AsyncMock(return_value=_FakeResult(rows=[]))
session.flush = AsyncMock()
def _add(rev: Any) -> None:
rev.id = 7
session.add = MagicMock(side_effect=_add)
dispatched: list[int] = []
async def _fake_dispatch(action_id: int | None) -> None:
if action_id is not None:
dispatched.append(int(action_id))
monkeypatch.setattr(
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
)
doc = MagicMock(id=11, document_metadata={"virtual_path": "/documents/y.md"})
doc.title = "y.md"
doc.folder_id = None
doc.content = "body"
await kb_persistence._snapshot_document_pre_write(
session, # type: ignore[arg-type]
doc=doc,
action_id=88,
search_space_id=1,
turn_id="t-1",
# No deferred_dispatches arg — fall back to inline dispatch.
)
assert dispatched == [88]
@pytest.mark.asyncio
async def test_pre_mkdir_snapshot_defers_dispatch_when_list_provided(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Folder mkdir snapshots honour the same deferred-dispatch contract."""
session = MagicMock()
session.begin_nested = MagicMock(return_value=_NestedCtx())
session.execute = AsyncMock() # _mark_action_reversible calls execute
session.flush = AsyncMock()
def _add(rev: Any) -> None:
rev.id = 3
session.add = MagicMock(side_effect=_add)
dispatched: list[int] = []
async def _fake_dispatch(action_id: int | None) -> None:
if action_id is not None:
dispatched.append(int(action_id))
monkeypatch.setattr(
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
)
deferred: list[int] = []
folder = MagicMock(id=2, name="f", parent_id=None, position="a0")
await kb_persistence._snapshot_folder_pre_mkdir(
session, # type: ignore[arg-type]
folder=folder,
action_id=55,
search_space_id=1,
turn_id="t-1",
deferred_dispatches=deferred,
)
assert dispatched == []
assert deferred == [55]

View file

@ -0,0 +1,139 @@
"""Unit tests for ``KnowledgeTreeMiddleware`` rendering.
The empty-folder marker is critical UX: without it, the LLM cannot
distinguish a leaf folder containing one document from a leaf folder
that has no descendants at all, and ends up firing ``rmdir`` on
non-empty folders. These tests pin the rendering contract so that
contract cannot silently regress.
"""
from __future__ import annotations
from app.agents.new_chat.middleware.knowledge_tree import KnowledgeTreeMiddleware
from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT
def _compute(folder_paths: list[str], doc_paths: list[str]) -> set[str]:
return KnowledgeTreeMiddleware._compute_non_empty_folders(folder_paths, doc_paths)
class TestComputeNonEmptyFolders:
def test_folder_with_direct_document_is_non_empty(self):
folder_paths = [f"{DOCUMENTS_ROOT}/Travel/Boarding Pass"]
doc_paths = [
f"{DOCUMENTS_ROOT}/Travel/Boarding Pass/southwest.pdf.xml",
]
non_empty = _compute(folder_paths, doc_paths)
assert f"{DOCUMENTS_ROOT}/Travel/Boarding Pass" in non_empty
def test_truly_empty_leaf_folder_is_not_non_empty(self):
folder_paths = [f"{DOCUMENTS_ROOT}/Travel/Boarding Pass"]
doc_paths: list[str] = []
assert _compute(folder_paths, doc_paths) == set()
def test_documents_propagate_up_to_all_ancestors(self):
folder_paths = [
f"{DOCUMENTS_ROOT}/A",
f"{DOCUMENTS_ROOT}/A/B",
f"{DOCUMENTS_ROOT}/A/B/C",
]
doc_paths = [f"{DOCUMENTS_ROOT}/A/B/C/file.xml"]
non_empty = _compute(folder_paths, doc_paths)
assert non_empty == {
f"{DOCUMENTS_ROOT}/A",
f"{DOCUMENTS_ROOT}/A/B",
f"{DOCUMENTS_ROOT}/A/B/C",
}
def test_chain_with_subfolders_marks_only_leaf_empty(self):
# POSIX-like semantic: a folder is "empty" only if it has no
# immediate children (docs OR sub-folders). The model needs this
# because parallel ``rmdir`` calls all see the same starting state,
# so trying to rmdir a parent before its children is never safe.
folder_paths = [
f"{DOCUMENTS_ROOT}/X",
f"{DOCUMENTS_ROOT}/X/Y",
f"{DOCUMENTS_ROOT}/X/Y/Z",
]
non_empty = _compute(folder_paths, [])
# Only ``X/Y/Z`` (the leaf) is empty. ``X`` and ``X/Y`` each have a
# sub-folder child, so they are non-empty and should NOT carry the
# ``(empty)`` marker.
assert non_empty == {f"{DOCUMENTS_ROOT}/X", f"{DOCUMENTS_ROOT}/X/Y"}
def test_sibling_with_doc_does_not_mark_other_sibling_non_empty(self):
# Mirrors a real DB layout where every intermediate folder is
# materialized in the ``folders`` table.
folder_paths = [
f"{DOCUMENTS_ROOT}/Travel",
f"{DOCUMENTS_ROOT}/Travel/Boarding Pass",
f"{DOCUMENTS_ROOT}/Travel/Notes",
]
doc_paths = [f"{DOCUMENTS_ROOT}/Travel/Notes/itinerary.xml"]
non_empty = _compute(folder_paths, doc_paths)
# ``Travel`` is non-empty because it has children, ``Notes`` is non-empty
# because of the doc, but ``Boarding Pass`` (sibling leaf) is empty.
assert f"{DOCUMENTS_ROOT}/Travel" in non_empty
assert f"{DOCUMENTS_ROOT}/Travel/Notes" in non_empty
assert f"{DOCUMENTS_ROOT}/Travel/Boarding Pass" not in non_empty
class TestFormatTreeRendering:
"""Integration check: empty leaf gets ``(empty)`` marker; non-empty doesn't."""
def _render(
self,
folder_paths: list[str],
doc_specs: list[dict],
) -> str:
from app.agents.new_chat.path_resolver import PathIndex
index = PathIndex(
folder_paths={i + 1: p for i, p in enumerate(folder_paths)},
)
class _Row:
def __init__(self, **kw):
self.__dict__.update(kw)
docs = [_Row(**spec) for spec in doc_specs]
mw = KnowledgeTreeMiddleware(
search_space_id=1,
filesystem_mode=None, # type: ignore[arg-type]
)
return mw._format_tree(index, docs)
def test_renders_empty_marker_only_for_truly_empty_folders(self):
# Reproduces the failure scenario from the bug report:
# ``Boarding Pass`` is empty (its only doc was just deleted), while
# ``Tax Returns`` still has ``federal.pdf``. All intermediate
# folders are present in the index, mirroring the real DB layout.
folder_paths = [
"/documents/File Upload",
"/documents/File Upload/2026-04-08",
"/documents/File Upload/2026-04-08/Travel",
"/documents/File Upload/2026-04-08/Travel/Boarding Pass",
"/documents/File Upload/2026-04-15",
"/documents/File Upload/2026-04-15/Finance",
"/documents/File Upload/2026-04-15/Finance/Tax Returns",
]
tax_returns_folder_id = (
folder_paths.index("/documents/File Upload/2026-04-15/Finance/Tax Returns")
+ 1
)
rendered = self._render(
folder_paths=folder_paths,
doc_specs=[
{
"id": 100,
"title": "federal.pdf",
"folder_id": tax_returns_folder_id,
},
],
)
assert "Boarding Pass/ (empty)" in rendered
assert "Tax Returns/ (empty)" not in rendered
# Intermediate ancestors of the doc must NOT be marked empty.
assert "Finance/ (empty)" not in rendered
assert "2026-04-15/ (empty)" not in rendered

View file

@ -69,3 +69,74 @@ def test_local_backend_write_rejects_missing_parent_directory(tmp_path: Path):
assert write.error is not None
assert "parent directory" in write.error
assert not (tmp_path / "tempoo").exists()
def test_local_backend_delete_file_success(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
(tmp_path / "delete-me.md").write_text("bye")
res = backend.delete_file("/delete-me.md")
assert res.error is None
assert res.path == "/delete-me.md"
assert not (tmp_path / "delete-me.md").exists()
def test_local_backend_delete_file_rejects_directory(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
(tmp_path / "subdir").mkdir()
res = backend.delete_file("/subdir")
assert res.error is not None
assert "directory" in res.error
assert (tmp_path / "subdir").exists()
def test_local_backend_delete_file_missing_returns_error(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
res = backend.delete_file("/nope.md")
assert res.error is not None
assert "not found" in res.error
def test_local_backend_rmdir_success(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
(tmp_path / "empty").mkdir()
res = backend.rmdir("/empty")
assert res.error is None
assert res.path == "/empty"
assert not (tmp_path / "empty").exists()
def test_local_backend_rmdir_rejects_non_empty(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
(tmp_path / "withkid").mkdir()
(tmp_path / "withkid" / "child.md").write_text("x")
res = backend.rmdir("/withkid")
assert res.error is not None
assert "not empty" in res.error
assert (tmp_path / "withkid" / "child.md").exists()
def test_local_backend_rmdir_rejects_file(tmp_path: Path):
backend = LocalFolderBackend(str(tmp_path))
(tmp_path / "f.md").write_text("x")
res = backend.rmdir("/f.md")
assert res.error is not None
assert "not a directory" in res.error
def test_local_backend_rmdir_rejects_root(tmp_path: Path):
"""``rmdir /`` MUST fail. The exact error wording comes from
``_resolve_virtual`` (root resolves to outside the sandbox); what
matters is that the call returns an error and does NOT delete the
sandbox root on disk."""
backend = LocalFolderBackend(str(tmp_path))
res = backend.rmdir("/")
assert res.error is not None
assert "Invalid path" in res.error or "root" in res.error
assert tmp_path.exists()

View file

@ -0,0 +1,143 @@
"""Unit tests for the edit-from-arbitrary-position helpers inside ``new_chat_routes``.
The regenerate route's edit-from-position path introduces:
* ``_find_pre_turn_checkpoint_id`` walks LangGraph checkpoint tuples
newest-first and picks the first one whose ``metadata["turn_id"]``
differs from the edited turn. That checkpoint is the rewind target
(state immediately before the edited turn started).
* ``RegenerateRequest`` accepts ``from_message_id`` + ``revert_actions``
with a validator that prevents callers from requesting a revert pass
without specifying which turn to roll back.
These are pure-Python helpers that don't need a live DB, so we exercise
them with a small ``CheckpointTuple``-shaped namespace and direct
schema instantiation.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
from app.routes.new_chat_routes import _find_pre_turn_checkpoint_id
from app.schemas.new_chat import RegenerateRequest
def _cp(checkpoint_id: str, turn_id: str | None) -> SimpleNamespace:
"""Build a fake ``CheckpointTuple`` with the metadata shape we read."""
return SimpleNamespace(
config={"configurable": {"checkpoint_id": checkpoint_id}},
metadata={"turn_id": turn_id} if turn_id is not None else {},
)
class TestFindPreTurnCheckpointId:
def test_returns_last_pre_turn_checkpoint_when_editing_latest_turn(self) -> None:
# Newest-first: T2 is the most-recent turn. The latest non-T2
# checkpoint (cp2) is the rewind target — state immediately
# before T2 began.
tuples = [
_cp("cp4", "T2"),
_cp("cp3", "T2"),
_cp("cp2", "T1"),
_cp("cp1", "T1"),
]
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
def test_returns_pre_turn_checkpoint_when_later_turns_exist(self) -> None:
# Regression for the bug where walking newest-first returned the
# FIRST cp with ``turn_id != target`` — which is one of the
# later-turn checkpoints, NOT the pre-turn boundary. Editing
# T2 must rewind to the latest T1 checkpoint (cp2), not to the
# latest T3 checkpoint (cp6).
tuples = [
_cp("cp6", "T3"),
_cp("cp5", "T3"),
_cp("cp4", "T2"),
_cp("cp3", "T2"),
_cp("cp2", "T1"),
_cp("cp1", "T1"),
]
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
def test_returns_none_when_editing_first_turn(self) -> None:
# No pre-turn boundary exists; caller is expected to fall back
# to the oldest checkpoint or special-case "first turn of the
# thread".
tuples = [
_cp("cp4", "T2"),
_cp("cp3", "T2"),
_cp("cp2", "T1"),
_cp("cp1", "T1"),
]
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T1") is None
def test_returns_none_when_only_edited_turn_present(self) -> None:
tuples = [_cp("cp2", "T2"), _cp("cp1", "T2")]
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") is None
def test_returns_none_for_empty_history(self) -> None:
assert _find_pre_turn_checkpoint_id([], turn_id="T1") is None
def test_legacy_checkpoints_without_turn_id_count_as_pre_turn(self) -> None:
# Checkpoints written before migration 136 have no
# ``metadata.turn_id``. They should be eligible rewind targets
# — they came before the
# edited turn began.
tuples = [
_cp("cp3", "T2"),
SimpleNamespace(
config={"configurable": {"checkpoint_id": "cp2"}},
metadata=None,
),
_cp("cp1", "T1"),
]
# Walking oldest-first: cp1(T1) tracked, cp2(legacy/None) tracked,
# then cp3(T2) crosses the boundary -> return cp2.
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
def test_skips_checkpoint_missing_checkpoint_id_in_config(self) -> None:
# If a checkpoint tuple's ``config["configurable"]`` is missing
# the ``checkpoint_id`` key (corrupt / partial), we keep the
# last known good target instead of crashing.
broken = SimpleNamespace(
config={"configurable": {}}, metadata={"turn_id": "T1"}
)
tuples = [
_cp("cp3", "T2"),
broken,
_cp("cp1", "T1"),
]
# cp1(T1) tracked, broken skipped, cp3(T2) -> return cp1.
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp1"
class TestRegenerateRequestValidation:
def test_revert_actions_requires_from_message_id(self) -> None:
with pytest.raises(Exception) as exc:
RegenerateRequest(
search_space_id=1,
user_query="hi",
revert_actions=True,
)
msg = str(exc.value).lower()
assert "from_message_id" in msg
def test_from_message_id_without_revert_is_allowed(self) -> None:
req = RegenerateRequest(
search_space_id=1,
user_query="hi",
from_message_id=42,
)
assert req.from_message_id == 42
assert req.revert_actions is False
def test_revert_actions_with_from_message_id_passes(self) -> None:
req = RegenerateRequest(
search_space_id=1,
user_query="hi",
from_message_id=42,
revert_actions=True,
)
assert req.revert_actions is True

View file

@ -0,0 +1,530 @@
"""Unit tests for ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
The per-turn batch revert route walks rows in reverse ``created_at``
order, reverts each independently, and returns a per-action result
list. Partial success is normal the response status
is ``"partial"`` whenever any row could not be reverted, but we never
collapse the whole batch into a 4xx.
These tests stub ``load_thread`` / ``revert_action`` and feed a fake
session, so they exercise the route's dispatch logic without a real DB.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.routes import agent_revert_route
from app.services.revert_service import RevertOutcome
@dataclass
class _FakeAction:
id: int
tool_name: str
user_id: str | None = "u1"
reverse_of: int | None = None
error: dict | None = None
@dataclass
class _FakeUser:
id: str = "u1"
@dataclass
class _ScalarResult:
rows: list[Any]
def first(self) -> Any:
return self.rows[0] if self.rows else None
def all(self) -> list[Any]:
return list(self.rows)
@dataclass
class _Result:
rows: list[Any] = field(default_factory=list)
def scalars(self) -> _ScalarResult:
return _ScalarResult(self.rows)
def all(self) -> list[Any]:
# ``_was_already_reverted_batch`` calls ``.all()`` directly on
# the row-tuple result (no ``.scalars()`` indirection). The
# rows queued for that helper are list[(revert_id, original_id)].
return list(self.rows)
class _FakeNestedCtx:
"""Async context manager that mimics ``session.begin_nested()``.
The route raises a sentinel exception inside this block to roll back
bad rows. We just pass the exception through.
"""
async def __aenter__(self) -> _FakeNestedCtx:
return self
async def __aexit__(self, exc_type, exc, tb) -> bool:
# Returning False (or None) propagates the exception; the route
# catches its own sentinel above this layer.
return False
class _FakeSession:
"""Minimal AsyncSession stand-in for the revert-turn route.
Holds a queue of result objects; each ``execute(...)`` pops the next
one. The route calls ``execute`` exactly once per query so this maps
cleanly onto the assertion order of the test.
"""
def __init__(self) -> None:
self._results: list[_Result] = []
self.committed = False
self.rolled_back = False
# Count execute() calls to assert "no N+1 reverts".
self.execute_call_count = 0
def queue(self, *results: _Result) -> None:
self._results.extend(results)
async def execute(self, _stmt: Any) -> _Result:
self.execute_call_count += 1
if not self._results:
return _Result(rows=[])
return self._results.pop(0)
def begin_nested(self) -> _FakeNestedCtx:
return _FakeNestedCtx()
async def commit(self) -> None:
self.committed = True
async def rollback(self) -> None:
self.rolled_back = True
def _enabled_flags() -> AgentFeatureFlags:
return AgentFeatureFlags(
disable_new_agent_stack=False,
enable_action_log=True,
enable_revert_route=True,
)
@pytest.fixture
def patch_get_flags():
def _patch(flags: AgentFeatureFlags):
return patch(
"app.routes.agent_revert_route.get_flags",
return_value=flags,
)
return _patch
class TestFlagGuard:
@pytest.mark.asyncio
async def test_returns_503_when_revert_route_disabled(
self, patch_get_flags
) -> None:
flags = AgentFeatureFlags(
disable_new_agent_stack=False,
enable_action_log=True,
enable_revert_route=False,
)
session = _FakeSession()
with patch_get_flags(flags), pytest.raises(Exception) as exc:
await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="42:1700000000000",
session=session,
user=_FakeUser(),
)
assert getattr(exc.value, "status_code", None) == 503
class TestRevertTurnDispatch:
@pytest.mark.asyncio
async def test_empty_turn_returns_ok_with_no_rows(self, patch_get_flags) -> None:
session = _FakeSession()
session.queue(_Result(rows=[])) # rows query returns nothing
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-empty",
session=session,
user=_FakeUser(),
)
assert response.status == "ok"
assert response.total == 0
assert response.results == []
assert session.committed is True
@pytest.mark.asyncio
async def test_walks_rows_in_reverse_and_reverts_each(
self, patch_get_flags
) -> None:
rows = [
_FakeAction(id=10, tool_name="rm"),
_FakeAction(id=9, tool_name="write_file"),
_FakeAction(id=8, tool_name="mkdir"),
]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Single batched ``_was_already_reverted_batch`` probe replaces
# the previous N per-row SELECTs.
session.queue(_Result(rows=[]))
async def _fake_revert(_session, *, action, requester_user_id):
return RevertOutcome(
status="ok",
message=f"reverted-{action.id}",
new_action_id=100 + action.id,
)
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
),
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-3",
session=session,
user=_FakeUser(),
)
assert response.status == "ok"
assert response.total == 3
assert response.reverted == 3
assert [r.action_id for r in response.results] == [10, 9, 8]
assert all(r.status == "reverted" for r in response.results)
assert response.results[0].new_action_id == 110
# Only TWO ``execute`` calls regardless of the row count: one
# for the rows query, one for the batched
# ``_was_already_reverted_batch`` probe. Regression guard
# against re-introducing the per-row N+1 lookup.
assert session.execute_call_count == 2, (
"revert-turn loop must batch idempotency probes; got "
f"{session.execute_call_count} execute() calls (expected 2)."
)
@pytest.mark.asyncio
async def test_already_reverted_rows_are_marked_idempotent(
self, patch_get_flags
) -> None:
rows = [_FakeAction(id=5, tool_name="edit_file")]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Batch probe returns ``[(revert_id, original_id)]``.
session.queue(_Result(rows=[(42, 5)]))
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-i",
session=session,
user=_FakeUser(),
)
assert response.status == "ok"
assert response.already_reverted == 1
assert response.results[0].status == "already_reverted"
assert response.results[0].new_action_id == 42
revert.assert_not_called()
@pytest.mark.asyncio
async def test_revert_action_skips_existing_revert_rows(
self, patch_get_flags
) -> None:
rows = [_FakeAction(id=99, tool_name="_revert:edit_file", reverse_of=42)]
session = _FakeSession()
session.queue(_Result(rows=rows))
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-rev",
session=session,
user=_FakeUser(),
)
assert response.status == "ok"
assert response.results[0].status == "skipped"
revert.assert_not_called()
@pytest.mark.asyncio
async def test_partial_success_when_some_rows_not_reversible(
self, patch_get_flags
) -> None:
rows = [
_FakeAction(id=2, tool_name="send_email"),
_FakeAction(id=1, tool_name="edit_file"),
]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Single batched idempotency probe.
session.queue(_Result(rows=[]))
async def _fake_revert(_session, *, action, requester_user_id):
if action.tool_name == "send_email":
return RevertOutcome(
status="not_reversible",
message="connector revert not yet implemented",
)
return RevertOutcome(status="ok", message="ok", new_action_id=500)
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
),
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-mix",
session=session,
user=_FakeUser(),
)
assert response.status == "partial"
assert response.reverted == 1
assert response.not_reversible == 1
statuses = sorted(r.status for r in response.results)
assert statuses == ["not_reversible", "reverted"]
@pytest.mark.asyncio
async def test_unexpected_exception_marks_row_failed_not_batch(
self, patch_get_flags
) -> None:
rows = [
_FakeAction(id=20, tool_name="edit_file"),
_FakeAction(id=21, tool_name="edit_file"),
]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Single batched idempotency probe.
session.queue(_Result(rows=[]))
async def _fake_revert(_session, *, action, requester_user_id):
if action.id == 20:
raise RuntimeError("disk on fire")
return RevertOutcome(status="ok", message="ok", new_action_id=999)
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
),
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-fail",
session=session,
user=_FakeUser(),
)
assert response.status == "partial"
assert response.failed == 1
assert response.reverted == 1
bad = next(r for r in response.results if r.action_id == 20)
assert bad.status == "failed"
assert "disk on fire" in (bad.error or "")
good = next(r for r in response.results if r.action_id == 21)
assert good.status == "reverted"
@pytest.mark.asyncio
async def test_permission_denied_when_other_user_owns_action(
self, patch_get_flags
) -> None:
rows = [_FakeAction(id=7, tool_name="edit_file", user_id="someone-else")]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Batch idempotency probe (no prior reverts).
session.queue(_Result(rows=[]))
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-perm",
session=session,
user=_FakeUser(id="not-owner"),
)
assert response.status == "partial"
assert response.results[0].status == "permission_denied"
# ``permission_denied`` has its own dedicated counter so the
# response invariant ``total == sum(counters)`` always holds
# without overloading ``not_reversible`` (which historically
# absorbed this case and confused frontend toasts).
assert response.permission_denied == 1
assert response.not_reversible == 0
revert.assert_not_called()
@pytest.mark.asyncio
async def test_counter_invariant_holds_across_mixed_outcomes(
self, patch_get_flags
) -> None:
"""Every row is accounted for in EXACTLY ONE counter.
Mixes one of every supported outcome (reverted, already_reverted,
not_reversible, permission_denied, failed, skipped) and asserts
that the sum of counters equals ``response.total``.
"""
rows = [
_FakeAction(id=10, tool_name="edit_file"), # ok
_FakeAction(id=9, tool_name="edit_file"), # already_reverted
_FakeAction(id=8, tool_name="send_email"), # not_reversible
_FakeAction(id=7, tool_name="rm", user_id="other"), # permission_denied
_FakeAction(id=6, tool_name="edit_file"), # failed
_FakeAction(id=5, tool_name="_revert:edit_file", reverse_of=99), # skipped
]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Single batched probe; only id=9 has a prior revert.
# Schema: list[(revert_id, original_id)].
session.queue(_Result(rows=[(42, 9)]))
async def _fake_revert(_session, *, action, requester_user_id):
if action.id == 10:
return RevertOutcome(status="ok", message="ok", new_action_id=500)
if action.id == 8:
return RevertOutcome(
status="not_reversible",
message="connector revert not yet implemented",
)
if action.id == 6:
raise RuntimeError("boom")
raise AssertionError(f"unexpected revert call for {action.id}")
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(
agent_revert_route,
"revert_action",
AsyncMock(side_effect=_fake_revert),
),
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-mixed-all",
session=session,
user=_FakeUser(), # only id=7 has a different user_id
)
assert response.total == len(rows) == 6
bucket_sum = (
response.reverted
+ response.already_reverted
+ response.not_reversible
+ response.permission_denied
+ response.failed
+ response.skipped
)
assert bucket_sum == response.total, (
"Counter invariant broken: total "
f"({response.total}) != sum of counters ({bucket_sum}). "
f"Counters: reverted={response.reverted}, "
f"already_reverted={response.already_reverted}, "
f"not_reversible={response.not_reversible}, "
f"permission_denied={response.permission_denied}, "
f"failed={response.failed}, skipped={response.skipped}"
)
assert response.reverted == 1
assert response.already_reverted == 1
assert response.not_reversible == 1
assert response.permission_denied == 1
assert response.failed == 1
assert response.skipped == 1
@pytest.mark.asyncio
async def test_integrity_error_translates_to_already_reverted(
self, patch_get_flags
) -> None:
"""The partial unique index on ``reverse_of`` raises
``IntegrityError`` when a concurrent revert wins the race against
the pre-flight ``_was_already_reverted`` SELECT. The route MUST
recover by re-querying for the winning revert id and returning
``status="already_reverted"`` (not ``"failed"``) so racing
clients see consistent idempotent semantics.
"""
from sqlalchemy.exc import IntegrityError
rows = [_FakeAction(id=33, tool_name="edit_file")]
session = _FakeSession()
session.queue(_Result(rows=rows))
# Batch pre-flight probe: nothing yet (we'll race).
session.queue(_Result(rows=[]))
# Post-IntegrityError fallback uses the SCALAR
# ``_was_already_reverted`` (single-id lookup) so it pulls
# ``[777]`` via ``.scalars().first()``.
session.queue(_Result(rows=[777]))
async def _racing_revert(_session, *, action, requester_user_id):
raise IntegrityError("INSERT", {}, Exception("dup reverse_of"))
with (
patch_get_flags(_enabled_flags()),
patch.object(
agent_revert_route, "load_thread", AsyncMock(return_value=object())
),
patch.object(
agent_revert_route,
"revert_action",
AsyncMock(side_effect=_racing_revert),
),
):
response = await agent_revert_route.revert_agent_turn(
thread_id=1,
chat_turn_id="ct-race",
session=session,
user=_FakeUser(),
)
assert response.failed == 0, (
"IntegrityError must NOT surface as a failed row; the unique "
"index is the durable expression of idempotency."
)
assert response.already_reverted == 1
assert response.results[0].status == "already_reverted"
assert response.results[0].new_action_id == 777

View file

@ -0,0 +1,921 @@
from __future__ import annotations
from dataclasses import dataclass
from types import SimpleNamespace
import pytest
from app.services.auto_model_pin_service import (
clear_healthy,
clear_runtime_cooldown,
is_recently_healthy,
mark_healthy,
mark_runtime_cooldown,
resolve_or_get_pinned_llm_config_id,
)
pytestmark = pytest.mark.unit
@pytest.fixture(autouse=True)
def _clear_runtime_cooldown_map():
clear_runtime_cooldown()
clear_healthy()
yield
clear_runtime_cooldown()
clear_healthy()
@dataclass
class _FakeQuotaResult:
allowed: bool
class _FakeExecResult:
def __init__(self, thread):
self._thread = thread
def unique(self):
return self
def scalar_one_or_none(self):
return self._thread
class _FakeSession:
def __init__(self, thread):
self.thread = thread
self.commit_count = 0
async def execute(self, _stmt):
return _FakeExecResult(self.thread)
async def commit(self):
self.commit_count += 1
def _thread(
*,
search_space_id: int = 10,
pinned_llm_config_id: int | None = None,
):
return SimpleNamespace(
id=1,
search_space_id=search_space_id,
pinned_llm_config_id=pinned_llm_config_id,
)
@pytest.mark.asyncio
async def test_auto_first_turn_pins_one_model(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
{
"id": -1,
"provider": "OPENAI",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
},
],
)
async def _allowed(*_args, **_kwargs):
return _FakeQuotaResult(allowed=True)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id in {-1, -2}
assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id
assert session.commit_count == 1
@pytest.mark.asyncio
async def test_next_turn_reuses_existing_pin(monkeypatch):
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENAI",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
},
],
)
async def _must_not_call(*_args, **_kwargs):
raise AssertionError(
"premium_get_usage should not be called for valid pin reuse"
)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_must_not_call,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -1
assert result.from_existing_pin is True
assert session.commit_count == 0
@pytest.mark.asyncio
async def test_premium_eligible_auto_can_pin_premium(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENAI",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
},
],
)
async def _allowed(*_args, **_kwargs):
return _FakeQuotaResult(allowed=True)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -1
assert result.resolved_tier == "premium"
@pytest.mark.asyncio
async def test_premium_ineligible_auto_pins_free_only(monkeypatch):
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -2,
"provider": "OPENAI",
"model_name": "gpt-free",
"api_key": "k1",
"billing_tier": "free",
},
{
"id": -1,
"provider": "OPENAI",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
},
],
)
async def _blocked(*_args, **_kwargs):
return _FakeQuotaResult(allowed=False)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_blocked,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -2
assert result.resolved_tier == "free"
@pytest.mark.asyncio
async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch):
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -2,
"provider": "OPENAI",
"model_name": "gpt-free",
"api_key": "k1",
"billing_tier": "free",
},
{
"id": -1,
"provider": "OPENAI",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
},
],
)
async def _blocked(*_args, **_kwargs):
return _FakeQuotaResult(allowed=False)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_blocked,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -1
assert result.from_existing_pin is True
@pytest.mark.asyncio
async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch):
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -2,
"provider": "OPENAI",
"model_name": "gpt-free",
"api_key": "k1",
"billing_tier": "free",
},
{
"id": -1,
"provider": "OPENAI",
"model_name": "gpt-prem",
"api_key": "k2",
"billing_tier": "premium",
},
],
)
async def _blocked(*_args, **_kwargs):
return _FakeQuotaResult(allowed=False)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_blocked,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
force_repin_free=True,
)
assert result.resolved_llm_config_id == -2
assert result.resolved_tier == "free"
assert result.from_existing_pin is False
assert session.thread.pinned_llm_config_id == -2
@pytest.mark.asyncio
async def test_explicit_user_model_change_clears_pin(monkeypatch):
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-2))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
],
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=7,
)
assert result.resolved_llm_config_id == 7
assert session.thread.pinned_llm_config_id is None
assert session.commit_count == 1
@pytest.mark.asyncio
async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-999))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
],
)
async def _allowed(*_args, **_kwargs):
return _FakeQuotaResult(allowed=True)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -2
assert session.thread.pinned_llm_config_id == -2
assert session.commit_count == 1
# ---------------------------------------------------------------------------
# Quality-aware pin selection (Auto Fastest upgrade)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_health_gated_config_is_excluded_from_selection(monkeypatch):
"""A cfg flagged ``health_gated`` must never be picked even if it has
the highest score among eligible cfgs."""
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "venice/dead-model",
"api_key": "k1",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 95,
"health_gated": True,
},
{
"id": -2,
"provider": "OPENROUTER",
"model_name": "google/gemini-flash",
"api_key": "k1",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 60,
"health_gated": False,
},
],
)
async def _blocked(*_args, **_kwargs):
return _FakeQuotaResult(allowed=False)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_blocked,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -2
@pytest.mark.asyncio
async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch):
"""Premium-eligible users with Tier A available should never spill to
Tier B even if a B cfg ranks higher by ``quality_score``."""
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5",
"api_key": "k-yaml",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 70,
"health_gated": False,
},
{
"id": -2,
"provider": "OPENROUTER",
"model_name": "openai/gpt-5",
"api_key": "k-or",
"billing_tier": "premium",
"auto_pin_tier": "B",
"quality_score": 95,
"health_gated": False,
},
],
)
async def _allowed(*_args, **_kwargs):
return _FakeQuotaResult(allowed=True)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -1
assert result.resolved_tier == "premium"
@pytest.mark.asyncio
async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch):
"""Free-only user with no Tier A free cfg should pick from Tier C."""
from app.config import config
session = _FakeSession(_thread())
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5",
"api_key": "k-yaml",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 100,
"health_gated": False,
},
{
"id": -2,
"provider": "OPENROUTER",
"model_name": "google/gemini-flash:free",
"api_key": "k-or",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 60,
"health_gated": False,
},
],
)
async def _blocked(*_args, **_kwargs):
return _FakeQuotaResult(allowed=False)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_blocked,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -2
@pytest.mark.asyncio
async def test_top_k_picks_only_high_score_models(monkeypatch):
"""Different thread IDs should spread across top-K, never pick the
obvious low-quality cfg even when it sits in the candidate list."""
from app.config import config
high_score_cfgs = [
{
"id": -i,
"provider": "AZURE_OPENAI",
"model_name": f"gpt-x-{i}",
"api_key": "k",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 90,
"health_gated": False,
}
for i in range(1, 6) # 5 high-quality Tier A cfgs
]
low_score_trap = {
"id": -99,
"provider": "AZURE_OPENAI",
"model_name": "tiny-legacy",
"api_key": "k",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 10,
"health_gated": False,
}
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[*high_score_cfgs, low_score_trap],
)
async def _allowed(*_args, **_kwargs):
return _FakeQuotaResult(allowed=True)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_allowed,
)
high_score_ids = {c["id"] for c in high_score_cfgs}
seen = set()
for thread_id in range(1, 50):
session = _FakeSession(_thread())
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=thread_id,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
seen.add(result.resolved_llm_config_id)
assert result.resolved_llm_config_id != -99, (
"low-score trap cfg should never be picked"
)
assert result.resolved_llm_config_id in high_score_ids
# Spread across at least a couple of top-K cfgs.
assert len(seen) > 1
@pytest.mark.asyncio
async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch):
"""An *already* pinned cfg that later flips to ``health_gated`` should
still not be reused gated cfgs are filtered out of the candidate
pool, which forces a repair to a healthy cfg.
This guards the no-silent-tier-switch invariant: we don't keep using
a known-broken model just because the thread happened to be pinned
to it before the gate fired."""
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "venice/dead-model",
"api_key": "k",
"billing_tier": "premium",
"auto_pin_tier": "B",
"quality_score": 50,
"health_gated": True,
},
{
"id": -2,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5",
"api_key": "k",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 90,
"health_gated": False,
},
],
)
async def _allowed(*_args, **_kwargs):
return _FakeQuotaResult(allowed=True)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_allowed,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -2
assert result.from_existing_pin is False
@pytest.mark.asyncio
async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch):
"""Existing pin reuse must short-circuit the new tier/score logic."""
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5",
"api_key": "k",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 50, # lower than -2
"health_gated": False,
},
{
"id": -2,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5-pro",
"api_key": "k",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score": 99,
"health_gated": False,
},
],
)
async def _must_not_call(*_args, **_kwargs):
raise AssertionError("premium_get_usage should not run on pin reuse")
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_must_not_call,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -1
assert result.from_existing_pin is True
assert session.commit_count == 0
@pytest.mark.asyncio
async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
"""A runtime-cooled config should be excluded from candidate reuse.
This enables one-shot recovery from transient provider 429 bursts: we can
mark the pinned cfg as cooled down and force a repair to another eligible
cfg on the next resolution.
"""
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "google/gemma-4-26b-a4b-it:free",
"api_key": "k",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 90,
"health_gated": False,
},
{
"id": -2,
"provider": "OPENROUTER",
"model_name": "google/gemini-2.5-flash:free",
"api_key": "k",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 80,
"health_gated": False,
},
],
)
async def _blocked(*_args, **_kwargs):
return _FakeQuotaResult(allowed=False)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_blocked,
)
mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -2
assert result.from_existing_pin is False
@pytest.mark.asyncio
async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch):
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "google/gemma-4-26b-a4b-it:free",
"api_key": "k",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 90,
"health_gated": False,
},
],
)
async def _must_not_call(*_args, **_kwargs):
raise AssertionError("premium_get_usage should not run on healthy pin reuse")
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_must_not_call,
)
mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600)
clear_runtime_cooldown(-1)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
)
assert result.resolved_llm_config_id == -1
assert result.from_existing_pin is True
@pytest.mark.asyncio
async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypatch):
"""Runtime retry should never repin the just-failed config."""
from app.config import config
session = _FakeSession(_thread(pinned_llm_config_id=-1))
monkeypatch.setattr(
config,
"GLOBAL_LLM_CONFIGS",
[
{
"id": -1,
"provider": "OPENROUTER",
"model_name": "google/gemma-4-26b-a4b-it:free",
"api_key": "k",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 90,
"health_gated": False,
},
{
"id": -2,
"provider": "OPENROUTER",
"model_name": "google/gemini-2.5-flash:free",
"api_key": "k",
"billing_tier": "free",
"auto_pin_tier": "C",
"quality_score": 80,
"health_gated": False,
},
],
)
async def _blocked(*_args, **_kwargs):
return _FakeQuotaResult(allowed=False)
monkeypatch.setattr(
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
_blocked,
)
result = await resolve_or_get_pinned_llm_config_id(
session,
thread_id=1,
search_space_id=10,
user_id="00000000-0000-0000-0000-000000000001",
selected_llm_config_id=0,
exclude_config_ids={-1},
)
assert result.resolved_llm_config_id == -2
assert result.from_existing_pin is False
# ---------------------------------------------------------------------------
# Healthy-status cache (preflight TTL companion)
# ---------------------------------------------------------------------------
def test_mark_healthy_then_is_recently_healthy_true_within_ttl():
mark_healthy(-42, ttl_seconds=60)
assert is_recently_healthy(-42) is True
def test_healthy_expires_after_ttl(monkeypatch):
import app.services.auto_model_pin_service as svc
real_time = svc.time.time
base = real_time()
monkeypatch.setattr(svc.time, "time", lambda: base)
mark_healthy(-7, ttl_seconds=10)
assert is_recently_healthy(-7) is True
monkeypatch.setattr(svc.time, "time", lambda: base + 11)
assert is_recently_healthy(-7) is False
def test_mark_runtime_cooldown_invalidates_healthy_cache():
mark_healthy(-9, ttl_seconds=60)
assert is_recently_healthy(-9) is True
mark_runtime_cooldown(-9, reason="test", cooldown_seconds=60)
assert is_recently_healthy(-9) is False
def test_clear_healthy_removes_single_entry():
mark_healthy(-11, ttl_seconds=60)
mark_healthy(-12, ttl_seconds=60)
clear_healthy(-11)
assert is_recently_healthy(-11) is False
assert is_recently_healthy(-12) is True
def test_clear_healthy_no_args_drops_all_entries():
mark_healthy(-21, ttl_seconds=60)
mark_healthy(-22, ttl_seconds=60)
clear_healthy()
assert is_recently_healthy(-21) is False
assert is_recently_healthy(-22) is False

View file

@ -0,0 +1,226 @@
"""LLMRouterService pool-filter / rebuild tests.
These tests focus on the *config plumbing* (which configs enter the router
pool, rebuild resets state correctly). They stub out the underlying
``litellm.Router`` so we don't need real API keys or network access.
"""
from __future__ import annotations
from unittest.mock import patch
import pytest
from app.services.llm_router_service import LLMRouterService
pytestmark = pytest.mark.unit
def _fake_yaml_config(
*,
id: int,
model_name: str,
billing_tier: str = "free",
) -> dict:
return {
"id": id,
"name": f"yaml-{id}",
"provider": "OPENAI",
"model_name": model_name,
"api_key": "sk-test",
"api_base": "",
"billing_tier": billing_tier,
"rpm": 100,
"tpm": 100_000,
"litellm_params": {},
}
def _fake_openrouter_config(
*,
id: int,
model_name: str,
billing_tier: str,
router_pool_eligible: bool | None = None,
) -> dict:
"""Build a synthetic dynamic-OR config dict for router-pool tests.
Defaults mirror Strategy 3: premium OR enters the pool, free OR stays
out. Callers can override ``router_pool_eligible`` to simulate legacy
configs or to regression-test the filter mechanics directly.
"""
if router_pool_eligible is None:
router_pool_eligible = billing_tier == "premium"
return {
"id": id,
"name": f"or-{id}",
"provider": "OPENROUTER",
"model_name": model_name,
"api_key": "sk-or-test",
"api_base": "",
"billing_tier": billing_tier,
"rpm": 20 if billing_tier == "free" else 200,
"tpm": 100_000 if billing_tier == "free" else 1_000_000,
"litellm_params": {},
"router_pool_eligible": router_pool_eligible,
}
def _reset_router_singleton() -> None:
instance = LLMRouterService.get_instance()
instance._initialized = False
instance._router = None
instance._model_list = []
instance._premium_model_strings = set()
def test_router_pool_includes_or_premium_excludes_or_free():
"""Strategy 3: premium OR joins the pool, free OR stays out.
Dynamic OpenRouter premium entries opt into load balancing alongside
curated YAML configs. Dynamic OR free entries are intentionally kept
out because OpenRouter's free tier enforces a single account-global
quota bucket that per-deployment router accounting can't represent.
"""
_reset_router_singleton()
configs = [
_fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
_fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"),
_fake_openrouter_config(
id=-10_001, model_name="openai/gpt-4o", billing_tier="premium"
),
_fake_openrouter_config(
id=-10_002,
model_name="meta-llama/llama-3.3-70b:free",
billing_tier="free",
),
]
with (
patch("app.services.llm_router_service.Router") as mock_router,
patch(
"app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
) as mock_ctx_fb,
):
mock_ctx_fb.side_effect = lambda ml: (ml, None)
mock_router.return_value = object()
LLMRouterService.initialize(configs)
pool_models = {
dep["litellm_params"]["model"]
for dep in LLMRouterService.get_instance()._model_list
}
# YAML premium + YAML free + dynamic OR premium are all in the pool.
# Dynamic OR free is NOT (shared-bucket rate limits can't be load-balanced).
assert pool_models == {
"openai/gpt-4o",
"openai/gpt-4o-mini",
"openrouter/openai/gpt-4o",
}
prem = LLMRouterService.get_instance()._premium_model_strings
# YAML premium is fingerprinted under both its model_string and its
# ``base_model`` form (existing behavior we don't want to regress).
assert "openai/gpt-4o" in prem
# Dynamic OR premium is now fingerprinted as premium so pool-level
# calls through the router are billed against premium quota.
assert "openrouter/openai/gpt-4o" in prem
assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is True
# Dynamic OR free never enters the pool, so it's never counted as premium.
assert (
LLMRouterService.is_premium_model("openrouter/meta-llama/llama-3.3-70b:free")
is False
)
def test_router_pool_filter_mechanics_respect_override():
"""The ``router_pool_eligible`` filter itself works independently of tier.
Regression guard: if a future refactor ever sets the flag False on a
premium config (e.g. for maintenance), that config MUST be skipped by
``initialize`` even though its tier is premium.
"""
_reset_router_singleton()
configs = [
_fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
_fake_openrouter_config(
id=-10_001,
model_name="openai/gpt-4o",
billing_tier="premium",
router_pool_eligible=False, # opt out despite being premium
),
]
with (
patch("app.services.llm_router_service.Router") as mock_router,
patch(
"app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
) as mock_ctx_fb,
):
mock_ctx_fb.side_effect = lambda ml: (ml, None)
mock_router.return_value = object()
LLMRouterService.initialize(configs)
pool_models = {
dep["litellm_params"]["model"]
for dep in LLMRouterService.get_instance()._model_list
}
assert pool_models == {"openai/gpt-4o"}
assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is False
def test_rebuild_refreshes_pool_after_configs_change():
_reset_router_singleton()
configs_v1 = [
_fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
]
configs_v2 = [
*configs_v1,
_fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"),
]
with (
patch("app.services.llm_router_service.Router") as mock_router,
patch(
"app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
) as mock_ctx_fb,
):
mock_ctx_fb.side_effect = lambda ml: (ml, None)
mock_router.return_value = object()
LLMRouterService.initialize(configs_v1)
assert len(LLMRouterService.get_instance()._model_list) == 1
# ``initialize`` should be a no-op here (already initialized).
LLMRouterService.initialize(configs_v2)
assert len(LLMRouterService.get_instance()._model_list) == 1
# ``rebuild`` must clear the guard and re-run with the new configs.
LLMRouterService.rebuild(configs_v2)
assert len(LLMRouterService.get_instance()._model_list) == 2
def test_auto_model_pin_candidates_include_dynamic_openrouter():
"""Dynamic OR configs must remain Auto-mode thread-pin candidates.
Guards against a future regression where someone adds the
``router_pool_eligible`` filter to ``auto_model_pin_service._global_candidates``.
"""
from app.config import config
from app.services.auto_model_pin_service import _global_candidates
or_premium = _fake_openrouter_config(
id=-10_001, model_name="openai/gpt-4o", billing_tier="premium"
)
or_free = _fake_openrouter_config(
id=-10_002,
model_name="meta-llama/llama-3.3-70b:free",
billing_tier="free",
)
original = config.GLOBAL_LLM_CONFIGS
try:
config.GLOBAL_LLM_CONFIGS = [or_premium, or_free]
candidate_ids = {c["id"] for c in _global_candidates()}
assert candidate_ids == {-10_001, -10_002}
finally:
config.GLOBAL_LLM_CONFIGS = original

View file

@ -0,0 +1,216 @@
"""Unit tests for the dynamic OpenRouter integration."""
from __future__ import annotations
import pytest
from app.services.openrouter_integration_service import (
_OPENROUTER_DYNAMIC_MARKER,
_generate_configs,
_openrouter_tier,
_stable_config_id,
)
pytestmark = pytest.mark.unit
def _minimal_openrouter_model(
*,
model_id: str,
pricing: dict | None = None,
name: str | None = None,
) -> dict:
"""Return a synthetic OpenRouter /api/v1/models entry.
The real API payload includes a lot of fields; we only populate what
``_generate_configs`` actually inspects (architecture, tool support,
context, pricing, id).
"""
return {
"id": model_id,
"name": name or model_id,
"architecture": {"output_modalities": ["text"]},
"supported_parameters": ["tools"],
"context_length": 200_000,
"pricing": pricing or {"prompt": "0.000003", "completion": "0.000015"},
}
# ---------------------------------------------------------------------------
# _openrouter_tier
# ---------------------------------------------------------------------------
def test_openrouter_tier_free_suffix():
assert _openrouter_tier({"id": "foo/bar:free"}) == "free"
def test_openrouter_tier_zero_pricing():
model = {
"id": "foo/bar",
"pricing": {"prompt": "0", "completion": "0"},
}
assert _openrouter_tier(model) == "free"
def test_openrouter_tier_paid():
model = {
"id": "foo/bar",
"pricing": {"prompt": "0.000003", "completion": "0.000015"},
}
assert _openrouter_tier(model) == "premium"
def test_openrouter_tier_missing_pricing_is_premium():
assert _openrouter_tier({"id": "foo/bar"}) == "premium"
assert _openrouter_tier({"id": "foo/bar", "pricing": {}}) == "premium"
# ---------------------------------------------------------------------------
# _stable_config_id
# ---------------------------------------------------------------------------
def test_stable_config_id_deterministic():
taken1: set[int] = set()
taken2: set[int] = set()
a = _stable_config_id("openai/gpt-4o", -10_000, taken1)
b = _stable_config_id("openai/gpt-4o", -10_000, taken2)
assert a == b
assert a < 0
def test_stable_config_id_collision_decrements():
"""When two model_ids hash to the same slot, the second should decrement."""
taken: set[int] = set()
a = _stable_config_id("openai/gpt-4o", -10_000, taken)
# Force a collision by pre-populating ``taken`` with a slot we know will be
# picked.
taken_forced = {a}
b = _stable_config_id("openai/gpt-4o", -10_000, taken_forced)
assert b != a
assert b == a - 1
assert b in taken_forced
def test_stable_config_id_different_models_different_ids():
taken: set[int] = set()
ids = {
_stable_config_id("openai/gpt-4o", -10_000, taken),
_stable_config_id("anthropic/claude-3.5-sonnet", -10_000, taken),
_stable_config_id("google/gemini-2.0-flash", -10_000, taken),
}
assert len(ids) == 3
def test_stable_config_id_survives_catalogue_churn():
"""Removing a model should not shift other models' IDs (the bug we fix)."""
taken1: set[int] = set()
id_a1 = _stable_config_id("openai/gpt-4o", -10_000, taken1)
_ = _stable_config_id("anthropic/claude-3-haiku", -10_000, taken1)
id_c1 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken1)
taken2: set[int] = set()
id_a2 = _stable_config_id("openai/gpt-4o", -10_000, taken2)
id_c2 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken2)
assert id_a1 == id_a2
assert id_c1 == id_c2
# ---------------------------------------------------------------------------
# _generate_configs
# ---------------------------------------------------------------------------
_SETTINGS_BASE: dict = {
"api_key": "sk-or-test",
"id_offset": -10_000,
"rpm": 200,
"tpm": 1_000_000,
"free_rpm": 20,
"free_tpm": 100_000,
"anonymous_enabled_paid": False,
"anonymous_enabled_free": True,
"quota_reserve_tokens": 4000,
}
def test_generate_configs_respects_tier():
"""Premium OR models opt into the router pool; free OR models stay out.
Strategy-3 split: premium participates in LiteLLM Router load balancing,
free stays excluded because OpenRouter enforces a shared global free-tier
bucket that per-deployment router accounting can't represent.
"""
raw = [
_minimal_openrouter_model(model_id="openai/gpt-4o"),
_minimal_openrouter_model(
model_id="meta-llama/llama-3.3-70b-instruct:free",
pricing={"prompt": "0", "completion": "0"},
),
]
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
by_model = {c["model_name"]: c for c in cfgs}
paid = by_model["openai/gpt-4o"]
assert paid["billing_tier"] == "premium"
assert paid["rpm"] == 200
assert paid["tpm"] == 1_000_000
assert paid["anonymous_enabled"] is False
assert paid["router_pool_eligible"] is True
assert paid[_OPENROUTER_DYNAMIC_MARKER] is True
free = by_model["meta-llama/llama-3.3-70b-instruct:free"]
assert free["billing_tier"] == "free"
assert free["rpm"] == 20
assert free["tpm"] == 100_000
assert free["anonymous_enabled"] is True
assert free["router_pool_eligible"] is False
def test_generate_configs_excludes_upstream_openrouter_free_router():
"""OpenRouter's own ``openrouter/free`` meta-router must never become a card.
The upstream API returns this as a first-class zero-priced model, so
without an explicit blocklist entry it would slip through every other
filter (text output, tool calling, 200k context, non-Amazon) and land
in the selector as a duplicate of the concrete ``:free`` cards. The
exclusion in ``_EXCLUDED_MODEL_IDS`` prevents that.
"""
raw = [
_minimal_openrouter_model(model_id="openai/gpt-4o"),
_minimal_openrouter_model(
model_id="openrouter/free",
pricing={"prompt": "0", "completion": "0"},
),
]
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
model_names = {c["model_name"] for c in cfgs}
assert "openrouter/free" not in model_names
assert "openai/gpt-4o" in model_names
def test_generate_configs_drops_non_text_and_non_tool_models():
raw = [
_minimal_openrouter_model(model_id="openai/gpt-4o"),
{ # image-output model
"id": "openai/dall-e",
"architecture": {"output_modalities": ["image"]},
"supported_parameters": ["tools"],
"context_length": 200_000,
"pricing": {"prompt": "0.01", "completion": "0.01"},
},
{ # text but no tool calling
"id": "openai/completion-only",
"architecture": {"output_modalities": ["text"]},
"supported_parameters": [],
"context_length": 200_000,
"pricing": {"prompt": "0.01", "completion": "0.01"},
},
]
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
model_names = [c["model_name"] for c in cfgs]
assert "openai/gpt-4o" in model_names
assert "openai/dall-e" not in model_names
assert "openai/completion-only" not in model_names

View file

@ -0,0 +1,108 @@
"""Tests for deprecated-key warnings and back-compat in
``load_openrouter_integration_settings``.
"""
from __future__ import annotations
from pathlib import Path
import pytest
pytestmark = pytest.mark.unit
def _write_yaml(tmp_path: Path, body: str) -> Path:
cfg_dir = tmp_path / "app" / "config"
cfg_dir.mkdir(parents=True)
cfg_path = cfg_dir / "global_llm_config.yaml"
cfg_path.write_text(body, encoding="utf-8")
return cfg_path
def _patch_base_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
from app import config as config_module
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
def test_legacy_billing_tier_emits_warning(monkeypatch, tmp_path, capsys):
_write_yaml(
tmp_path,
"""
openrouter_integration:
enabled: true
api_key: "sk-or-test"
billing_tier: "premium"
""".lstrip(),
)
_patch_base_dir(monkeypatch, tmp_path)
from app.config import load_openrouter_integration_settings
settings = load_openrouter_integration_settings()
captured = capsys.readouterr().out
assert settings is not None
assert "billing_tier is deprecated" in captured
def test_legacy_anonymous_enabled_back_compat(monkeypatch, tmp_path, capsys):
_write_yaml(
tmp_path,
"""
openrouter_integration:
enabled: true
api_key: "sk-or-test"
anonymous_enabled: true
""".lstrip(),
)
_patch_base_dir(monkeypatch, tmp_path)
from app.config import load_openrouter_integration_settings
settings = load_openrouter_integration_settings()
captured = capsys.readouterr().out
assert settings is not None
assert settings["anonymous_enabled_paid"] is True
assert settings["anonymous_enabled_free"] is True
assert "anonymous_enabled is" in captured
assert "deprecated" in captured
def test_new_keys_take_priority_over_legacy_back_compat(monkeypatch, tmp_path, capsys):
"""If both legacy and new keys are present, new keys win (setdefault)."""
_write_yaml(
tmp_path,
"""
openrouter_integration:
enabled: true
api_key: "sk-or-test"
anonymous_enabled: true
anonymous_enabled_paid: false
anonymous_enabled_free: false
""".lstrip(),
)
_patch_base_dir(monkeypatch, tmp_path)
from app.config import load_openrouter_integration_settings
settings = load_openrouter_integration_settings()
capsys.readouterr()
assert settings is not None
assert settings["anonymous_enabled_paid"] is False
assert settings["anonymous_enabled_free"] is False
def test_disabled_integration_returns_none(monkeypatch, tmp_path):
_write_yaml(
tmp_path,
"""
openrouter_integration:
enabled: false
api_key: "sk-or-test"
""".lstrip(),
)
_patch_base_dir(monkeypatch, tmp_path)
from app.config import load_openrouter_integration_settings
assert load_openrouter_integration_settings() is None

View file

@ -0,0 +1,331 @@
"""Unit tests for the OpenRouter ``_enrich_health`` background task."""
from __future__ import annotations
from typing import Any
import pytest
from app.services.openrouter_integration_service import (
OpenRouterIntegrationService,
)
from app.services.quality_score import (
_HEALTH_FAIL_RATIO_FALLBACK,
)
pytestmark = pytest.mark.unit
def _or_cfg(
*,
cid: int,
model_name: str,
tier: str = "premium",
static_score: int = 50,
) -> dict:
return {
"id": cid,
"provider": "OPENROUTER",
"model_name": model_name,
"billing_tier": tier,
"auto_pin_tier": "B" if tier == "premium" else "C",
"quality_score_static": static_score,
"quality_score_health": None,
"quality_score": static_score,
"health_gated": False,
}
class _StubResponse:
def __init__(self, *, payload: dict, status_code: int = 200):
self._payload = payload
self.status_code = status_code
def raise_for_status(self) -> None:
if self.status_code >= 400:
raise RuntimeError(f"HTTP {self.status_code}")
def json(self) -> dict:
return self._payload
class _StubAsyncClient:
"""Minimal drop-in for ``httpx.AsyncClient`` used by ``_fetch_endpoints``."""
def __init__(self, responder):
self._responder = responder
self.requests: list[str] = []
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def get(self, url: str, headers: dict | None = None) -> _StubResponse:
self.requests.append(url)
return self._responder(url)
def _patch_async_client(monkeypatch, responder) -> _StubAsyncClient:
"""Replace ``httpx.AsyncClient`` for the duration of the test."""
client = _StubAsyncClient(responder)
monkeypatch.setattr(
"app.services.openrouter_integration_service.httpx.AsyncClient",
lambda *_args, **_kwargs: client,
)
return client
def _healthy_payload() -> dict:
return {
"data": {
"endpoints": [
{
"status": 0,
"uptime_last_30m": 0.99,
"uptime_last_1d": 0.995,
"uptime_last_5m": 0.99,
}
]
}
}
def _unhealthy_payload() -> dict:
return {
"data": {
"endpoints": [
{
"status": 0,
"uptime_last_30m": 0.55,
"uptime_last_1d": 0.62,
"uptime_last_5m": 0.50,
}
]
}
}
# ---------------------------------------------------------------------------
# Bounded fan-out + happy path
# ---------------------------------------------------------------------------
async def test_enrich_health_marks_healthy_and_gates_unhealthy(monkeypatch):
cfgs = [
_or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
_or_cfg(cid=-2, model_name="venice/dead-model", static_score=60),
]
def responder(url: str) -> _StubResponse:
if "anthropic" in url:
return _StubResponse(payload=_healthy_payload())
return _StubResponse(payload=_unhealthy_payload())
_patch_async_client(monkeypatch, responder)
service = OpenRouterIntegrationService()
service._settings = {"api_key": ""}
await service._enrich_health(cfgs)
healthy = next(c for c in cfgs if c["id"] == -1)
gated = next(c for c in cfgs if c["id"] == -2)
assert healthy["health_gated"] is False
assert healthy["quality_score_health"] is not None
assert healthy["quality_score"] >= healthy["quality_score_static"]
assert gated["health_gated"] is True
assert gated["quality_score"] == gated["quality_score_static"]
async def test_enrich_health_only_touches_or_provider(monkeypatch):
"""YAML cfgs that aren't OPENROUTER must be skipped entirely."""
yaml_cfg = {
"id": -1,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5",
"billing_tier": "premium",
"auto_pin_tier": "A",
"quality_score_static": 80,
"quality_score": 80,
"health_gated": False,
}
or_cfg = _or_cfg(cid=-2, model_name="anthropic/claude-haiku")
requests: list[str] = []
def responder(url: str) -> _StubResponse:
requests.append(url)
return _StubResponse(payload=_healthy_payload())
_patch_async_client(monkeypatch, responder)
service = OpenRouterIntegrationService()
service._settings = {}
await service._enrich_health([yaml_cfg, or_cfg])
assert all("anthropic/claude-haiku" in r for r in requests)
# YAML cfg is untouched.
assert yaml_cfg["quality_score"] == 80
assert yaml_cfg["health_gated"] is False
# ---------------------------------------------------------------------------
# Failure ratio fallback
# ---------------------------------------------------------------------------
async def test_enrich_health_falls_back_to_last_good_when_failure_ratio_high(
monkeypatch,
):
"""If >= 25% of fetches fail, keep last-good cache instead of writing
partial data."""
cfgs = [
_or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
_or_cfg(cid=-2, model_name="openai/gpt-5", static_score=80),
_or_cfg(cid=-3, model_name="google/gemini-flash", static_score=65),
_or_cfg(cid=-4, model_name="venice/something", static_score=50),
]
service = OpenRouterIntegrationService()
service._settings = {}
# Pre-seed last-good cache with a known-healthy snapshot.
service._health_cache = {
"anthropic/claude-haiku": {"gated": False, "score": 95.0},
}
def all_fail(_url: str) -> _StubResponse:
return _StubResponse(payload={}, status_code=500)
_patch_async_client(monkeypatch, all_fail)
await service._enrich_health(cfgs)
# Above threshold ⇒ degraded; last-good cache wins for the cached cfg.
cached_hit = next(c for c in cfgs if c["model_name"] == "anthropic/claude-haiku")
assert cached_hit["quality_score_health"] == 95.0
assert cached_hit["health_gated"] is False
# Confirm the threshold constant we're testing against is real.
assert _HEALTH_FAIL_RATIO_FALLBACK <= 1.0
async def test_enrich_health_keeps_static_only_with_no_cache_and_failures(
monkeypatch,
):
"""If a fetch fails and there's no last-good cache, the cfg keeps its
static-only ``quality_score`` and is *not* gated by default."""
cfgs = [
_or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
]
def fail(_url: str) -> _StubResponse:
return _StubResponse(payload={}, status_code=500)
_patch_async_client(monkeypatch, fail)
service = OpenRouterIntegrationService()
service._settings = {}
await service._enrich_health(cfgs)
cfg = cfgs[0]
assert cfg["health_gated"] is False
assert cfg["quality_score"] == cfg["quality_score_static"]
assert cfg["quality_score_health"] is None
# ---------------------------------------------------------------------------
# Last-good cache: success populates, next failure reuses
# ---------------------------------------------------------------------------
async def test_enrich_health_populates_cache_on_success_then_reuses_on_failure(
monkeypatch,
):
cfg = _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70)
service = OpenRouterIntegrationService()
service._settings = {}
def healthy(_url: str) -> _StubResponse:
return _StubResponse(payload=_healthy_payload())
_patch_async_client(monkeypatch, healthy)
await service._enrich_health([cfg])
assert "anthropic/claude-haiku" in service._health_cache
cached_score = service._health_cache["anthropic/claude-haiku"]["score"]
assert cached_score is not None
# Next cycle: enough other healthy cfgs so failure ratio stays below
# the 25% threshold even when this one fails individually.
other_cfgs = [
_or_cfg(cid=-2 - i, model_name=f"healthy/m-{i}", static_score=60)
for i in range(10)
]
cfg["quality_score_health"] = None
cfg["quality_score"] = cfg["quality_score_static"]
def mixed(url: str) -> _StubResponse:
if "anthropic" in url:
return _StubResponse(payload={}, status_code=500)
return _StubResponse(payload=_healthy_payload())
_patch_async_client(monkeypatch, mixed)
await service._enrich_health([cfg, *other_cfgs])
assert cfg["quality_score_health"] == cached_score
assert cfg["health_gated"] is False
# ---------------------------------------------------------------------------
# Bounded fan-out: respects top-N caps
# ---------------------------------------------------------------------------
async def test_enrich_health_bounds_premium_fanout(monkeypatch):
"""Top-N premium cap is honoured even when many cfgs are present."""
from app.services.quality_score import _HEALTH_ENRICH_TOP_N_PREMIUM
cfgs = [
_or_cfg(
cid=-i, model_name=f"openai/m-{i}", tier="premium", static_score=100 - i
)
for i in range(1, _HEALTH_ENRICH_TOP_N_PREMIUM + 20)
]
seen: list[str] = []
def responder(url: str) -> _StubResponse:
seen.append(url)
return _StubResponse(payload=_healthy_payload())
_patch_async_client(monkeypatch, responder)
service = OpenRouterIntegrationService()
service._settings = {}
await service._enrich_health(cfgs)
assert len(seen) == _HEALTH_ENRICH_TOP_N_PREMIUM
async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch):
"""When the catalogue has no OR cfgs at all, no HTTP calls fire."""
yaml_cfg: dict[str, Any] = {
"id": -1,
"provider": "AZURE_OPENAI",
"model_name": "gpt-5",
"billing_tier": "premium",
}
requests: list[str] = []
def responder(url: str) -> _StubResponse:
requests.append(url)
return _StubResponse(payload=_healthy_payload())
_patch_async_client(monkeypatch, responder)
service = OpenRouterIntegrationService()
service._settings = {}
await service._enrich_health([yaml_cfg])
assert requests == []

View file

@ -0,0 +1,345 @@
"""Unit tests for the Auto (Fastest) quality scoring module."""
from __future__ import annotations
import time
import pytest
from app.services.quality_score import (
_HEALTH_GATE_UPTIME_PCT,
_OPERATOR_TRUST_BONUS,
aggregate_health,
capabilities_signal,
context_signal,
created_recency_signal,
pricing_band,
slug_penalty,
static_score_or,
static_score_yaml,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# created_recency_signal
# ---------------------------------------------------------------------------
def test_created_recency_signal_recent_model_scores_high():
now = 1_750_000_000 # ~mid-2025
one_month_ago = now - (30 * 86_400)
assert created_recency_signal(one_month_ago, now) == 20
def test_created_recency_signal_old_model_scores_zero():
now = 1_750_000_000
five_years_ago = now - (5 * 365 * 86_400)
assert created_recency_signal(five_years_ago, now) == 0
def test_created_recency_signal_missing_timestamp_is_neutral():
now = 1_750_000_000
assert created_recency_signal(None, now) == 0
assert created_recency_signal(0, now) == 0
def test_created_recency_signal_monotonic_decay():
now = 1_750_000_000
scores = [
created_recency_signal(now - days * 86_400, now)
for days in (30, 120, 300, 500, 700, 1000, 1500)
]
assert scores == sorted(scores, reverse=True)
# ---------------------------------------------------------------------------
# pricing_band
# ---------------------------------------------------------------------------
def test_pricing_band_free_returns_zero():
assert pricing_band("0", "0") == 0
assert pricing_band(0.0, 0.0) == 0
assert pricing_band(None, None) == 0
def test_pricing_band_handles_unparseable():
assert pricing_band("not-a-number", "0") == 0
assert pricing_band({}, []) == 0 # type: ignore[arg-type]
def test_pricing_band_premium_tiers_increase_with_price():
cheap = pricing_band("0.0000003", "0.0000005")
mid = pricing_band("0.000003", "0.000015")
flagship = pricing_band("0.00001", "0.00005")
assert 0 < cheap < mid < flagship
# ---------------------------------------------------------------------------
# context_signal
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"ctx,expected",
[
(1_500_000, 10),
(1_000_000, 10),
(500_000, 8),
(200_000, 6),
(128_000, 4),
(100_000, 2),
(50_000, 0),
(0, 0),
(None, 0),
],
)
def test_context_signal_bands(ctx, expected):
assert context_signal(ctx) == expected
# ---------------------------------------------------------------------------
# capabilities_signal
# ---------------------------------------------------------------------------
def test_capabilities_signal_caps_at_five():
assert (
capabilities_signal(
["tools", "structured_outputs", "reasoning", "include_reasoning"]
)
<= 5
)
def test_capabilities_signal_tools_only():
assert capabilities_signal(["tools"]) == 2
def test_capabilities_signal_empty():
assert capabilities_signal(None) == 0
assert capabilities_signal([]) == 0
# ---------------------------------------------------------------------------
# slug_penalty
# ---------------------------------------------------------------------------
def test_slug_penalty_demotes_tiny_models():
assert slug_penalty("meta-llama/llama-3.2-1b-instruct") < 0
assert slug_penalty("liquid/lfm-7b") < 0
assert slug_penalty("google/gemma-3n-e4b-it") < 0
def test_slug_penalty_skips_capable_mini_nano_lite_models():
"""Critical Option C+ regression: don't penalise modern frontier
models named ``-nano`` / ``-mini`` / ``-lite`` (gpt-5-mini, etc.)."""
assert slug_penalty("openai/gpt-5-mini") == 0
assert slug_penalty("openai/gpt-5-nano") == 0
assert slug_penalty("google/gemini-2.5-flash-lite") == 0
assert slug_penalty("anthropic/claude-haiku-4.5") == 0
def test_slug_penalty_demotes_legacy_variants():
assert slug_penalty("openai/o1-preview") < 0
assert slug_penalty("foo/bar-base") < 0
assert slug_penalty("foo/bar-distill") < 0
def test_slug_penalty_empty_input():
assert slug_penalty("") == 0
# ---------------------------------------------------------------------------
# static_score_or
# ---------------------------------------------------------------------------
def _or_model(
*,
model_id: str,
created: int | None = None,
prompt: str = "0.000003",
completion: str = "0.000015",
context: int = 200_000,
params: list[str] | None = None,
) -> dict:
return {
"id": model_id,
"created": created,
"pricing": {"prompt": prompt, "completion": completion},
"context_length": context,
"supported_parameters": params if params is not None else ["tools"],
}
def test_static_score_or_frontier_premium_beats_free_tiny():
now = 1_750_000_000
frontier = _or_model(
model_id="openai/gpt-5",
created=now - (60 * 86_400),
prompt="0.000005",
completion="0.000020",
context=400_000,
params=["tools", "structured_outputs", "reasoning"],
)
tiny_free = _or_model(
model_id="meta-llama/llama-3.2-1b-instruct:free",
created=now - (5 * 365 * 86_400),
prompt="0",
completion="0",
context=128_000,
params=["tools"],
)
assert static_score_or(frontier, now_ts=now) > static_score_or(
tiny_free, now_ts=now
)
def test_static_score_or_score_is_clamped_0_to_100():
now = int(time.time())
score = static_score_or(_or_model(model_id="openai/gpt-4o"), now_ts=now)
assert 0 <= score <= 100
def test_static_score_or_unknown_provider_is_neutral_not_zero():
now = int(time.time())
score = static_score_or(
_or_model(model_id="some-new-lab/some-model"),
now_ts=now,
)
assert score > 0
def test_static_score_or_recent_release_beats_year_old_same_provider():
now = 1_750_000_000
fresh = _or_model(model_id="openai/gpt-5", created=now - (60 * 86_400))
old = _or_model(model_id="openai/gpt-4-turbo", created=now - (700 * 86_400))
assert static_score_or(fresh, now_ts=now) > static_score_or(old, now_ts=now)
# ---------------------------------------------------------------------------
# static_score_yaml
# ---------------------------------------------------------------------------
def test_static_score_yaml_includes_operator_bonus():
cfg = {
"provider": "AZURE_OPENAI",
"model_name": "gpt-5",
"litellm_params": {"base_model": "azure/gpt-5"},
}
score = static_score_yaml(cfg)
assert score >= _OPERATOR_TRUST_BONUS
def test_static_score_yaml_unknown_provider_still_carries_bonus():
cfg = {
"provider": "SOME_NEW_PROVIDER",
"model_name": "weird-model",
}
score = static_score_yaml(cfg)
assert score >= _OPERATOR_TRUST_BONUS
def test_static_score_yaml_clamped_0_to_100():
cfg = {
"provider": "AZURE_OPENAI",
"model_name": "gpt-5",
"litellm_params": {"base_model": "azure/gpt-5"},
}
assert 0 <= static_score_yaml(cfg) <= 100
# ---------------------------------------------------------------------------
# aggregate_health
# ---------------------------------------------------------------------------
def test_aggregate_health_gates_when_uptime_below_threshold():
"""Live data showed Venice-routed cfgs at 53-68%; this guards that the
90% gate excludes them."""
venice_endpoints = [
{
"status": 0,
"uptime_last_30m": 0.55,
"uptime_last_1d": 0.60,
"uptime_last_5m": 0.50,
},
{
"status": 0,
"uptime_last_30m": 0.65,
"uptime_last_1d": 0.68,
"uptime_last_5m": 0.62,
},
]
gated, score = aggregate_health(venice_endpoints)
assert gated is True
assert score is None
def test_aggregate_health_passes_for_healthy_provider():
healthy = [
{
"status": 0,
"uptime_last_30m": 0.99,
"uptime_last_1d": 0.995,
"uptime_last_5m": 0.99,
},
]
gated, score = aggregate_health(healthy)
assert gated is False
assert score is not None
assert score >= _HEALTH_GATE_UPTIME_PCT
def test_aggregate_health_picks_best_endpoint_across_multiple():
"""Multi-endpoint aggregation should reward the best non-null uptime."""
mixed = [
{"status": 0, "uptime_last_30m": 0.55},
{"status": 0, "uptime_last_30m": 0.97}, # this one passes the gate
]
gated, score = aggregate_health(mixed)
assert gated is False
assert score is not None
def test_aggregate_health_empty_endpoints_gated():
gated, score = aggregate_health([])
assert gated is True
assert score is None
def test_aggregate_health_no_status_zero_gated():
"""Even with high uptime, no OK status means the cfg is broken upstream."""
endpoints = [
{"status": 1, "uptime_last_30m": 0.99},
{"status": 2, "uptime_last_30m": 0.98},
]
gated, score = aggregate_health(endpoints)
assert gated is True
assert score is None
def test_aggregate_health_all_uptime_null_gated():
endpoints = [
{"status": 0, "uptime_last_30m": None, "uptime_last_1d": None},
]
gated, score = aggregate_health(endpoints)
assert gated is True
assert score is None
def test_aggregate_health_pct_normalisation():
"""OpenRouter returns 0-1 fractions; some endpoints surface 0-100%
percentages. Both should reach the same gate decision."""
fraction_form = [{"status": 0, "uptime_last_30m": 0.95}]
pct_form = [{"status": 0, "uptime_last_30m": 95.0}]
g1, s1 = aggregate_health(fraction_form)
g2, s2 = aggregate_health(pct_form)
assert g1 == g2 == False # noqa: E712
assert s1 is not None and s2 is not None
assert abs(s1 - s2) < 0.5

View file

@ -0,0 +1,370 @@
"""Unit tests for the filesystem-tool branches of ``revert_service``.
Covers:
* Exact-name dispatch ``rmdir`` does NOT mis-route to the document
branch (``"rmdir".startswith("rm")`` would mis-route under the legacy
prefix-based dispatch).
* ``rm`` revert re-INSERTs a fresh document from the snapshot, including
re-creating chunks. Falls back to ``(folder_id_before, title_before)``
when ``metadata_before["virtual_path"]`` is missing.
* ``write_file`` create-revert (``content_before IS NULL``) DELETEs the
document.
* ``rmdir`` revert re-INSERTs a fresh folder from the snapshot.
* ``mkdir`` revert DELETEs the empty folder; reports ``tool_unavailable``
when the folder gained children.
"""
from __future__ import annotations
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import numpy as np
import pytest
from app.services import revert_service
pytestmark = pytest.mark.unit
@pytest.fixture(autouse=True)
def _stub_embeddings(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
revert_service,
"embed_texts",
lambda texts: [np.zeros(8, dtype=np.float32) for _ in texts],
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _FakeResult:
def __init__(self, rows: list[Any] | None = None, scalar: Any = None) -> None:
self._rows = rows or []
self._scalar = scalar
def all(self) -> list[Any]:
return list(self._rows)
def scalar_one_or_none(self) -> Any:
return self._scalar
def scalars(self) -> Any:
return _FakeScalarsProxy(self._rows)
class _FakeScalarsProxy:
def __init__(self, rows: list[Any]) -> None:
self._rows = rows
def first(self) -> Any:
return self._rows[0] if self._rows else None
class _FakeSession:
def __init__(self) -> None:
self.execute = AsyncMock()
self.added: list[Any] = []
self.deleted: list[Any] = []
self.flush = AsyncMock()
# session.get(Model, pk) lookup
self.get = AsyncMock(return_value=None)
async def _flush_assigning_ids() -> None:
for obj in self.added:
if getattr(obj, "id", None) is None:
obj.id = 999
self.flush.side_effect = _flush_assigning_ids
def add(self, obj: Any) -> None:
self.added.append(obj)
def add_all(self, objs: list[Any]) -> None:
self.added.extend(objs)
def _action(*, tool_name: str, action_id: int = 7):
return MagicMock(
id=action_id,
tool_name=tool_name,
thread_id=1,
search_space_id=2,
user_id="user-1",
reverse_descriptor=None,
)
def _doc_revision(
*,
document_id: int | None = None,
content_before: str | None = "old content",
title_before: str | None = "notes.md",
folder_id_before: int | None = 5,
chunks_before: list[dict[str, str]] | None = None,
metadata_before: dict[str, str] | None = None,
):
revision = MagicMock()
revision.id = 100
revision.document_id = document_id
revision.search_space_id = 2
revision.content_before = content_before
revision.title_before = title_before
revision.folder_id_before = folder_id_before
revision.chunks_before = chunks_before or []
revision.metadata_before = metadata_before
return revision
def _folder_revision(
*,
folder_id: int | None = None,
name_before: str | None = "team",
parent_id_before: int | None = None,
position_before: str | None = "a0",
):
revision = MagicMock()
revision.id = 200
revision.folder_id = folder_id
revision.search_space_id = 2
revision.name_before = name_before
revision.parent_id_before = parent_id_before
revision.position_before = position_before
return revision
# ---------------------------------------------------------------------------
# Exact-name dispatch regression guards
# ---------------------------------------------------------------------------
class TestExactDispatch:
"""Regression: ``rmdir`` MUST NOT route to the document branch."""
@pytest.mark.asyncio
async def test_rmdir_does_not_misroute_to_document(self) -> None:
# If dispatch used `startswith("rm")` we'd hit the document branch
# here. With exact-name lookup `rmdir` lands in `_FOLDER_TOOLS`.
session = _FakeSession()
action = _action(tool_name="rmdir")
# No folder revisions exist for this action.
session.execute.return_value = _FakeResult(rows=[])
outcome = await revert_service.revert_action(
session, # type: ignore[arg-type]
action=action,
requester_user_id="user-1",
)
assert outcome.status == "not_reversible"
assert "folder_revisions" in outcome.message
def test_dispatch_sets_split_doc_and_folder(self) -> None:
# Static guards on the dispatch tables themselves so a future
# refactor doesn't accidentally reintroduce the prefix bug.
assert "rm" in revert_service._DOC_TOOLS
assert "rmdir" in revert_service._FOLDER_TOOLS
assert "rmdir" not in revert_service._DOC_TOOLS
assert "rm" not in revert_service._FOLDER_TOOLS
# ``move_file`` lives only in document tools (it's a doc rename).
assert "move_file" in revert_service._DOC_TOOLS
assert "move_file" not in revert_service._FOLDER_TOOLS
# ---------------------------------------------------------------------------
# rm revert (re-INSERT)
# ---------------------------------------------------------------------------
class TestRmRevert:
@pytest.mark.asyncio
async def test_re_inserts_document_with_chunks(self) -> None:
session = _FakeSession()
revision = _doc_revision(
document_id=None, # row was hard-deleted
content_before="hello world",
title_before="x.md",
folder_id_before=None,
chunks_before=[{"content": "alpha"}, {"content": "beta"}],
metadata_before={"virtual_path": "/documents/x.md"},
)
# No collision check hit and the resulting query returns nothing.
session.execute.return_value = _FakeResult(scalar=None)
outcome = await revert_service._reinsert_document_from_revision(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "ok"
# New Document + 2 chunks must have been added.
from app.db import Chunk, Document
added_docs = [obj for obj in session.added if isinstance(obj, Document)]
added_chunks = [obj for obj in session.added if isinstance(obj, Chunk)]
assert len(added_docs) == 1
assert added_docs[0].title == "x.md"
assert len(added_chunks) == 2
# Snapshot was repointed at the new doc id so a follow-up revert works.
assert revision.document_id == added_docs[0].id
@pytest.mark.asyncio
async def test_falls_back_to_folder_id_and_title_for_virtual_path(
self,
) -> None:
session = _FakeSession()
# Snapshot with NO metadata_before — the fallback path must kick in.
revision = _doc_revision(
document_id=None,
content_before="hello",
title_before="cap.md",
folder_id_before=42,
chunks_before=[],
metadata_before=None,
)
# session.get(Folder, 42) returns a folder with a name.
folder = MagicMock()
folder.name = "team"
folder.parent_id = None
# First .get is for the folder lookup in the path-derivation.
session.get = AsyncMock(return_value=folder)
session.execute.return_value = _FakeResult(scalar=None)
outcome = await revert_service._reinsert_document_from_revision(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "ok"
@pytest.mark.asyncio
async def test_falls_back_to_root_path_when_no_folder(
self,
) -> None:
"""metadata_before is None and folder_id_before is None still
resolves: title fallback yields ``/documents/<title>`` so revert
proceeds at the root of the documents tree."""
session = _FakeSession()
revision = _doc_revision(
document_id=None,
content_before="hello",
title_before="x.md",
folder_id_before=None,
metadata_before=None,
)
# No collision in the documents tree at /documents/x.md.
session.execute.return_value = _FakeResult(scalar=None)
outcome = await revert_service._reinsert_document_from_revision(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "ok"
@pytest.mark.asyncio
async def test_collision_with_live_doc_returns_tool_unavailable(self) -> None:
session = _FakeSession()
revision = _doc_revision(
document_id=None,
content_before="hi",
title_before="x.md",
folder_id_before=None,
metadata_before={"virtual_path": "/documents/x.md"},
)
# SELECT for unique_identifier_hash collision hits an existing row.
session.execute.return_value = _FakeResult(scalar=42)
outcome = await revert_service._reinsert_document_from_revision(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "tool_unavailable"
assert "collide" in outcome.message
# ---------------------------------------------------------------------------
# write_file create revert (DELETE)
# ---------------------------------------------------------------------------
class TestWriteFileCreateRevert:
@pytest.mark.asyncio
async def test_deletes_created_doc(self) -> None:
session = _FakeSession()
revision = _doc_revision(
document_id=99,
content_before=None, # marker for "created in this action"
title_before=None,
)
outcome = await revert_service._delete_created_document(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "ok"
# Exactly one DELETE was issued.
assert session.execute.await_count == 1
# ---------------------------------------------------------------------------
# rmdir revert (re-INSERT folder)
# ---------------------------------------------------------------------------
class TestRmdirRevert:
@pytest.mark.asyncio
async def test_re_inserts_folder_from_snapshot(self) -> None:
session = _FakeSession()
revision = _folder_revision(
folder_id=None,
name_before="team",
parent_id_before=None,
position_before="a0",
)
outcome = await revert_service._reinsert_folder_from_revision(
session, # type: ignore[arg-type]
revision=revision,
)
from app.db import Folder
assert outcome.status == "ok"
added_folders = [obj for obj in session.added if isinstance(obj, Folder)]
assert len(added_folders) == 1
assert added_folders[0].name == "team"
assert revision.folder_id == added_folders[0].id
# ---------------------------------------------------------------------------
# mkdir revert (DELETE folder)
# ---------------------------------------------------------------------------
class TestMkdirRevert:
@pytest.mark.asyncio
async def test_deletes_empty_folder(self) -> None:
session = _FakeSession()
revision = _folder_revision(folder_id=42)
# Both the doc-existence check and the child-folder check return None.
session.execute.side_effect = [
_FakeResult(scalar=None), # docs
_FakeResult(scalar=None), # children
_FakeResult(scalar=None), # delete (no return value)
]
outcome = await revert_service._delete_created_folder(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "ok"
# 3 executes: docs check, children check, delete.
assert session.execute.await_count == 3
@pytest.mark.asyncio
async def test_reports_tool_unavailable_when_folder_has_children(self) -> None:
session = _FakeSession()
revision = _folder_revision(folder_id=42)
# First check (docs) returns "row found".
session.execute.return_value = _FakeResult(scalar=1)
outcome = await revert_service._delete_created_folder(
session, # type: ignore[arg-type]
revision=revision,
)
assert outcome.status == "tool_unavailable"
assert "no longer empty" in outcome.message

View file

@ -0,0 +1,228 @@
"""Unit tests for ``stream_new_chat._extract_chunk_parts``.
Earlier versions only handled ``isinstance(chunk.content, str)`` and
silently dropped every other shape (Anthropic typed-block lists,
Bedrock reasoning blocks, ``additional_kwargs.reasoning_content`` from
a few providers). These regression tests pin those four shapes plus the
defensive cases (``None`` chunk, mixed types, missing fields).
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
import pytest
from app.tasks.chat.stream_new_chat import _extract_chunk_parts
@dataclass
class _FakeChunk:
"""Minimal stand-in for ``AIMessageChunk`` used in unit tests."""
content: Any = ""
additional_kwargs: dict[str, Any] = field(default_factory=dict)
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
class TestStringContent:
def test_plain_string_content_extracts_as_text(self) -> None:
chunk = _FakeChunk(content="hello world")
out = _extract_chunk_parts(chunk)
assert out["text"] == "hello world"
assert out["reasoning"] == ""
assert out["tool_call_chunks"] == []
def test_empty_string_content_yields_empty_text(self) -> None:
chunk = _FakeChunk(content="")
out = _extract_chunk_parts(chunk)
assert out["text"] == ""
assert out["reasoning"] == ""
assert out["tool_call_chunks"] == []
class TestListContent:
def test_list_of_text_blocks_concatenates(self) -> None:
chunk = _FakeChunk(
content=[
{"type": "text", "text": "Hello "},
{"type": "text", "text": "world"},
]
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "Hello world"
assert out["reasoning"] == ""
def test_mixed_text_and_reasoning_blocks(self) -> None:
chunk = _FakeChunk(
content=[
{"type": "reasoning", "reasoning": "Let me think... "},
{"type": "reasoning", "text": "still thinking."},
{"type": "text", "text": "The answer is 42."},
]
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "The answer is 42."
assert out["reasoning"] == "Let me think... still thinking."
def test_tool_call_chunks_in_content_list_extracted(self) -> None:
chunk = _FakeChunk(
content=[
{"type": "text", "text": "Calling tool..."},
{
"type": "tool_call_chunk",
"id": "call_123",
"name": "make_widget",
"args": '{"color":"red"}',
},
]
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "Calling tool..."
assert out["reasoning"] == ""
assert len(out["tool_call_chunks"]) == 1
assert out["tool_call_chunks"][0]["id"] == "call_123"
assert out["tool_call_chunks"][0]["name"] == "make_widget"
def test_tool_use_blocks_also_extracted(self) -> None:
"""Some providers (Anthropic) emit ``type='tool_use'`` instead."""
chunk = _FakeChunk(
content=[
{
"type": "tool_use",
"id": "call_xyz",
"name": "search",
},
]
)
out = _extract_chunk_parts(chunk)
assert out["tool_call_chunks"] == [
{"type": "tool_use", "id": "call_xyz", "name": "search"}
]
def test_unknown_block_types_are_ignored(self) -> None:
chunk = _FakeChunk(
content=[
{"type": "image_url", "url": "https://example.com/x.png"},
{"type": "text", "text": "ok"},
]
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "ok"
def test_blocks_without_text_field_are_ignored(self) -> None:
chunk = _FakeChunk(
content=[
{"type": "text"}, # no text/content key
{"type": "text", "text": "kept"},
]
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "kept"
class TestAdditionalKwargsReasoning:
def test_reasoning_content_in_additional_kwargs(self) -> None:
"""Some providers stash reasoning in ``additional_kwargs.reasoning_content``."""
chunk = _FakeChunk(
content="visible answer",
additional_kwargs={"reasoning_content": "internal monologue"},
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "visible answer"
assert out["reasoning"] == "internal monologue"
def test_reasoning_appended_to_typed_block_reasoning(self) -> None:
chunk = _FakeChunk(
content=[{"type": "reasoning", "text": "from blocks. "}],
additional_kwargs={"reasoning_content": "from kwargs."},
)
out = _extract_chunk_parts(chunk)
assert out["reasoning"] == "from blocks. from kwargs."
class TestToolCallChunksAttribute:
def test_tool_call_chunks_attribute_extracted_alongside_string_content(
self,
) -> None:
chunk = _FakeChunk(
content="streaming text",
tool_call_chunks=[
{"name": "save_document", "args": '{"title":"x"}', "id": "tc-9"}
],
)
out = _extract_chunk_parts(chunk)
assert out["text"] == "streaming text"
assert len(out["tool_call_chunks"]) == 1
assert out["tool_call_chunks"][0]["id"] == "tc-9"
def test_attribute_and_typed_block_chunks_both_collected(self) -> None:
chunk = _FakeChunk(
content=[
{
"type": "tool_call_chunk",
"id": "from-block",
"name": "x",
}
],
tool_call_chunks=[{"id": "from-attr", "name": "y"}],
)
out = _extract_chunk_parts(chunk)
ids = [tcc.get("id") for tcc in out["tool_call_chunks"]]
assert ids == ["from-block", "from-attr"]
class TestDefensive:
@pytest.mark.parametrize(
"chunk_value",
[None, _FakeChunk(content=None), _FakeChunk(content=42)],
)
def test_invalid_chunk_returns_empty_parts(self, chunk_value: Any) -> None:
out = _extract_chunk_parts(chunk_value)
assert out["text"] == ""
assert out["reasoning"] == ""
assert out["tool_call_chunks"] == []
class TestIdlessContinuationChunks:
"""Per LangChain ``ToolCallChunk`` semantics, the FIRST chunk for a
tool call carries id+name; later chunks for the same call have
``id=None, name=None`` and only ``args`` + ``index``. Live tool-call
argument streaming relies on those idless continuation chunks
flowing through ``_extract_chunk_parts`` UNTOUCHED so the upstream
chunk-emission loop can still route them by ``index``.
"""
def test_idless_continuation_chunk_preserved_verbatim(self) -> None:
chunk = _FakeChunk(
tool_call_chunks=[
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
]
)
out = _extract_chunk_parts(chunk)
assert len(out["tool_call_chunks"]) == 1
tcc = out["tool_call_chunks"][0]
assert tcc.get("id") is None
assert tcc.get("name") is None
assert tcc.get("args") == '_path":"/x"}'
assert tcc.get("index") == 0
def test_first_then_idless_sequence_preserves_index(self) -> None:
"""Both chunks for the same call share an ``index`` key — the
index-routing loop in ``stream_new_chat`` depends on it."""
first = _FakeChunk(
tool_call_chunks=[
{"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0}
]
)
cont = _FakeChunk(
tool_call_chunks=[
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
]
)
out_first = _extract_chunk_parts(first)
out_cont = _extract_chunk_parts(cont)
assert out_first["tool_call_chunks"][0]["index"] == 0
assert out_cont["tool_call_chunks"][0]["index"] == 0
assert out_cont["tool_call_chunks"][0].get("id") is None

View file

@ -0,0 +1,527 @@
"""Unit tests for live tool-call argument streaming.
Pins the wire format that ``_stream_agent_events`` emits when
``SURFSENSE_ENABLE_STREAM_PARITY_V2=true``: ``tool-input-start``
``tool-input-delta``... ``tool-input-available`` ``tool-output-available``
all keyed by the same LangChain ``tool_call.id``.
Identity is tracked in ``index_to_meta`` (per-chunk ``index``) and
``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are private to
``_stream_agent_events`` so we exercise them via the public wire output.
These tests also lock in the legacy / parity_v2-OFF behaviour so the
synthetic ``call_<run_id>`` shape stays stable for older clients.
"""
from __future__ import annotations
import json
from collections.abc import AsyncGenerator
from dataclasses import dataclass, field
from typing import Any
import pytest
import app.tasks.chat.stream_new_chat as stream_module
from app.agents.new_chat.feature_flags import AgentFeatureFlags
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.stream_new_chat import (
StreamResult,
_legacy_match_lc_id,
_stream_agent_events,
)
pytestmark = pytest.mark.unit
@dataclass
class _FakeChunk:
"""Minimal stand-in for ``AIMessageChunk``."""
content: Any = ""
additional_kwargs: dict[str, Any] = field(default_factory=dict)
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
@dataclass
class _FakeToolMessage:
"""Stand-in for ``ToolMessage`` returned by ``on_tool_end``."""
content: Any
tool_call_id: str | None = None
class _FakeAgentState:
"""Stand-in for ``StateSnapshot`` returned by ``aget_state``."""
def __init__(self) -> None:
# Empty values keeps the cloud-fallback safety-net branch a no-op,
# and an empty ``tasks`` list keeps the post-stream interrupt
# check a no-op too.
self.values: dict[str, Any] = {}
self.tasks: list[Any] = []
class _FakeAgent:
"""Replays a list of ``astream_events`` events."""
def __init__(self, events: list[dict[str, Any]]) -> None:
self._events = events
async def astream_events( # type: ignore[no-untyped-def]
self, _input_data: Any, *, config: dict[str, Any], version: str
) -> AsyncGenerator[dict[str, Any], None]:
del config, version # unused, contract-compatible
for ev in self._events:
yield ev
async def aget_state(self, _config: dict[str, Any]) -> _FakeAgentState:
# Called once after astream_events drains so the cloud-fallback
# safety net can inspect staged filesystem work. The fake stays
# empty so the safety net is a no-op.
return _FakeAgentState()
def _model_stream(
*,
text: str = "",
reasoning: str = "",
tool_call_chunks: list[dict[str, Any]] | None = None,
tags: list[str] | None = None,
) -> dict[str, Any]:
return (
{
"event": "on_chat_model_stream",
"tags": tags or [],
"data": {
"chunk": _FakeChunk(
content=text,
tool_call_chunks=list(tool_call_chunks or []),
)
},
# reasoning piggybacks via additional_kwargs path; if needed,
# override content to a typed-block list. Most tests just check
# tool_call_chunks routing so this is fine.
}
if not reasoning
else {
"event": "on_chat_model_stream",
"tags": tags or [],
"data": {
"chunk": _FakeChunk(
content=text,
additional_kwargs={"reasoning_content": reasoning},
tool_call_chunks=list(tool_call_chunks or []),
)
},
}
)
def _tool_start(
*,
name: str,
run_id: str,
input_payload: dict[str, Any] | None = None,
) -> dict[str, Any]:
return {
"event": "on_tool_start",
"name": name,
"run_id": run_id,
"data": {"input": input_payload or {}},
}
def _tool_end(
*,
name: str,
run_id: str,
tool_call_id: str | None = None,
output: Any = "ok",
) -> dict[str, Any]:
return {
"event": "on_tool_end",
"name": name,
"run_id": run_id,
"data": {
"output": _FakeToolMessage(
content=json.dumps(output) if not isinstance(output, str) else output,
tool_call_id=tool_call_id,
)
},
}
@pytest.fixture
def parity_v2_on(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
stream_module,
"get_flags",
lambda: AgentFeatureFlags(enable_stream_parity_v2=True),
)
@pytest.fixture
def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
stream_module,
"get_flags",
lambda: AgentFeatureFlags(enable_stream_parity_v2=False),
)
async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Run ``_stream_agent_events`` against a fake agent and return the
SSE payloads (parsed JSON) it yielded.
"""
agent = _FakeAgent(events)
service = VercelStreamingService()
result = StreamResult()
config = {"configurable": {"thread_id": "test-thread"}}
sse_lines: list[str] = []
async for sse in _stream_agent_events(
agent, config, {}, service, result, step_prefix="thinking"
):
sse_lines.append(sse)
parsed: list[dict[str, Any]] = []
for line in sse_lines:
if not line.startswith("data: "):
continue
body = line[len("data: ") :].rstrip("\n")
if not body or body == "[DONE]":
continue
try:
parsed.append(json.loads(body))
except json.JSONDecodeError:
continue
return parsed
def _types(payloads: list[dict[str, Any]]) -> list[str]:
return [p.get("type", "?") for p in payloads]
def _of_type(payloads: list[dict[str, Any]], type_name: str) -> list[dict[str, Any]]:
return [p for p in payloads if p.get("type") == type_name]
# ---------------------------------------------------------------------------
# Helper: ``_legacy_match_lc_id`` is a pure refactor; assert behaviour.
# ---------------------------------------------------------------------------
class TestLegacyMatch:
def test_pops_first_id_bearing_chunk_with_matching_name(self) -> None:
chunks: list[dict[str, Any]] = [
{"id": "x1", "name": "ls"},
{"id": "y1", "name": "write_file"},
]
runs: dict[str, str] = {}
result = _legacy_match_lc_id(chunks, "write_file", "run-1", runs)
assert result == "y1"
assert chunks == [{"id": "x1", "name": "ls"}]
assert runs == {"run-1": "y1"}
def test_falls_back_to_any_id_bearing_when_name_mismatches(self) -> None:
chunks: list[dict[str, Any]] = [{"id": "anon", "name": None}]
runs: dict[str, str] = {}
out = _legacy_match_lc_id(chunks, "ls", "run-2", runs)
assert out == "anon"
assert chunks == []
def test_returns_none_when_no_id_bearing_chunk(self) -> None:
chunks: list[dict[str, Any]] = [{"id": None, "name": None}]
runs: dict[str, str] = {}
assert _legacy_match_lc_id(chunks, "ls", "run-3", runs) is None
assert chunks == [{"id": None, "name": None}]
assert runs == {}
# ---------------------------------------------------------------------------
# parity_v2 wire format tests.
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None:
"""First chunk carries id+name; later idless chunks at the same
``index`` merge into the SAME ``tool-input-start`` ui id and emit
one ``tool-input-delta`` per chunk."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0}
],
),
_model_stream(
tool_call_chunks=[
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
],
),
_tool_start(
name="write_file", run_id="run-A", input_payload={"file_path": "/x"}
),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
deltas = _of_type(payloads, "tool-input-delta")
available = _of_type(payloads, "tool-input-available")
output = _of_type(payloads, "tool-output-available")
assert len(starts) == 1
assert starts[0]["toolCallId"] == "lc-1"
assert starts[0]["toolName"] == "write_file"
assert starts[0]["langchainToolCallId"] == "lc-1"
assert [d["inputTextDelta"] for d in deltas] == ['{"file', '_path":"/x"}']
assert all(d["toolCallId"] == "lc-1" for d in deltas)
assert len(available) == 1
assert available[0]["toolCallId"] == "lc-1"
assert len(output) == 1
assert output[0]["toolCallId"] == "lc-1"
@pytest.mark.asyncio
async def test_two_interleaved_tool_calls_route_by_index(
parity_v2_on: None,
) -> None:
"""Two same-name calls with distinct indices keep their deltas
routed to the right card."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-A", "name": "write_file", "args": '{"a":1', "index": 0},
{"id": "lc-B", "name": "write_file", "args": '{"b":2', "index": 1},
]
),
_model_stream(
tool_call_chunks=[
{"id": None, "name": None, "args": "}", "index": 0},
{"id": None, "name": None, "args": "}", "index": 1},
]
),
_tool_start(name="write_file", run_id="run-A", input_payload={"a": 1}),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-A"),
_tool_start(name="write_file", run_id="run-B", input_payload={"b": 2}),
_tool_end(name="write_file", run_id="run-B", tool_call_id="lc-B"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
deltas = _of_type(payloads, "tool-input-delta")
output = _of_type(payloads, "tool-output-available")
assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"}
by_id: dict[str, list[str]] = {"lc-A": [], "lc-B": []}
for d in deltas:
by_id[d["toolCallId"]].append(d["inputTextDelta"])
assert by_id["lc-A"] == ['{"a":1', "}"]
assert by_id["lc-B"] == ['{"b":2', "}"]
assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"}
@pytest.mark.asyncio
async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None:
"""Whatever id ``tool-input-start`` chose must be the SAME id used
on ``tool-input-available`` AND ``tool-output-available``."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-9", "name": "ls", "args": '{"path":"/"}', "index": 0}
]
),
_tool_start(name="ls", run_id="run-X", input_payload={"path": "/"}),
_tool_end(name="ls", run_id="run-X", tool_call_id="lc-9"),
]
payloads = await _drain(events)
relevant = [
p
for p in payloads
if p.get("type")
in {"tool-input-start", "tool-input-available", "tool-output-available"}
]
assert {p["toolCallId"] for p in relevant} == {"lc-9"}
@pytest.mark.asyncio
async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None:
"""When the chunk-emission loop already fired ``tool-input-start``
for this run, ``on_tool_start`` MUST NOT emit a second one."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "write_file", "args": "{}", "index": 0}
]
),
_tool_start(name="write_file", run_id="run-A", input_payload={}),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
assert len(starts) == 1
assert starts[0]["toolCallId"] == "lc-1"
@pytest.mark.asyncio
async def test_active_text_closes_before_early_tool_input_start(
parity_v2_on: None,
) -> None:
"""Streaming a text-delta then a tool-call chunk in subsequent
chunks: the wire MUST contain ``text-end`` before the FIRST
``tool-input-start`` (clean part boundary on the frontend)."""
events = [
_model_stream(text="Working on it"),
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "write_file", "args": "{}", "index": 0}
]
),
_tool_start(name="write_file", run_id="run-A", input_payload={}),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
]
types = _types(await _drain(events))
text_end_idx = types.index("text-end")
start_idx = types.index("tool-input-start")
assert text_end_idx < start_idx
@pytest.mark.asyncio
async def test_mixed_text_and_tool_chunk_preserve_order(
parity_v2_on: None,
) -> None:
"""One AIMessageChunk that carries BOTH ``text`` content AND
``tool_call_chunks`` should emit the text delta FIRST, then close
text, then ``tool-input-start``+``tool-input-delta``."""
events = [
_model_stream(
text="I'll update it",
tool_call_chunks=[
{
"id": "lc-1",
"name": "write_file",
"args": '{"file_path":"/x"}',
"index": 0,
}
],
),
_tool_start(
name="write_file", run_id="run-A", input_payload={"file_path": "/x"}
),
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
]
types = _types(await _drain(events))
# text-start … text-delta … text-end … tool-input-start … tool-input-delta
assert types.index("text-start") < types.index("text-delta")
assert types.index("text-delta") < types.index("text-end")
assert types.index("text-end") < types.index("tool-input-start")
assert types.index("tool-input-start") < types.index("tool-input-delta")
@pytest.mark.asyncio
async def test_parity_v2_off_preserves_legacy_shape(
parity_v2_off: None,
) -> None:
"""When the flag is OFF, no deltas are emitted and the ``toolCallId``
is ``call_<run_id>`` (NOT the lc id)."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-1", "name": "ls", "args": '{"path":"/"}', "index": 0}
]
),
_tool_start(name="ls", run_id="run-A", input_payload={"path": "/"}),
_tool_end(name="ls", run_id="run-A", tool_call_id="lc-1"),
]
payloads = await _drain(events)
assert _of_type(payloads, "tool-input-delta") == []
starts = _of_type(payloads, "tool-input-start")
assert len(starts) == 1
assert starts[0]["toolCallId"].startswith("call_run-A")
# No ``langchainToolCallId`` propagation on ``tool-input-start`` in
# legacy mode (the start event fires before the ToolMessage is
# available, so we can't extract the authoritative LangChain id yet).
assert "langchainToolCallId" not in starts[0]
output = _of_type(payloads, "tool-output-available")
assert output[0]["toolCallId"].startswith("call_run-A")
# ``tool-output-available`` MUST carry ``langchainToolCallId`` even
# in legacy mode: the chat tool card uses it to backfill the
# LangChain id and join against the ``data-action-log`` SSE event
# (keyed by ``lc_tool_call_id``) so the inline Revert button can
# light up. Sourced from the returned ``ToolMessage.tool_call_id``,
# which is populated regardless of feature-flag state.
assert output[0]["langchainToolCallId"] == "lc-1"
@pytest.mark.asyncio
async def test_skip_append_prevents_stale_id_reuse(
parity_v2_on: None,
) -> None:
"""Two same-name tools: the SECOND tool's ``langchainToolCallId``
must NOT come from the first tool's chunk (``pending_tool_call_chunks``
must stay empty for indexed-registered chunks)."""
events = [
_model_stream(
tool_call_chunks=[
{"id": "lc-A", "name": "write_file", "args": "{}", "index": 0},
{"id": "lc-B", "name": "write_file", "args": "{}", "index": 1},
]
),
_tool_start(name="write_file", run_id="run-1", input_payload={}),
_tool_end(name="write_file", run_id="run-1", tool_call_id="lc-A"),
_tool_start(name="write_file", run_id="run-2", input_payload={}),
_tool_end(name="write_file", run_id="run-2", tool_call_id="lc-B"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
# Two distinct lc ids, each its own card.
assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"}
# Each tool-output-available landed on its respective card.
output = _of_type(payloads, "tool-output-available")
assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"}
@pytest.mark.asyncio
async def test_registration_waits_for_both_id_and_name(
parity_v2_on: None,
) -> None:
"""An id-only chunk (no name yet) must NOT emit ``tool-input-start``."""
events = [
_model_stream(
tool_call_chunks=[{"id": "lc-1", "name": None, "args": "", "index": 0}]
),
]
payloads = await _drain(events)
assert _of_type(payloads, "tool-input-start") == []
@pytest.mark.asyncio
async def test_unmatched_fallback_still_attaches_lc_id(
parity_v2_on: None,
) -> None:
"""parity_v2 ON, but the provider didn't include an ``index``: the
legacy fallback path must still emit ``tool-input-start`` with the
matching ``langchainToolCallId``."""
events = [
# No index on the chunk → not registered into index_to_meta;
# falls through to ``pending_tool_call_chunks`` so the legacy
# match path can pop it at on_tool_start.
_model_stream(tool_call_chunks=[{"id": "lc-orphan", "name": "ls", "args": ""}]),
_tool_start(name="ls", run_id="run-1", input_payload={"path": "/"}),
_tool_end(name="ls", run_id="run-1", tool_call_id="lc-orphan"),
]
payloads = await _drain(events)
starts = _of_type(payloads, "tool-input-start")
assert len(starts) == 1
assert starts[0]["toolCallId"].startswith("call_run-1")
assert starts[0]["langchainToolCallId"] == "lc-orphan"

View file

@ -1,9 +1,21 @@
import inspect
import json
import logging
import re
from pathlib import Path
import pytest
import app.tasks.chat.stream_new_chat as stream_new_chat_module
from app.agents.new_chat.errors import BusyError
from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel
from app.tasks.chat.stream_new_chat import (
StreamResult,
_classify_stream_exception,
_contract_enforcement_active,
_evaluate_file_contract_outcome,
_extract_resolved_file_path,
_log_chat_stream_error,
_tool_output_has_error,
)
@ -17,6 +29,39 @@ def test_tool_output_error_detection():
assert not _tool_output_has_error({"result": "Updated file /notes.md"})
def test_extract_resolved_file_path_prefers_structured_path():
assert (
_extract_resolved_file_path(
tool_name="write_file",
tool_output={"status": "completed", "path": "/docs/note.md"},
tool_input=None,
)
== "/docs/note.md"
)
def test_extract_resolved_file_path_falls_back_to_tool_input():
assert (
_extract_resolved_file_path(
tool_name="edit_file",
tool_output={"status": "completed", "result": "updated"},
tool_input={"file_path": "/docs/edited.md"},
)
== "/docs/edited.md"
)
def test_extract_resolved_file_path_does_not_parse_result_text():
assert (
_extract_resolved_file_path(
tool_name="write_file",
tool_output={"result": "Updated file /docs/from-text.md"},
tool_input=None,
)
is None
)
def test_file_write_contract_outcome_reasons():
result = StreamResult(intent_detected="file_write")
passed, reason = _evaluate_file_contract_outcome(result)
@ -45,3 +90,433 @@ def test_contract_enforcement_local_only():
result.filesystem_mode = "cloud"
assert not _contract_enforcement_active(result)
def _extract_chat_stream_payload(record_message: str) -> dict:
prefix = "[chat_stream_error] "
assert record_message.startswith(prefix)
return json.loads(record_message[len(prefix) :])
def test_unified_chat_stream_error_log_schema(caplog):
with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"):
_log_chat_stream_error(
flow="new",
error_kind="server_error",
error_code="SERVER_ERROR",
severity="warn",
is_expected=False,
request_id="req-123",
thread_id=101,
search_space_id=202,
user_id="user-1",
message="Error during chat: boom",
)
record = next(r for r in caplog.records if "[chat_stream_error]" in r.message)
payload = _extract_chat_stream_payload(record.message)
required_keys = {
"event",
"flow",
"error_kind",
"error_code",
"severity",
"is_expected",
"request_id",
"thread_id",
"search_space_id",
"user_id",
"message",
}
assert required_keys.issubset(payload.keys())
assert payload["event"] == "chat_stream_error"
assert payload["flow"] == "new"
assert payload["error_code"] == "SERVER_ERROR"
def test_premium_quota_uses_unified_chat_stream_log_shape(caplog):
with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"):
_log_chat_stream_error(
flow="resume",
error_kind="premium_quota_exhausted",
error_code="PREMIUM_QUOTA_EXHAUSTED",
severity="info",
is_expected=True,
request_id="req-premium",
thread_id=303,
search_space_id=404,
user_id="user-2",
message="Buy more tokens to continue with this model, or switch to a free model",
extra={"auto_fallback": False},
)
record = next(r for r in caplog.records if "[chat_stream_error]" in r.message)
payload = _extract_chat_stream_payload(record.message)
assert payload["event"] == "chat_stream_error"
assert payload["error_kind"] == "premium_quota_exhausted"
assert payload["error_code"] == "PREMIUM_QUOTA_EXHAUSTED"
assert payload["flow"] == "resume"
assert payload["is_expected"] is True
assert payload["auto_fallback"] is False
def test_stream_error_emission_keeps_machine_error_codes():
source = inspect.getsource(stream_new_chat_module)
format_error_calls = re.findall(r"format_error\(", source)
emitted_error_codes = set(re.findall(r'error_code="([A-Z_]+)"', source))
# All stream paths should route through one shared terminal error emitter.
assert len(format_error_calls) == 1
assert {
"PREMIUM_QUOTA_EXHAUSTED",
"SERVER_ERROR",
}.issubset(emitted_error_codes)
assert 'flow: Literal["new", "regenerate"] = "new"' in source
assert "_emit_stream_terminal_error" in source
assert "flow=flow" in source
assert 'flow="resume"' in source
def test_stream_exception_classifies_rate_limited():
exc = Exception(
'{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}'
)
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "rate_limited"
assert code == "RATE_LIMITED"
assert severity == "warn"
assert is_expected is True
assert "temporarily rate-limited" in user_message
assert extra is None
def test_stream_exception_classifies_openrouter_429_payload():
exc = Exception(
'OpenrouterException - {"error":{"message":"Provider returned error","code":429,'
'"metadata":{"raw":"foo is temporarily rate-limited upstream"}}}'
)
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "rate_limited"
assert code == "RATE_LIMITED"
assert severity == "warn"
assert is_expected is True
assert "temporarily rate-limited" in user_message
assert extra is None
@pytest.mark.asyncio
async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkeypatch):
"""``_preflight_llm`` is best-effort.
- On rate-limit shaped exceptions (provider 429) it MUST re-raise so the
caller can drive the cooldown/repin branch.
- On any other transient failure it MUST swallow the error so the normal
stream path continues without surfacing preflight noise to the user.
"""
from types import SimpleNamespace
from app.tasks.chat.stream_new_chat import _preflight_llm
class _RateLimitedError(Exception):
"""Class-name carries 'RateLimit' so _is_provider_rate_limited triggers."""
rate_calls: list[dict] = []
other_calls: list[dict] = []
async def _fake_acompletion_429(**kwargs):
rate_calls.append(kwargs)
raise _RateLimitedError("simulated 429")
async def _fake_acompletion_other(**kwargs):
other_calls.append(kwargs)
raise RuntimeError("some unrelated transient failure")
fake_llm = SimpleNamespace(
model="openrouter/google/gemma-4-31b-it:free",
api_key="test",
api_base=None,
)
import litellm # type: ignore[import-not-found]
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_429)
with pytest.raises(_RateLimitedError):
await _preflight_llm(fake_llm)
assert len(rate_calls) == 1
assert rate_calls[0]["max_tokens"] == 1
assert rate_calls[0]["stream"] is False
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_other)
# MUST NOT raise: non-rate-limit failures are swallowed.
await _preflight_llm(fake_llm)
assert len(other_calls) == 1
@pytest.mark.asyncio
async def test_preflight_skipped_for_auto_router_model():
"""Router-mode ``model='auto'`` has no single deployment to ping; the
LiteLLM router itself owns per-deployment rate-limit accounting, so the
preflight helper must short-circuit instead of issuing a probe."""
from types import SimpleNamespace
from app.tasks.chat.stream_new_chat import _preflight_llm
fake_llm = SimpleNamespace(model="auto", api_key="x", api_base=None)
# Should return without raising or making any LiteLLM call.
await _preflight_llm(fake_llm)
def test_stream_exception_classifies_thread_busy():
exc = BusyError(request_id="thread-123")
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "thread_busy"
assert code == "THREAD_BUSY"
assert severity == "warn"
assert is_expected is True
assert "still finishing for this thread" in user_message
assert extra is None
def test_stream_exception_classifies_thread_busy_from_message():
exc = Exception("Thread is busy with another request")
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "thread_busy"
assert code == "THREAD_BUSY"
assert severity == "warn"
assert is_expected is True
assert "still finishing for this thread" in user_message
assert extra is None
def test_stream_exception_classifies_turn_cancelling_when_cancel_requested():
thread_id = "thread-cancelling-1"
reset_cancel(thread_id)
request_cancel(thread_id)
exc = BusyError(request_id=thread_id)
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
exc, flow_label="chat"
)
assert kind == "thread_busy"
assert code == "TURN_CANCELLING"
assert severity == "info"
assert is_expected is True
assert "stopping" in user_message
assert isinstance(extra, dict)
assert "retry_after_ms" in extra
def test_premium_classification_is_error_code_driven():
classifier_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/lib/chat/chat-error-classifier.ts"
)
source = classifier_path.read_text(encoding="utf-8")
assert "PREMIUM_KEYWORDS" not in source
assert "RATE_LIMIT_KEYWORDS" not in source
assert "normalized.includes(" not in source
assert 'if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {' in source
def test_stream_terminal_error_handler_has_pre_accept_soft_rollback_hook():
page_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
)
source = page_path.read_text(encoding="utf-8")
assert "onPreAcceptFailure?: () => Promise<void>;" in source
assert "if (!accepted) {" in source
assert "await onPreAcceptFailure?.();" in source
assert "await onAcceptedStreamError?.();" in source
assert "setMessages((prev) => prev.filter((m) => m.id !== userMsgId));" in source
assert "setMessageDocumentsMap((prev) => {" in source
def test_toast_only_pre_accept_policy_has_no_inline_failed_marker():
user_message_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/components/assistant-ui/user-message.tsx"
)
source = user_message_path.read_text(encoding="utf-8")
assert "Not sent. Edit and retry." not in source
assert "failed_pre_accept" not in source
def test_network_send_failures_use_unified_retry_toast_message():
classifier_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/lib/chat/chat-error-classifier.ts"
)
classifier_source = classifier_path.read_text(encoding="utf-8")
request_errors_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/lib/chat/chat-request-errors.ts"
)
request_errors_source = request_errors_path.read_text(encoding="utf-8")
assert '"send_failed_pre_accept"' in classifier_source
assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source
assert 'errorCode === "TURN_CANCELLING"' in classifier_source
assert "if (withCode.code) return withCode.code;" in classifier_source
assert 'userMessage: "Message not sent. Please retry."' in classifier_source
assert 'userMessage: "Connection issue. Please try again."' in classifier_source
assert "const passthroughCodes = new Set([" in request_errors_source
assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source
assert '"THREAD_BUSY"' in request_errors_source
assert '"TURN_CANCELLING"' in request_errors_source
assert '"AUTH_EXPIRED"' in request_errors_source
assert '"UNAUTHORIZED"' in request_errors_source
assert '"RATE_LIMITED"' in request_errors_source
assert '"NETWORK_ERROR"' in request_errors_source
assert '"STREAM_PARSE_ERROR"' in request_errors_source
assert '"TOOL_EXECUTION_ERROR"' in request_errors_source
assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source
assert '"SERVER_ERROR"' in request_errors_source
assert "passthroughCodes.has(existingCode)" in request_errors_source
assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source
assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source
assert "Failed to start chat. Please try again." not in classifier_source
def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows():
page_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
)
source = page_path.read_text(encoding="utf-8")
# Each flow tracks accepted boundary and passes it into shared terminal handling.
assert "let newAccepted = false;" in source
assert "let resumeAccepted = false;" in source
assert "let regenerateAccepted = false;" in source
assert "accepted: newAccepted," in source
assert "accepted: resumeAccepted," in source
assert "accepted: regenerateAccepted," in source
# Pre-accept abort in resume/regenerate exits without persistence.
assert "if (!resumeAccepted) return;" in source
assert "if (!regenerateAccepted) return;" in source
# New flow persists only when accepted and not already persisted.
assert "if (newAccepted && !userPersisted) {" in source
assert "const fetchWithTurnCancellingRetry = useCallback(" in source
assert "computeFallbackTurnCancellingRetryDelay" in source
assert 'withMeta.errorCode === "TURN_CANCELLING"' in source
assert 'withMeta.errorCode === "THREAD_BUSY"' in source
assert "await fetchWithTurnCancellingRetry(() =>" in source
def test_cancel_active_turn_route_contract_exists():
routes_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_backend/app/routes/new_chat_routes.py"
)
source = routes_path.read_text(encoding="utf-8")
assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source
assert "response_model=CancelActiveTurnResponse" in source
assert 'status="cancelling",' in source
assert 'error_code="TURN_CANCELLING",' in source
assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source
assert "retry_after_at=" in source
assert 'status="idle",' in source
assert 'error_code="NO_ACTIVE_TURN",' in source
def test_turn_status_route_contract_exists():
routes_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_backend/app/routes/new_chat_routes.py"
)
source = routes_path.read_text(encoding="utf-8")
assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source
assert "response_model=TurnStatusResponse" in source
assert "_build_turn_status_payload(thread_id)" in source
assert "Permission.CHATS_READ.value" in source
assert "_raise_if_thread_busy_for_start(" in source
def test_turn_cancelling_retry_policy_contract_exists():
routes_path = (
Path(__file__).resolve().parents[3]
/ "surfsense_backend/app/routes/new_chat_routes.py"
)
source = routes_path.read_text(encoding="utf-8")
assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source
assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source
assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source
assert "def _compute_turn_cancelling_retry_delay(" in source
assert "retry-after-ms" in source
assert '"Retry-After"' in source
assert '"errorCode": "TURN_CANCELLING"' in source
def test_turn_status_sse_contract_exists():
stream_source = (
Path(__file__).resolve().parents[3]
/ "surfsense_backend/app/tasks/chat/stream_new_chat.py"
).read_text(encoding="utf-8")
state_source = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/lib/chat/streaming-state.ts"
).read_text(encoding="utf-8")
pipeline_source = (
Path(__file__).resolve().parents[3]
/ "surfsense_web/lib/chat/stream-pipeline.ts"
).read_text(encoding="utf-8")
assert '"turn-status"' in stream_source
assert '"status": "busy"' in stream_source
assert '"status": "idle"' in stream_source
assert 'type: "data-turn-status"' in state_source
assert 'case "data-turn-status":' in pipeline_source
assert "end_turn(str(chat_id))" in stream_source
def test_chat_deepagent_forwards_resolved_model_name_to_both_builders():
"""Regression guard: both system-prompt builders in chat_deepagent.py
must receive ``model_name=_resolve_prompt_model_name(...)`` so the
provider-variant dispatch can render the right ``<provider_hints>``
block. Without this the prompt silently falls back to the empty
``"default"`` variant the original bug being fixed.
This test mirrors :func:`test_stream_error_emission_keeps_machine_error_codes`
in style: it inspects module source text + a regex to enforce the
call-site shape, not just the wrapper layer (the wrappers already
forward ``model_name`` correctly, so testing them would not catch
the actual missed plumbing).
"""
import app.agents.new_chat.chat_deepagent as chat_deepagent_module
source = inspect.getsource(chat_deepagent_module)
# Helper itself must be defined.
assert "def _resolve_prompt_model_name(" in source
# Both builder calls must forward the resolved model name. Match
# across newlines + whitespace because the kwargs are split over
# multiple lines.
pattern = re.compile(
r"build_(?:surfsense|configurable)_system_prompt\([^)]*"
r"model_name=_resolve_prompt_model_name\(",
re.DOTALL,
)
matches = pattern.findall(source)
assert len(matches) == 2, (
"Expected both system-prompt builder call sites to forward "
"`model_name=_resolve_prompt_model_name(...)`, found "
f"{len(matches)}"
)

View file

@ -62,7 +62,7 @@ wheels = [
[[package]]
name = "aiohttp"
version = "3.13.5"
version = "3.13.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiohappyeyeballs" },
@ -73,76 +73,76 @@ dependencies = [
{ name = "propcache" },
{ name = "yarl" },
]
sdist = { url = "https://files.pythonhosted.org/packages/77/9a/152096d4808df8e4268befa55fba462f440f14beab85e8ad9bf990516918/aiohttp-3.13.5.tar.gz", hash = "sha256:9d98cc980ecc96be6eb4c1994ce35d28d8b1f5e5208a23b421187d1209dbb7d1", size = 7858271 }
sdist = { url = "https://files.pythonhosted.org/packages/45/4a/064321452809dae953c1ed6e017504e72551a26b6f5708a5a80e4bf556ff/aiohttp-3.13.4.tar.gz", hash = "sha256:d97a6d09c66087890c2ab5d49069e1e570583f7ac0314ecf98294c1b6aaebd38", size = 7859748 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/be/6f/353954c29e7dcce7cf00280a02c75f30e133c00793c7a2ed3776d7b2f426/aiohttp-3.13.5-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:023ecba036ddd840b0b19bf195bfae970083fd7024ce1ac22e9bba90464620e9", size = 748876 },
{ url = "https://files.pythonhosted.org/packages/f5/1b/428a7c64687b3b2e9cd293186695affc0e1e54a445d0361743b231f11066/aiohttp-3.13.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:15c933ad7920b7d9a20de151efcd05a6e38302cbf0e10c9b2acb9a42210a2416", size = 499557 },
{ url = "https://files.pythonhosted.org/packages/29/47/7be41556bfbb6917069d6a6634bb7dd5e163ba445b783a90d40f5ac7e3a7/aiohttp-3.13.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ab2899f9fa2f9f741896ebb6fa07c4c883bfa5c7f2ddd8cf2aafa86fa981b2d2", size = 500258 },
{ url = "https://files.pythonhosted.org/packages/67/84/c9ecc5828cb0b3695856c07c0a6817a99d51e2473400f705275a2b3d9239/aiohttp-3.13.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a60eaa2d440cd4707696b52e40ed3e2b0f73f65be07fd0ef23b6b539c9c0b0b4", size = 1749199 },
{ url = "https://files.pythonhosted.org/packages/f0/d3/3c6d610e66b495657622edb6ae7c7fd31b2e9086b4ec50b47897ad6042a9/aiohttp-3.13.5-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:55b3bdd3292283295774ab585160c4004f4f2f203946997f49aac032c84649e9", size = 1721013 },
{ url = "https://files.pythonhosted.org/packages/49/a0/24409c12217456df0bae7babe3b014e460b0b38a8e60753d6cb339f6556d/aiohttp-3.13.5-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c2b2355dc094e5f7d45a7bb262fe7207aa0460b37a0d87027dcf21b5d890e7d5", size = 1781501 },
{ url = "https://files.pythonhosted.org/packages/98/9d/b65ec649adc5bccc008b0957a9a9c691070aeac4e41cea18559fef49958b/aiohttp-3.13.5-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b38765950832f7d728297689ad78f5f2cf79ff82487131c4d26fe6ceecdc5f8e", size = 1878981 },
{ url = "https://files.pythonhosted.org/packages/57/d8/8d44036d7eb7b6a8ec4c5494ea0c8c8b94fbc0ed3991c1a7adf230df03bf/aiohttp-3.13.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b18f31b80d5a33661e08c89e202edabf1986e9b49c42b4504371daeaa11b47c1", size = 1767934 },
{ url = "https://files.pythonhosted.org/packages/31/04/d3f8211f273356f158e3464e9e45484d3fb8c4ce5eb2f6fe9405c3273983/aiohttp-3.13.5-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:33add2463dde55c4f2d9635c6ab33ce154e5ecf322bd26d09af95c5f81cfa286", size = 1566671 },
{ url = "https://files.pythonhosted.org/packages/41/db/073e4ebe00b78e2dfcacff734291651729a62953b48933d765dc513bf798/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:327cc432fdf1356fb4fbc6fe833ad4e9f6aacb71a8acaa5f1855e4b25910e4a9", size = 1705219 },
{ url = "https://files.pythonhosted.org/packages/48/45/7dfba71a2f9fd97b15c95c06819de7eb38113d2cdb6319669195a7d64270/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:7c35b0bf0b48a70b4cb4fc5d7bed9b932532728e124874355de1a0af8ec4bc88", size = 1743049 },
{ url = "https://files.pythonhosted.org/packages/18/71/901db0061e0f717d226386a7f471bb59b19566f2cae5f0d93874b017271f/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:df23d57718f24badef8656c49743e11a89fd6f5358fa8a7b96e728fda2abf7d3", size = 1749557 },
{ url = "https://files.pythonhosted.org/packages/08/d5/41eebd16066e59cd43728fe74bce953d7402f2b4ddfdfef2c0e9f17ca274/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:02e048037a6501a5ec1f6fc9736135aec6eb8a004ce48838cb951c515f32c80b", size = 1558931 },
{ url = "https://files.pythonhosted.org/packages/30/e6/4a799798bf05740e66c3a1161079bda7a3dd8e22ca392481d7a7f9af82a6/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:31cebae8b26f8a615d2b546fee45d5ffb76852ae6450e2a03f42c9102260d6fe", size = 1774125 },
{ url = "https://files.pythonhosted.org/packages/84/63/7749337c90f92bc2cb18f9560d67aa6258c7060d1397d21529b8004fcf6f/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:888e78eb5ca55a615d285c3c09a7a91b42e9dd6fc699b166ebd5dee87c9ccf14", size = 1732427 },
{ url = "https://files.pythonhosted.org/packages/98/de/cf2f44ff98d307e72fb97d5f5bbae3bfcb442f0ea9790c0bf5c5c2331404/aiohttp-3.13.5-cp312-cp312-win32.whl", hash = "sha256:8bd3ec6376e68a41f9f95f5ed170e2fcf22d4eb27a1f8cb361d0508f6e0557f3", size = 433534 },
{ url = "https://files.pythonhosted.org/packages/aa/ca/eadf6f9c8fa5e31d40993e3db153fb5ed0b11008ad5d9de98a95045bed84/aiohttp-3.13.5-cp312-cp312-win_amd64.whl", hash = "sha256:110e448e02c729bcebb18c60b9214a87ba33bac4a9fa5e9a5f139938b56c6cb1", size = 460446 },
{ url = "https://files.pythonhosted.org/packages/78/e9/d76bf503005709e390122d34e15256b88f7008e246c4bdbe915cd4f1adce/aiohttp-3.13.5-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a5029cc80718bbd545123cd8fe5d15025eccaaaace5d0eeec6bd556ad6163d61", size = 742930 },
{ url = "https://files.pythonhosted.org/packages/57/00/4b7b70223deaebd9bb85984d01a764b0d7bd6526fcdc73cca83bcbe7243e/aiohttp-3.13.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4bb6bf5811620003614076bdc807ef3b5e38244f9d25ca5fe888eaccea2a9832", size = 496927 },
{ url = "https://files.pythonhosted.org/packages/9c/f5/0fb20fb49f8efdcdce6cd8127604ad2c503e754a8f139f5e02b01626523f/aiohttp-3.13.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a84792f8631bf5a94e52d9cc881c0b824ab42717165a5579c760b830d9392ac9", size = 497141 },
{ url = "https://files.pythonhosted.org/packages/3b/86/b7c870053e36a94e8951b803cb5b909bfbc9b90ca941527f5fcafbf6b0fa/aiohttp-3.13.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:57653eac22c6a4c13eb22ecf4d673d64a12f266e72785ab1c8b8e5940d0e8090", size = 1732476 },
{ url = "https://files.pythonhosted.org/packages/b5/e5/4e161f84f98d80c03a238671b4136e6530453d65262867d989bbe78244d0/aiohttp-3.13.5-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e5e5f7debc7a57af53fdf5c5009f9391d9f4c12867049d509bf7bb164a6e295b", size = 1706507 },
{ url = "https://files.pythonhosted.org/packages/d4/56/ea11a9f01518bd5a2a2fcee869d248c4b8a0cfa0bb13401574fa31adf4d4/aiohttp-3.13.5-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c719f65bebcdf6716f10e9eff80d27567f7892d8988c06de12bbbd39307c6e3a", size = 1773465 },
{ url = "https://files.pythonhosted.org/packages/eb/40/333ca27fb74b0383f17c90570c748f7582501507307350a79d9f9f3c6eb1/aiohttp-3.13.5-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d97f93fdae594d886c5a866636397e2bcab146fd7a132fd6bb9ce182224452f8", size = 1873523 },
{ url = "https://files.pythonhosted.org/packages/f0/d2/e2f77eef1acb7111405433c707dc735e63f67a56e176e72e9e7a2cd3f493/aiohttp-3.13.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3df334e39d4c2f899a914f1dba283c1aadc311790733f705182998c6f7cae665", size = 1754113 },
{ url = "https://files.pythonhosted.org/packages/fb/56/3f653d7f53c89669301ec9e42c95233e2a0c0a6dd051269e6e678db4fdb0/aiohttp-3.13.5-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:fe6970addfea9e5e081401bcbadf865d2b6da045472f58af08427e108d618540", size = 1562351 },
{ url = "https://files.pythonhosted.org/packages/ec/a6/9b3e91eb8ae791cce4ee736da02211c85c6f835f1bdfac0594a8a3b7018c/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:7becdf835feff2f4f335d7477f121af787e3504b48b449ff737afb35869ba7bb", size = 1693205 },
{ url = "https://files.pythonhosted.org/packages/98/fc/bfb437a99a2fcebd6b6eaec609571954de2ed424f01c352f4b5504371dd3/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:676e5651705ad5d8a70aeb8eb6936c436d8ebbd56e63436cb7dd9bb36d2a9a46", size = 1730618 },
{ url = "https://files.pythonhosted.org/packages/e4/b6/c8534862126191a034f68153194c389addc285a0f1347d85096d349bbc15/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:9b16c653d38eb1a611cc898c41e76859ca27f119d25b53c12875fd0474ae31a8", size = 1745185 },
{ url = "https://files.pythonhosted.org/packages/0b/93/4ca8ee2ef5236e2707e0fd5fecb10ce214aee1ff4ab307af9c558bda3b37/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:999802d5fa0389f58decd24b537c54aa63c01c3219ce17d1214cbda3c2b22d2d", size = 1557311 },
{ url = "https://files.pythonhosted.org/packages/57/ae/76177b15f18c5f5d094f19901d284025db28eccc5ae374d1d254181d33f4/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:ec707059ee75732b1ba130ed5f9580fe10ff75180c812bc267ded039db5128c6", size = 1773147 },
{ url = "https://files.pythonhosted.org/packages/01/a4/62f05a0a98d88af59d93b7fcac564e5f18f513cb7471696ac286db970d6a/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2d6d44a5b48132053c2f6cd5c8cb14bc67e99a63594e336b0f2af81e94d5530c", size = 1730356 },
{ url = "https://files.pythonhosted.org/packages/e4/85/fc8601f59dfa8c9523808281f2da571f8b4699685f9809a228adcc90838d/aiohttp-3.13.5-cp313-cp313-win32.whl", hash = "sha256:329f292ed14d38a6c4c435e465f48bebb47479fd676a0411936cc371643225cc", size = 432637 },
{ url = "https://files.pythonhosted.org/packages/c0/1b/ac685a8882896acf0f6b31d689e3792199cfe7aba37969fa91da63a7fa27/aiohttp-3.13.5-cp313-cp313-win_amd64.whl", hash = "sha256:69f571de7500e0557801c0b51f4780482c0ec5fe2ac851af5a92cfce1af1cb83", size = 458896 },
{ url = "https://files.pythonhosted.org/packages/5d/ce/46572759afc859e867a5bc8ec3487315869013f59281ce61764f76d879de/aiohttp-3.13.5-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:eb4639f32fd4a9904ab8fb45bf3383ba71137f3d9d4ba25b3b3f3109977c5b8c", size = 745721 },
{ url = "https://files.pythonhosted.org/packages/13/fe/8a2efd7626dbe6049b2ef8ace18ffda8a4dfcbe1bcff3ac30c0c7575c20b/aiohttp-3.13.5-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:7e5dc4311bd5ac493886c63cbf76ab579dbe4641268e7c74e48e774c74b6f2be", size = 497663 },
{ url = "https://files.pythonhosted.org/packages/9b/91/cc8cc78a111826c54743d88651e1687008133c37e5ee615fee9b57990fac/aiohttp-3.13.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:756c3c304d394977519824449600adaf2be0ccee76d206ee339c5e76b70ded25", size = 499094 },
{ url = "https://files.pythonhosted.org/packages/0a/33/a8362cb15cf16a3af7e86ed11962d5cd7d59b449202dc576cdc731310bde/aiohttp-3.13.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ecc26751323224cf8186efcf7fbcbc30f4e1d8c7970659daf25ad995e4032a56", size = 1726701 },
{ url = "https://files.pythonhosted.org/packages/45/0c/c091ac5c3a17114bd76cbf85d674650969ddf93387876cf67f754204bd77/aiohttp-3.13.5-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:10a75acfcf794edf9d8db50e5a7ec5fc818b2a8d3f591ce93bc7b1210df016d2", size = 1683360 },
{ url = "https://files.pythonhosted.org/packages/23/73/bcee1c2b79bc275e964d1446c55c54441a461938e70267c86afaae6fba27/aiohttp-3.13.5-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:0f7a18f258d124cd678c5fe072fe4432a4d5232b0657fca7c1847f599233c83a", size = 1773023 },
{ url = "https://files.pythonhosted.org/packages/c7/ef/720e639df03004fee2d869f771799d8c23046dec47d5b81e396c7cda583a/aiohttp-3.13.5-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:df6104c009713d3a89621096f3e3e88cc323fd269dbd7c20afe18535094320be", size = 1853795 },
{ url = "https://files.pythonhosted.org/packages/bd/c9/989f4034fb46841208de7aeeac2c6d8300745ab4f28c42f629ba77c2d916/aiohttp-3.13.5-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:241a94f7de7c0c3b616627aaad530fe2cb620084a8b144d3be7b6ecfe95bae3b", size = 1730405 },
{ url = "https://files.pythonhosted.org/packages/ce/75/ee1fd286ca7dc599d824b5651dad7b3be7ff8d9a7e7b3fe9820d9180f7db/aiohttp-3.13.5-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c974fb66180e58709b6fc402846f13791240d180b74de81d23913abe48e96d94", size = 1558082 },
{ url = "https://files.pythonhosted.org/packages/c3/20/1e9e6650dfc436340116b7aa89ff8cb2bbdf0abc11dfaceaad8f74273a10/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:6e27ea05d184afac78aabbac667450c75e54e35f62238d44463131bd3f96753d", size = 1692346 },
{ url = "https://files.pythonhosted.org/packages/d8/40/8ebc6658d48ea630ac7903912fe0dd4e262f0e16825aa4c833c56c9f1f56/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:a79a6d399cef33a11b6f004c67bb07741d91f2be01b8d712d52c75711b1e07c7", size = 1698891 },
{ url = "https://files.pythonhosted.org/packages/d8/78/ea0ae5ec8ba7a5c10bdd6e318f1ba5e76fcde17db8275188772afc7917a4/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:c632ce9c0b534fbe25b52c974515ed674937c5b99f549a92127c85f771a78772", size = 1742113 },
{ url = "https://files.pythonhosted.org/packages/8a/66/9d308ed71e3f2491be1acb8769d96c6f0c47d92099f3bc9119cada27b357/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:fceedde51fbd67ee2bcc8c0b33d0126cc8b51ef3bbde2f86662bd6d5a6f10ec5", size = 1553088 },
{ url = "https://files.pythonhosted.org/packages/da/a6/6cc25ed8dfc6e00c90f5c6d126a98e2cf28957ad06fa1036bd34b6f24a2c/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:f92995dfec9420bb69ae629abf422e516923ba79ba4403bc750d94fb4a6c68c1", size = 1757976 },
{ url = "https://files.pythonhosted.org/packages/c1/2b/cce5b0ffe0de99c83e5e36d8f828e4161e415660a9f3e58339d07cce3006/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:20ae0ff08b1f2c8788d6fb85afcb798654ae6ba0b747575f8562de738078457b", size = 1712444 },
{ url = "https://files.pythonhosted.org/packages/6c/cf/9e1795b4160c58d29421eafd1a69c6ce351e2f7c8d3c6b7e4ca44aea1a5b/aiohttp-3.13.5-cp314-cp314-win32.whl", hash = "sha256:b20df693de16f42b2472a9c485e1c948ee55524786a0a34345511afdd22246f3", size = 438128 },
{ url = "https://files.pythonhosted.org/packages/22/4d/eaedff67fc805aeba4ba746aec891b4b24cebb1a7d078084b6300f79d063/aiohttp-3.13.5-cp314-cp314-win_amd64.whl", hash = "sha256:f85c6f327bf0b8c29da7d93b1cabb6363fb5e4e160a32fa241ed2dce21b73162", size = 464029 },
{ url = "https://files.pythonhosted.org/packages/79/11/c27d9332ee20d68dd164dc12a6ecdef2e2e35ecc97ed6cf0d2442844624b/aiohttp-3.13.5-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:1efb06900858bb618ff5cee184ae2de5828896c448403d51fb633f09e109be0a", size = 778758 },
{ url = "https://files.pythonhosted.org/packages/04/fb/377aead2e0a3ba5f09b7624f702a964bdf4f08b5b6728a9799830c80041e/aiohttp-3.13.5-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:fee86b7c4bd29bdaf0d53d14739b08a106fdda809ca5fe032a15f52fae5fe254", size = 512883 },
{ url = "https://files.pythonhosted.org/packages/bb/a6/aa109a33671f7a5d3bd78b46da9d852797c5e665bfda7d6b373f56bff2ec/aiohttp-3.13.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:20058e23909b9e65f9da62b396b77dfa95965cbe840f8def6e572538b1d32e36", size = 516668 },
{ url = "https://files.pythonhosted.org/packages/79/b3/ca078f9f2fa9563c36fb8ef89053ea2bb146d6f792c5104574d49d8acb63/aiohttp-3.13.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8cf20a8d6868cb15a73cab329ffc07291ba8c22b1b88176026106ae39aa6df0f", size = 1883461 },
{ url = "https://files.pythonhosted.org/packages/b7/e3/a7ad633ca1ca497b852233a3cce6906a56c3225fb6d9217b5e5e60b7419d/aiohttp-3.13.5-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:330f5da04c987f1d5bdb8ae189137c77139f36bd1cb23779ca1a354a4b027800", size = 1747661 },
{ url = "https://files.pythonhosted.org/packages/33/b9/cd6fe579bed34a906d3d783fe60f2fa297ef55b27bb4538438ee49d4dc41/aiohttp-3.13.5-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:6f1cbf0c7926d315c3c26c2da41fd2b5d2fe01ac0e157b78caefc51a782196cf", size = 1863800 },
{ url = "https://files.pythonhosted.org/packages/c0/3f/2c1e2f5144cefa889c8afd5cf431994c32f3b29da9961698ff4e3811b79a/aiohttp-3.13.5-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:53fc049ed6390d05423ba33103ded7281fe897cf97878f369a527070bd95795b", size = 1958382 },
{ url = "https://files.pythonhosted.org/packages/66/1d/f31ec3f1013723b3babe3609e7f119c2c2fb6ef33da90061a705ef3e1bc8/aiohttp-3.13.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:898703aa2667e3c5ca4c54ca36cd73f58b7a38ef87a5606414799ebce4d3fd3a", size = 1803724 },
{ url = "https://files.pythonhosted.org/packages/0e/b4/57712dfc6f1542f067daa81eb61da282fab3e6f1966fca25db06c4fc62d5/aiohttp-3.13.5-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:0494a01ca9584eea1e5fbd6d748e61ecff218c51b576ee1999c23db7066417d8", size = 1640027 },
{ url = "https://files.pythonhosted.org/packages/25/3c/734c878fb43ec083d8e31bf029daae1beafeae582d1b35da234739e82ee7/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:6cf81fe010b8c17b09495cbd15c1d35afbc8fb405c0c9cf4738e5ae3af1d65be", size = 1806644 },
{ url = "https://files.pythonhosted.org/packages/20/a5/f671e5cbec1c21d044ff3078223f949748f3a7f86b14e34a365d74a5d21f/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:c564dd5f09ddc9d8f2c2d0a301cd30a79a2cc1b46dd1a73bef8f0038863d016b", size = 1791630 },
{ url = "https://files.pythonhosted.org/packages/0b/63/fb8d0ad63a0b8a99be97deac8c04dacf0785721c158bdf23d679a87aa99e/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:2994be9f6e51046c4f864598fd9abeb4fba6e88f0b2152422c9666dcd4aea9c6", size = 1809403 },
{ url = "https://files.pythonhosted.org/packages/59/0c/bfed7f30662fcf12206481c2aac57dedee43fe1c49275e85b3a1e1742294/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:157826e2fa245d2ef46c83ea8a5faf77ca19355d278d425c29fda0beb3318037", size = 1634924 },
{ url = "https://files.pythonhosted.org/packages/17/d6/fd518d668a09fd5a3319ae5e984d4d80b9a4b3df4e21c52f02251ef5a32e/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:a8aca50daa9493e9e13c0f566201a9006f080e7c50e5e90d0b06f53146a54500", size = 1836119 },
{ url = "https://files.pythonhosted.org/packages/78/b7/15fb7a9d52e112a25b621c67b69c167805cb1f2ab8f1708a5c490d1b52fe/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:3b13560160d07e047a93f23aaa30718606493036253d5430887514715b67c9d9", size = 1772072 },
{ url = "https://files.pythonhosted.org/packages/7e/df/57ba7f0c4a553fc2bd8b6321df236870ec6fd64a2a473a8a13d4f733214e/aiohttp-3.13.5-cp314-cp314t-win32.whl", hash = "sha256:9a0f4474b6ea6818b41f82172d799e4b3d29e22c2c520ce4357856fced9af2f8", size = 471819 },
{ url = "https://files.pythonhosted.org/packages/62/29/2f8418269e46454a26171bfdd6a055d74febf32234e474930f2f60a17145/aiohttp-3.13.5-cp314-cp314t-win_amd64.whl", hash = "sha256:18a2f6c1182c51baa1d28d68fea51513cb2a76612f038853c0ad3c145423d3d9", size = 505441 },
{ url = "https://files.pythonhosted.org/packages/1e/bd/ede278648914cabbabfdf95e436679b5d4156e417896a9b9f4587169e376/aiohttp-3.13.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ee62d4471ce86b108b19c3364db4b91180d13fe3510144872d6bad5401957360", size = 752158 },
{ url = "https://files.pythonhosted.org/packages/90/de/581c053253c07b480b03785196ca5335e3c606a37dc73e95f6527f1591fe/aiohttp-3.13.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c0fd8f41b54b58636402eb493afd512c23580456f022c1ba2db0f810c959ed0d", size = 501037 },
{ url = "https://files.pythonhosted.org/packages/fa/f9/a5ede193c08f13cc42c0a5b50d1e246ecee9115e4cf6e900d8dbd8fd6acb/aiohttp-3.13.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4baa48ce49efd82d6b1a0be12d6a36b35e5594d1dd42f8bfba96ea9f8678b88c", size = 501556 },
{ url = "https://files.pythonhosted.org/packages/d6/10/88ff67cd48a6ec36335b63a640abe86135791544863e0cfe1f065d6cef7a/aiohttp-3.13.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d738ebab9f71ee652d9dbd0211057690022201b11197f9a7324fd4dba128aa97", size = 1757314 },
{ url = "https://files.pythonhosted.org/packages/8b/15/fdb90a5cf5a1f52845c276e76298c75fbbcc0ac2b4a86551906d54529965/aiohttp-3.13.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:0ce692c3468fa831af7dceed52edf51ac348cebfc8d3feb935927b63bd3e8576", size = 1731819 },
{ url = "https://files.pythonhosted.org/packages/ec/df/28146785a007f7820416be05d4f28cc207493efd1e8c6c1068e9bdc29198/aiohttp-3.13.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8e08abcfe752a454d2cb89ff0c08f2d1ecd057ae3e8cc6d84638de853530ebab", size = 1793279 },
{ url = "https://files.pythonhosted.org/packages/10/47/689c743abf62ea7a77774d5722f220e2c912a77d65d368b884d9779ef41b/aiohttp-3.13.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5977f701b3fff36367a11087f30ea73c212e686d41cd363c50c022d48b011d8d", size = 1891082 },
{ url = "https://files.pythonhosted.org/packages/b0/b6/f7f4f318c7e58c23b761c9b13b9a3c9b394e0f9d5d76fbc6622fa98509f6/aiohttp-3.13.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:54203e10405c06f8b6020bd1e076ae0fe6c194adcee12a5a78af3ffa3c57025e", size = 1773938 },
{ url = "https://files.pythonhosted.org/packages/aa/06/f207cb3121852c989586a6fc16ff854c4fcc8651b86c5d3bd1fc83057650/aiohttp-3.13.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:358a6af0145bc4dda037f13167bef3cce54b132087acc4c295c739d05d16b1c3", size = 1579548 },
{ url = "https://files.pythonhosted.org/packages/6c/58/e1289661a32161e24c1fe479711d783067210d266842523752869cc1d9c2/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:898ea1850656d7d61832ef06aa9846ab3ddb1621b74f46de78fbc5e1a586ba83", size = 1714669 },
{ url = "https://files.pythonhosted.org/packages/96/0a/3e86d039438a74a86e6a948a9119b22540bae037d6ba317a042ae3c22711/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:7bc30cceb710cf6a44e9617e43eebb6e3e43ad855a34da7b4b6a73537d8a6763", size = 1754175 },
{ url = "https://files.pythonhosted.org/packages/f4/30/e717fc5df83133ba467a560b6d8ef20197037b4bb5d7075b90037de1018e/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4a31c0c587a8a038f19a4c7e60654a6c899c9de9174593a13e7cc6e15ff271f9", size = 1762049 },
{ url = "https://files.pythonhosted.org/packages/e4/28/8f7a2d4492e336e40005151bdd94baf344880a4707573378579f833a64c1/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:2062f675f3fe6e06d6113eb74a157fb9df58953ffed0cdb4182554b116545758", size = 1570861 },
{ url = "https://files.pythonhosted.org/packages/78/45/12e1a3d0645968b1c38de4b23fdf270b8637735ea057d4f84482ff918ad9/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:3d1ba8afb847ff80626d5e408c1fdc99f942acc877d0702fe137015903a220a9", size = 1790003 },
{ url = "https://files.pythonhosted.org/packages/eb/0f/60374e18d590de16dcb39d6ff62f39c096c1b958e6f37727b5870026ea30/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b08149419994cdd4d5eecf7fd4bc5986b5a9380285bcd01ab4c0d6bfca47b79d", size = 1737289 },
{ url = "https://files.pythonhosted.org/packages/02/bf/535e58d886cfbc40a8b0013c974afad24ef7632d645bca0b678b70033a60/aiohttp-3.13.4-cp312-cp312-win32.whl", hash = "sha256:fc432f6a2c4f720180959bc19aa37259651c1a4ed8af8afc84dd41c60f15f791", size = 434185 },
{ url = "https://files.pythonhosted.org/packages/1e/1a/d92e3325134ebfff6f4069f270d3aac770d63320bd1fcd0eca023e74d9a8/aiohttp-3.13.4-cp312-cp312-win_amd64.whl", hash = "sha256:6148c9ae97a3e8bff9a1fc9c757fa164116f86c100468339730e717590a3fb77", size = 461285 },
{ url = "https://files.pythonhosted.org/packages/e3/ac/892f4162df9b115b4758d615f32ec63d00f3084c705ff5526630887b9b42/aiohttp-3.13.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:63dd5e5b1e43b8fb1e91b79b7ceba1feba588b317d1edff385084fcc7a0a4538", size = 745744 },
{ url = "https://files.pythonhosted.org/packages/97/a9/c5b87e4443a2f0ea88cb3000c93a8fdad1ee63bffc9ded8d8c8e0d66efc6/aiohttp-3.13.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:746ac3cc00b5baea424dacddea3ec2c2702f9590de27d837aa67004db1eebc6e", size = 498178 },
{ url = "https://files.pythonhosted.org/packages/94/42/07e1b543a61250783650df13da8ddcdc0d0a5538b2bd15cef6e042aefc61/aiohttp-3.13.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bda8f16ea99d6a6705e5946732e48487a448be874e54a4f73d514660ff7c05d3", size = 498331 },
{ url = "https://files.pythonhosted.org/packages/20/d6/492f46bf0328534124772d0cf58570acae5b286ea25006900650f69dae0e/aiohttp-3.13.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4b061e7b5f840391e3f64d0ddf672973e45c4cfff7a0feea425ea24e51530fc2", size = 1744414 },
{ url = "https://files.pythonhosted.org/packages/e2/4d/e02627b2683f68051246215d2d62b2d2f249ff7a285e7a858dc47d6b6a14/aiohttp-3.13.4-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:b252e8d5cd66184b570d0d010de742736e8a4fab22c58299772b0c5a466d4b21", size = 1719226 },
{ url = "https://files.pythonhosted.org/packages/7b/6c/5d0a3394dd2b9f9aeba6e1b6065d0439e4b75d41f1fb09a3ec010b43552b/aiohttp-3.13.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:20af8aad61d1803ff11152a26146d8d81c266aa8c5aa9b4504432abb965c36a0", size = 1782110 },
{ url = "https://files.pythonhosted.org/packages/0d/2d/c20791e3437700a7441a7edfb59731150322424f5aadf635602d1d326101/aiohttp-3.13.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:13a5cc924b59859ad2adb1478e31f410a7ed46e92a2a619d6d1dd1a63c1a855e", size = 1884809 },
{ url = "https://files.pythonhosted.org/packages/c8/94/d99dbfbd1924a87ef643833932eb2a3d9e5eee87656efea7d78058539eff/aiohttp-3.13.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:534913dfb0a644d537aebb4123e7d466d94e3be5549205e6a31f72368980a81a", size = 1764938 },
{ url = "https://files.pythonhosted.org/packages/49/61/3ce326a1538781deb89f6cf5e094e2029cd308ed1e21b2ba2278b08426f6/aiohttp-3.13.4-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:320e40192a2dcc1cf4b5576936e9652981ab596bf81eb309535db7e2f5b5672f", size = 1570697 },
{ url = "https://files.pythonhosted.org/packages/b6/77/4ab5a546857bb3028fbaf34d6eea180267bdab022ee8b1168b1fcde4bfdd/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9e587fcfce2bcf06526a43cb705bdee21ac089096f2e271d75de9c339db3100c", size = 1702258 },
{ url = "https://files.pythonhosted.org/packages/79/63/d8f29021e39bc5af8e5d5e9da1b07976fb9846487a784e11e4f4eeda4666/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:9eb9c2eea7278206b5c6c1441fdd9dc420c278ead3f3b2cc87f9b693698cc500", size = 1740287 },
{ url = "https://files.pythonhosted.org/packages/55/3a/cbc6b3b124859a11bc8055d3682c26999b393531ef926754a3445b99dfef/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:29be00c51972b04bf9d5c8f2d7f7314f48f96070ca40a873a53056e652e805f7", size = 1753011 },
{ url = "https://files.pythonhosted.org/packages/e0/30/836278675205d58c1368b21520eab9572457cf19afd23759216c04483048/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:90c06228a6c3a7c9f776fe4fc0b7ff647fffd3bed93779a6913c804ae00c1073", size = 1566359 },
{ url = "https://files.pythonhosted.org/packages/50/b4/8032cc9b82d17e4277704ba30509eaccb39329dc18d6a35f05e424439e32/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:a533ec132f05fd9a1d959e7f34184cd7d5e8511584848dab85faefbaac573069", size = 1785537 },
{ url = "https://files.pythonhosted.org/packages/17/7d/5873e98230bde59f493bf1f7c3e327486a4b5653fa401144704df5d00211/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1c946f10f413836f82ea4cfb90200d2a59578c549f00857e03111cf45ad01ca5", size = 1740752 },
{ url = "https://files.pythonhosted.org/packages/7b/f2/13e46e0df051494d7d3c68b7f72d071f48c384c12716fc294f75d5b1a064/aiohttp-3.13.4-cp313-cp313-win32.whl", hash = "sha256:48708e2706106da6967eff5908c78ca3943f005ed6bcb75da2a7e4da94ef8c70", size = 433187 },
{ url = "https://files.pythonhosted.org/packages/ea/c0/649856ee655a843c8f8664592cfccb73ac80ede6a8c8db33a25d810c12db/aiohttp-3.13.4-cp313-cp313-win_amd64.whl", hash = "sha256:74a2eb058da44fa3a877a49e2095b591d4913308bb424c418b77beb160c55ce3", size = 459778 },
{ url = "https://files.pythonhosted.org/packages/6d/29/6657cc37ae04cacc2dbf53fb730a06b6091cc4cbe745028e047c53e6d840/aiohttp-3.13.4-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:e0a2c961fc92abeff61d6444f2ce6ad35bb982db9fc8ff8a47455beacf454a57", size = 749363 },
{ url = "https://files.pythonhosted.org/packages/90/7f/30ccdf67ca3d24b610067dc63d64dcb91e5d88e27667811640644aa4a85d/aiohttp-3.13.4-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:153274535985a0ff2bff1fb6c104ed547cec898a09213d21b0f791a44b14d933", size = 499317 },
{ url = "https://files.pythonhosted.org/packages/93/13/e372dd4e68ad04ee25dafb050c7f98b0d91ea643f7352757e87231102555/aiohttp-3.13.4-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:351f3171e2458da3d731ce83f9e6b9619e325c45cbd534c7759750cabf453ad7", size = 500477 },
{ url = "https://files.pythonhosted.org/packages/e5/fe/ee6298e8e586096fb6f5eddd31393d8544f33ae0792c71ecbb4c2bef98ac/aiohttp-3.13.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f989ac8bc5595ff761a5ccd32bdb0768a117f36dd1504b1c2c074ed5d3f4df9c", size = 1737227 },
{ url = "https://files.pythonhosted.org/packages/b0/b9/a7a0463a09e1a3fe35100f74324f23644bfc3383ac5fd5effe0722a5f0b7/aiohttp-3.13.4-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:d36fc1709110ec1e87a229b201dd3ddc32aa01e98e7868083a794609b081c349", size = 1694036 },
{ url = "https://files.pythonhosted.org/packages/57/7c/8972ae3fb7be00a91aee6b644b2a6a909aedb2c425269a3bfd90115e6f8f/aiohttp-3.13.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:42adaeea83cbdf069ab94f5103ce0787c21fb1a0153270da76b59d5578302329", size = 1786814 },
{ url = "https://files.pythonhosted.org/packages/93/01/c81e97e85c774decbaf0d577de7d848934e8166a3a14ad9f8aa5be329d28/aiohttp-3.13.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:92deb95469928cc41fd4b42a95d8012fa6df93f6b1c0a83af0ffbc4a5e218cde", size = 1866676 },
{ url = "https://files.pythonhosted.org/packages/5a/5f/5b46fe8694a639ddea2cd035bf5729e4677ea882cb251396637e2ef1590d/aiohttp-3.13.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0c0c7c07c4257ef3a1df355f840bc62d133bcdef5c1c5ba75add3c08553e2eed", size = 1740842 },
{ url = "https://files.pythonhosted.org/packages/20/a2/0d4b03d011cca6b6b0acba8433193c1e484efa8d705ea58295590fe24203/aiohttp-3.13.4-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f062c45de8a1098cb137a1898819796a2491aec4e637a06b03f149315dff4d8f", size = 1566508 },
{ url = "https://files.pythonhosted.org/packages/98/17/e689fd500da52488ec5f889effd6404dece6a59de301e380f3c64f167beb/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:76093107c531517001114f0ebdb4f46858ce818590363e3e99a4a2280334454a", size = 1700569 },
{ url = "https://files.pythonhosted.org/packages/d8/0d/66402894dbcf470ef7db99449e436105ea862c24f7ea4c95c683e635af35/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:6f6ec32162d293b82f8b63a16edc80769662fbd5ae6fbd4936d3206a2c2cc63b", size = 1707407 },
{ url = "https://files.pythonhosted.org/packages/2f/eb/af0ab1a3650092cbd8e14ef29e4ab0209e1460e1c299996c3f8288b3f1ff/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:5903e2db3d202a00ad9f0ec35a122c005e85d90c9836ab4cda628f01edf425e2", size = 1752214 },
{ url = "https://files.pythonhosted.org/packages/5a/bf/72326f8a98e4c666f292f03c385545963cc65e358835d2a7375037a97b57/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:2d5bea57be7aca98dbbac8da046d99b5557c5cf4e28538c4c786313078aca09e", size = 1562162 },
{ url = "https://files.pythonhosted.org/packages/67/9f/13b72435f99151dd9a5469c96b3b5f86aa29b7e785ca7f35cf5e538f74c0/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:bcf0c9902085976edc0232b75006ef38f89686901249ce14226b6877f88464fb", size = 1768904 },
{ url = "https://files.pythonhosted.org/packages/18/bc/28d4970e7d5452ac7776cdb5431a1164a0d9cf8bd2fffd67b4fb463aa56d/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:c3295f98bfeed2e867cab588f2a146a9db37a85e3ae9062abf46ba062bd29165", size = 1723378 },
{ url = "https://files.pythonhosted.org/packages/53/74/b32458ca1a7f34d65bdee7aef2036adbe0438123d3d53e2b083c453c24dd/aiohttp-3.13.4-cp314-cp314-win32.whl", hash = "sha256:a598a5c5767e1369d8f5b08695cab1d8160040f796c4416af76fd773d229b3c9", size = 438711 },
{ url = "https://files.pythonhosted.org/packages/40/b2/54b487316c2df3e03a8f3435e9636f8a81a42a69d942164830d193beb56a/aiohttp-3.13.4-cp314-cp314-win_amd64.whl", hash = "sha256:c555db4bc7a264bead5a7d63d92d41a1122fcd39cc62a4db815f45ad46f9c2c8", size = 464977 },
{ url = "https://files.pythonhosted.org/packages/47/fb/e41b63c6ce71b07a59243bb8f3b457ee0c3402a619acb9d2c0d21ef0e647/aiohttp-3.13.4-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:45abbbf09a129825d13c18c7d3182fecd46d9da3cfc383756145394013604ac1", size = 781549 },
{ url = "https://files.pythonhosted.org/packages/97/53/532b8d28df1e17e44c4d9a9368b78dcb6bf0b51037522136eced13afa9e8/aiohttp-3.13.4-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:74c80b2bc2c2adb7b3d1941b2b60701ee2af8296fc8aad8b8bc48bc25767266c", size = 514383 },
{ url = "https://files.pythonhosted.org/packages/1b/1f/62e5d400603e8468cd635812d99cb81cfdc08127a3dc474c647615f31339/aiohttp-3.13.4-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:c97989ae40a9746650fa196894f317dafc12227c808c774929dda0ff873a5954", size = 518304 },
{ url = "https://files.pythonhosted.org/packages/90/57/2326b37b10896447e3c6e0cbef4fe2486d30913639a5cfd1332b5d870f82/aiohttp-3.13.4-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dae86be9811493f9990ef44fff1685f5c1a3192e9061a71a109d527944eed551", size = 1893433 },
{ url = "https://files.pythonhosted.org/packages/d2/b4/a24d82112c304afdb650167ef2fe190957d81cbddac7460bedd245f765aa/aiohttp-3.13.4-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:1db491abe852ca2fa6cc48a3341985b0174b3741838e1341b82ac82c8bd9e871", size = 1755901 },
{ url = "https://files.pythonhosted.org/packages/9e/2d/0883ef9d878d7846287f036c162a951968f22aabeef3ac97b0bea6f76d5d/aiohttp-3.13.4-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:0e5d701c0aad02a7dce72eef6b93226cf3734330f1a31d69ebbf69f33b86666e", size = 1876093 },
{ url = "https://files.pythonhosted.org/packages/ad/52/9204bb59c014869b71971addad6778f005daa72a96eed652c496789d7468/aiohttp-3.13.4-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8ac32a189081ae0a10ba18993f10f338ec94341f0d5df8fff348043962f3c6f8", size = 1970815 },
{ url = "https://files.pythonhosted.org/packages/d6/b5/e4eb20275a866dde0f570f411b36c6b48f7b53edfe4f4071aa1b0728098a/aiohttp-3.13.4-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:98e968cdaba43e45c73c3f306fca418c8009a957733bac85937c9f9cf3f4de27", size = 1816223 },
{ url = "https://files.pythonhosted.org/packages/d8/23/e98075c5bb146aa61a1239ee1ac7714c85e814838d6cebbe37d3fe19214a/aiohttp-3.13.4-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ca114790c9144c335d538852612d3e43ea0f075288f4849cf4b05d6cd2238ce7", size = 1649145 },
{ url = "https://files.pythonhosted.org/packages/d6/c1/7bad8be33bb06c2bb224b6468874346026092762cbec388c3bdb65a368ee/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:ea2e071661ba9cfe11eabbc81ac5376eaeb3061f6e72ec4cc86d7cdd1ffbdbbb", size = 1816562 },
{ url = "https://files.pythonhosted.org/packages/5c/10/c00323348695e9a5e316825969c88463dcc24c7e9d443244b8a2c9cf2eae/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:34e89912b6c20e0fd80e07fa401fd218a410aa1ce9f1c2f1dad6db1bd0ce0927", size = 1800333 },
{ url = "https://files.pythonhosted.org/packages/84/43/9b2147a1df3559f49bd723e22905b46a46c068a53adb54abdca32c4de180/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:0e217cf9f6a42908c52b46e42c568bd57adc39c9286ced31aaace614b6087965", size = 1820617 },
{ url = "https://files.pythonhosted.org/packages/a9/7f/b3481a81e7a586d02e99387b18c6dafff41285f6efd3daa2124c01f87eae/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:0c296f1221e21ba979f5ac1964c3b78cfde15c5c5f855ffd2caab337e9cd9182", size = 1643417 },
{ url = "https://files.pythonhosted.org/packages/8f/72/07181226bc99ce1124e0f89280f5221a82d3ae6a6d9d1973ce429d48e52b/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:d99a9d168ebaffb74f36d011750e490085ac418f4db926cce3989c8fe6cb6b1b", size = 1849286 },
{ url = "https://files.pythonhosted.org/packages/1a/e6/1b3566e103eca6da5be4ae6713e112a053725c584e96574caf117568ffef/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:cb19177205d93b881f3f89e6081593676043a6828f59c78c17a0fd6c1fbed2ba", size = 1782635 },
{ url = "https://files.pythonhosted.org/packages/37/58/1b11c71904b8d079eb0c39fe664180dd1e14bebe5608e235d8bfbadc8929/aiohttp-3.13.4-cp314-cp314t-win32.whl", hash = "sha256:c606aa5656dab6552e52ca368e43869c916338346bfaf6304e15c58fb113ea30", size = 472537 },
{ url = "https://files.pythonhosted.org/packages/bc/8f/87c56a1a1977d7dddea5b31e12189665a140fdb48a71e9038ff90bb564ec/aiohttp-3.13.4-cp314-cp314t-win_amd64.whl", hash = "sha256:014dcc10ec8ab8db681f0d68e939d1e9286a5aa2b993cbbdb0db130853e02144", size = 506381 },
]
[[package]]
@ -3723,7 +3723,7 @@ wheels = [
[[package]]
name = "litellm"
version = "1.83.4"
version = "1.83.14"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiohttp" },
@ -3739,9 +3739,9 @@ dependencies = [
{ name = "tiktoken" },
{ name = "tokenizers" },
]
sdist = { url = "https://files.pythonhosted.org/packages/03/c4/30469c06ae7437a4406bc11e3c433cfd380a6771068cca15ea918dcd158f/litellm-1.83.4.tar.gz", hash = "sha256:6458d2030a41229460b321adee00517a91dbd8e63213cc953d355cb41d16f2d4", size = 17733899 }
sdist = { url = "https://files.pythonhosted.org/packages/8d/7c/c095649380adc96c8630273c1768c2ad1e74aa2ee1dd8dd05d218a60569f/litellm-1.83.14.tar.gz", hash = "sha256:24aef9b47cdc424c833e32f3727f411741c690832cd1fe4405e0077144fe09c9", size = 14836599 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b8/bd/df19d3f8f6654535ee343a341fd921f81c411abf601a53e3eaef58129b02/litellm-1.83.4-py3-none-any.whl", hash = "sha256:17d7b4d48d47aca988ea4f762ddda5e7bd72cda3270192b22813d0330869d7b4", size = 16015555 },
{ url = "https://files.pythonhosted.org/packages/7f/5c/1b5691575420135e90578543b2bf219497caa33cfd0af64cb38f30288450/litellm-1.83.14-py3-none-any.whl", hash = "sha256:92b11ba2a32cf80707ddf388d18526696c7999a21b418c5e3b6eda1243d2cfdb", size = 16457054 },
]
[[package]]
@ -5124,7 +5124,7 @@ wheels = [
[[package]]
name = "openai"
version = "2.30.0"
version = "2.24.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
@ -5136,9 +5136,9 @@ dependencies = [
{ name = "tqdm" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/88/15/52580c8fbc16d0675d516e8749806eda679b16de1e4434ea06fb6feaa610/openai-2.30.0.tar.gz", hash = "sha256:92f7661c990bda4b22a941806c83eabe4896c3094465030dd882a71abe80c885", size = 676084 }
sdist = { url = "https://files.pythonhosted.org/packages/55/13/17e87641b89b74552ed408a92b231283786523edddc95f3545809fab673c/openai-2.24.0.tar.gz", hash = "sha256:1e5769f540dbd01cb33bc4716a23e67b9d695161a734aff9c5f925e2bf99a673", size = 658717 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2a/9e/5bfa2270f902d5b92ab7d41ce0475b8630572e71e349b2a4996d14bdda93/openai-2.30.0-py3-none-any.whl", hash = "sha256:9a5ae616888eb2748ec5e0c5b955a51592e0b201a11f4262db920f2a78c5231d", size = 1146656 },
{ url = "https://files.pythonhosted.org/packages/c9/30/844dc675ee6902579b8eef01ed23917cc9319a1c9c0c14ec6e39340c96d0/openai-2.24.0-py3-none-any.whl", hash = "sha256:fed30480d7d6c884303287bde864980a4b137b60553ffbcf9ab4a233b7a73d94", size = 1120122 },
]
[[package]]
@ -6780,11 +6780,11 @@ wheels = [
[[package]]
name = "python-dotenv"
version = "1.0.1"
version = "1.2.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/bc/57/e84d88dfe0aec03b7a2d4327012c1627ab5f03652216c63d49846d7a6c58/python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca", size = 39115 }
sdist = { url = "https://files.pythonhosted.org/packages/82/ed/0301aeeac3e5353ef3d94b6ec08bbcabd04a72018415dcb29e588514bba8/python_dotenv-1.2.2.tar.gz", hash = "sha256:2c371a91fbd7ba082c2c1dc1f8bf89ca22564a087c2c287cd9b662adde799cf3", size = 50135 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863 },
{ url = "https://files.pythonhosted.org/packages/0b/d7/1959b9648791274998a9c3526f6d0ec8fd2233e4d4acce81bbae76b44b2a/python_dotenv-1.2.2-py3-none-any.whl", hash = "sha256:1d8214789a24de455a8b8bd8ae6fe3c6b69a5e3d64aa8a8e5d68e694bbcb285a", size = 22101 },
]
[[package]]
@ -8070,7 +8070,7 @@ requires-dist = [
{ name = "langgraph", specifier = ">=1.1.3" },
{ name = "langgraph-checkpoint-postgres", specifier = ">=3.0.2" },
{ name = "linkup-sdk", specifier = ">=0.2.4" },
{ name = "litellm", specifier = ">=1.83.4" },
{ name = "litellm", specifier = ">=1.83.7" },
{ name = "llama-cloud-services", specifier = ">=0.6.25" },
{ name = "markdown", specifier = ">=3.7" },
{ name = "markdownify", specifier = ">=0.14.1" },

View file

@ -0,0 +1,35 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<!-- Required for Electron's V8 JIT under hardened runtime -->
<key>com.apple.security.cs.allow-jit</key>
<true/>
<key>com.apple.security.cs.allow-unsigned-executable-memory</key>
<true/>
<!-- node-mac-permissions and other native deps load dylibs at runtime -->
<key>com.apple.security.cs.allow-dyld-environment-variables</key>
<true/>
<key>com.apple.security.cs.disable-library-validation</key>
<true/>
<!-- Networking (OAuth, API calls, auto-updater, deep links) -->
<key>com.apple.security.network.client</key>
<true/>
<key>com.apple.security.network.server</key>
<true/>
<!-- Screen Capture / Screenshot Assist -->
<key>com.apple.security.device.camera</key>
<true/>
<!-- Accessibility / Apple Events used by general-assist -->
<key>com.apple.security.automation.apple-events</key>
<true/>
<!-- File access for folder watcher / agent filesystem features -->
<key>com.apple.security.files.user-selected.read-write</key>
<true/>
</dict>
</plist>

View file

@ -46,8 +46,11 @@ mac:
icon: assets/icon.icns
category: public.app-category.productivity
artifactName: "${productName}-${version}-${arch}.${ext}"
hardenedRuntime: false
hardenedRuntime: true
gatekeeperAssess: false
entitlements: build/entitlements.mac.plist
entitlementsInherit: build/entitlements.mac.plist
notarize: true
extendInfo:
NSAccessibilityUsageDescription: "SurfSense uses accessibility features to bring the app to the foreground and interact with the active application when you use desktop assists."
NSScreenCaptureUsageDescription: "SurfSense uses screen capture so you can attach a selected region to chat (Screenshot Assist) or capture the full screen from the composer."

View file

@ -1,11 +1,8 @@
"use client";
import { useQueryClient } from "@tanstack/react-query";
import { CheckCircle2 } from "lucide-react";
import Link from "next/link";
import { useParams } from "next/navigation";
import { useEffect } from "react";
import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms";
import { Button } from "@/components/ui/button";
import {
Card,
@ -18,14 +15,8 @@ import {
export default function PurchaseSuccessPage() {
const params = useParams();
const queryClient = useQueryClient();
const searchSpaceId = String(params.search_space_id ?? "");
useEffect(() => {
void queryClient.invalidateQueries({ queryKey: USER_QUERY_KEY });
void queryClient.invalidateQueries({ queryKey: ["token-status"] });
}, [queryClient]);
return (
<div className="flex min-h-[calc(100vh-64px)] items-center justify-center px-4 py-8">
<Card className="w-full max-w-lg">

View file

@ -132,8 +132,8 @@ export default function DesktopPermissionsPage() {
<div className="space-y-1">
<h1 className="text-2xl font-semibold tracking-tight">System Permissions</h1>
<p className="text-sm text-muted-foreground">
SurfSense needs two macOS permissions for Screenshot Assist and for desktop features that
require focusing the app or the active application.
SurfSense needs two macOS permissions for Screenshot Assist and for desktop features
that require focusing the app or the active application.
</p>
</div>
</div>

View file

@ -26,7 +26,14 @@ export const setThreadVisibilityAtom = atom(null, (get, set, newVisibility: Chat
export const resetCurrentThreadAtom = atom(null, (_, set) => {
set(currentThreadAtom, initialState);
set(reportPanelAtom, { isOpen: false, reportId: null, title: null, wordCount: null });
set(reportPanelAtom, {
isOpen: false,
reportId: null,
title: null,
wordCount: null,
shareToken: null,
contentType: "markdown",
});
});
/** Target comment ID to scroll to (from URL navigation or inbox click) */

View file

@ -0,0 +1,45 @@
import { atom } from "jotai";
export type PremiumAlertState = {
message: string;
};
export const premiumAlertByThreadAtom = atom<Record<number, PremiumAlertState>>({});
export const setPremiumAlertForThreadAtom = atom(
null,
(
get,
set,
payload: {
threadId: number;
message: string;
userId?: string | null;
}
) => {
const storageKey = `surfsense-premium-alert-seen-v1:${payload.userId ?? "anonymous"}`;
if (typeof window !== "undefined") {
const hasSeen = localStorage.getItem(storageKey) === "true";
if (hasSeen) return;
}
const current = get(premiumAlertByThreadAtom);
set(premiumAlertByThreadAtom, {
...current,
[payload.threadId]: { message: payload.message },
});
if (typeof window !== "undefined") {
localStorage.setItem(storageKey, "true");
}
}
);
export const clearPremiumAlertForThreadAtom = atom(null, (get, set, threadId: number) => {
const current = get(premiumAlertByThreadAtom);
if (!(threadId in current)) return;
const next = { ...current };
delete next[threadId];
set(premiumAlertByThreadAtom, next);
});

View file

@ -8,7 +8,10 @@ const userQueryFn = () => userApiService.getMe();
export const currentUserAtom = atomWithQuery(() => {
return {
queryKey: USER_QUERY_KEY,
staleTime: 5 * 60 * 1000,
// Live-changing numeric fields (pages_*, premium_tokens_*) are now
// pushed via Zero (queries.user.me()), so /users/me only needs to
// fire once per session for the static profile fields.
staleTime: Infinity,
enabled: !!getBearerToken(),
queryFn: userQueryFn,
};

View file

@ -17,16 +17,12 @@ import {
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import { Separator } from "@/components/ui/separator";
import { getToolIcon } from "@/contracts/enums/toolIcons";
import { getToolDisplayName, getToolIcon } from "@/contracts/enums/toolIcons";
import { type AgentAction, agentActionsApiService } from "@/lib/apis/agent-actions-api.service";
import { AppError } from "@/lib/error";
import { formatRelativeDate } from "@/lib/format-date";
import { cn } from "@/lib/utils";
function formatToolName(name: string): string {
return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase());
}
interface ActionLogItemProps {
action: AgentAction;
threadId: number;
@ -43,7 +39,7 @@ export function ActionLogItem({ action, threadId, onRevertSuccess }: ActionLogIt
const hasError = action.error !== null && action.error !== undefined;
const Icon = getToolIcon(action.tool_name);
const displayName = formatToolName(action.tool_name);
const displayName = getToolDisplayName(action.tool_name);
const argsPreview = action.args ? JSON.stringify(action.args, null, 2) : null;
const truncatedArgs =

View file

@ -1,9 +1,9 @@
"use client";
import { useQuery, useQueryClient } from "@tanstack/react-query";
import { useQueryClient } from "@tanstack/react-query";
import { useAtom, useAtomValue } from "jotai";
import { Activity, RefreshCcw } from "lucide-react";
import { useCallback, useMemo } from "react";
import { useCallback } from "react";
import { actionLogSheetAtom } from "@/atoms/agent/action-log-sheet.atom";
import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom";
import { Badge } from "@/components/ui/badge";
@ -17,15 +17,9 @@ import {
SheetTitle,
} from "@/components/ui/sheet";
import { Skeleton } from "@/components/ui/skeleton";
import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service";
import { agentActionsQueryKey, useAgentActionsQuery } from "@/hooks/use-agent-actions-query";
import { ActionLogItem } from "./action-log-item";
const ACTION_LOG_PAGE_SIZE = 50;
function actionLogQueryKey(threadId: number) {
return ["agent-actions", threadId] as const;
}
function EmptyState() {
return (
<div className="flex flex-1 flex-col items-center justify-center gap-3 px-6 text-center">
@ -85,25 +79,17 @@ export function ActionLogSheet() {
const threadId = state.threadId;
const { data, isLoading, isFetching, isError, error, refetch } = useQuery({
queryKey: threadId !== null ? actionLogQueryKey(threadId) : ["agent-actions", "none"],
queryFn: () =>
agentActionsApiService.listForThread(threadId as number, {
page: 0,
pageSize: ACTION_LOG_PAGE_SIZE,
}),
enabled: state.open && threadId !== null && actionLogEnabled,
staleTime: 15 * 1000,
});
const { data, items, isLoading, isFetching, isError, error, refetch } = useAgentActionsQuery(
threadId,
{ enabled: state.open && actionLogEnabled }
);
const handleRevertSuccess = useCallback(() => {
if (threadId !== null) {
queryClient.invalidateQueries({ queryKey: actionLogQueryKey(threadId) });
queryClient.invalidateQueries({ queryKey: agentActionsQueryKey(threadId) });
}
}, [queryClient, threadId]);
const items = useMemo(() => data?.items ?? [], [data]);
return (
<Sheet open={state.open} onOpenChange={(open) => setState((s) => ({ ...s, open }))}>
<SheetContent

View file

@ -33,6 +33,8 @@ import {
useAllCitationMetadata,
} from "@/components/assistant-ui/citation-metadata-context";
import { MarkdownText } from "@/components/assistant-ui/markdown-text";
import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part";
import { RevertTurnButton } from "@/components/assistant-ui/revert-turn-button";
import { useTokenUsage } from "@/components/assistant-ui/token-usage-context";
import { ToolFallback } from "@/components/assistant-ui/tool-fallback";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
@ -491,6 +493,7 @@ const AssistantMessageInner: FC = () => {
<MessagePrimitive.Parts
components={{
Text: MarkdownText,
Reasoning: ReasoningMessagePart,
tools: {
by_name: {
generate_report: GenerateReportToolUI,
@ -545,8 +548,10 @@ const AssistantMessageInner: FC = () => {
</div>
)}
<div className="aui-assistant-message-footer mt-3 mb-5 ml-2 flex items-center gap-2">
<AssistantActionBar />
<div className="aui-assistant-message-footer mt-3 mb-5 ml-2 h-6">
<div className="h-full opacity-100 transition-opacity">
<AssistantActionBar />
</div>
</div>
</CitationMetadataProvider>
);
@ -639,35 +644,41 @@ export const AssistantMessage: FC = () => {
className="aui-assistant-message-root group fade-in slide-in-from-bottom-1 relative mx-auto w-full max-w-(--thread-max-width) animate-in py-3 duration-150"
data-role="assistant"
>
{/* Comment trigger — right-aligned, just below user query on all screen sizes */}
{showCommentTrigger && (
<div className="mr-2 mb-1 flex justify-end">
<button
ref={isDesktop ? commentTriggerRef : undefined}
type="button"
onClick={
isDesktop ? () => setIsInlineOpen((prev) => !prev) : () => setIsSheetOpen(true)
}
className={cn(
"flex items-center gap-1.5 rounded-full px-3 py-1 text-sm transition-colors",
isDesktop && isInlineOpen
? "bg-primary/10 text-primary"
: hasComments
? "text-primary hover:bg-primary/10"
: "text-muted-foreground hover:text-foreground hover:bg-muted"
)}
>
<MessageCircleReply className={cn("size-3.5", hasComments && "fill-current")} />
{hasComments ? (
<span>
{commentCount} {commentCount === 1 ? "comment" : "comments"}
</span>
) : (
<span>Add comment</span>
)}
</button>
</div>
)}
{/* Fixed trigger slot prevents any vertical reflow when visibility changes */}
<div className="mr-2 mb-1 flex h-7 justify-end">
<button
ref={isDesktop ? commentTriggerRef : undefined}
type="button"
onClick={
showCommentTrigger
? isDesktop
? () => setIsInlineOpen((prev) => !prev)
: () => setIsSheetOpen(true)
: undefined
}
aria-hidden={!showCommentTrigger}
tabIndex={showCommentTrigger ? 0 : -1}
className={cn(
"flex items-center gap-1.5 rounded-full px-3 py-1 text-sm transition-colors",
"opacity-0 pointer-events-none",
showCommentTrigger && "opacity-100 pointer-events-auto",
isDesktop && isInlineOpen
? "bg-primary/10 text-primary"
: hasComments
? "text-primary hover:bg-primary/10"
: "text-muted-foreground hover:text-foreground hover:bg-muted"
)}
>
<MessageCircleReply className={cn("size-3.5", hasComments && "fill-current")} />
{hasComments ? (
<span>
{commentCount} {commentCount === 1 ? "comment" : "comments"}
</span>
) : (
<span>Add comment</span>
)}
</button>
</div>
{/* Desktop floating comment panel — overlays on top of chat content */}
{showCommentTrigger && isDesktop && isInlineOpen && dbMessageId && (
@ -699,6 +710,13 @@ const AssistantActionBar: FC = () => {
const isLast = useAuiState((s) => s.message.isLast);
const aui = useAui();
const api = useElectronAPI();
// Surface the persisted ``chat_turn_id`` so the per-turn revert
// affordance can scope to just this message's actions. Streamed
// turns get their id once the assistant message is hydrated/finalised.
const chatTurnId = useAuiState(({ message }) => {
const meta = message?.metadata as { custom?: { chatTurnId?: string | null } } | undefined;
return meta?.custom?.chatTurnId ?? null;
});
const isQuickAssist = !!api?.replaceText && IS_QUICK_ASSIST_WINDOW;
@ -743,6 +761,9 @@ const AssistantActionBar: FC = () => {
</TooltipIconButton>
)}
<MessageInfoDropdown />
<div className="ml-auto">
<RevertTurnButton chatTurnId={chatTurnId} />
</div>
</ActionBarPrimitive.Root>
);
};

View file

@ -0,0 +1,52 @@
"use client";
import { ThreadPrimitive } from "@assistant-ui/react";
import { ArrowDownIcon } from "lucide-react";
import type { FC, ReactNode } from "react";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
const ChatScrollToBottom: FC = () => (
<ThreadPrimitive.ScrollToBottom asChild>
<TooltipIconButton
tooltip="Scroll to bottom"
variant="outline"
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
>
<ArrowDownIcon />
</TooltipIconButton>
</ThreadPrimitive.ScrollToBottom>
);
export interface ChatViewportProps {
children: ReactNode;
footer?: ReactNode;
}
export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => (
<ThreadPrimitive.Viewport
turnAnchor="top"
autoScroll
scrollToBottomOnRunStart
scrollToBottomOnInitialize
scrollToBottomOnThreadSwitch
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 scroll-smooth"
style={{ scrollbarGutter: "stable" }}
>
<div
aria-hidden
className="aui-chat-viewport-top-fade pointer-events-none sticky top-0 z-10 -mx-4 h-2 shrink-0 bg-gradient-to-b from-main-panel from-20% to-transparent"
/>
{children}
{footer ? (
<ThreadPrimitive.ViewportFooter
className="aui-chat-composer-footer sticky bottom-0 z-20 -mx-4 mt-auto flex flex-col items-stretch bg-gradient-to-t from-main-panel from-60% to-transparent px-4 pt-6"
style={{ paddingBottom: "max(0.5rem, env(safe-area-inset-bottom))" }}
>
<div className="aui-chat-composer-area relative mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-3 overflow-visible">
<ChatScrollToBottom />
{footer}
</div>
</ThreadPrimitive.ViewportFooter>
) : null}
</ThreadPrimitive.Viewport>
);

View file

@ -0,0 +1,106 @@
"use client";
/**
* Confirmation dialog shown when the user edits a message that has
* reversible downstream actions. Three buttons:
*
* "Revert all & resubmit" POST regenerate with revert_actions=true
* "Continue without revert" POST regenerate with revert_actions=false
* "Cancel" abort the edit entirely
*
* The dialog is auto-skipped when zero reversible downstream actions
* exist (the caller checks first via ``downstreamReversibleCount``).
*/
import { useEffect, useRef, useState } from "react";
import {
AlertDialog,
AlertDialogCancel,
AlertDialogContent,
AlertDialogDescription,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogTitle,
} from "@/components/ui/alert-dialog";
import { Button } from "@/components/ui/button";
export type EditMessageDialogChoice = "revert" | "continue" | "cancel";
export interface EditMessageDialogProps {
open: boolean;
onOpenChange: (open: boolean) => void;
downstreamReversibleCount: number;
downstreamTotalCount: number;
onChoose: (choice: EditMessageDialogChoice) => void | Promise<void>;
}
export function EditMessageDialog({
open,
onOpenChange,
downstreamReversibleCount,
downstreamTotalCount,
onChoose,
}: EditMessageDialogProps) {
const [busy, setBusy] = useState<EditMessageDialogChoice | null>(null);
// The parent's ``handleEditDialogChoice`` calls
// ``setEditDialogState(null)`` BEFORE awaiting ``handleRegenerate``.
// That collapses the dialog (Radix unmounts it) while ``onChoose``
// is still awaiting the long-running stream. Without this guard,
// the ``finally { setBusy(null) }`` below ran after unmount and
// produced a "state update on unmounted component" dev warning.
const mountedRef = useRef(true);
useEffect(() => {
mountedRef.current = true;
return () => {
mountedRef.current = false;
};
}, []);
const handle = async (choice: EditMessageDialogChoice) => {
setBusy(choice);
try {
await onChoose(choice);
} finally {
if (mountedRef.current) {
setBusy(null);
}
}
};
return (
<AlertDialog open={open} onOpenChange={onOpenChange}>
<AlertDialogContent>
<AlertDialogHeader>
<AlertDialogTitle>Edit this message?</AlertDialogTitle>
<AlertDialogDescription>
This edit drops {downstreamTotalCount} downstream message
{downstreamTotalCount === 1 ? "" : "s"} from the thread. {downstreamReversibleCount}{" "}
action
{downstreamReversibleCount === 1 ? "" : "s"} (e.g. file writes, connector changes) can
be rolled back. Pick how to handle them before regenerating.
</AlertDialogDescription>
</AlertDialogHeader>
<div className="grid gap-2">
<Button variant="default" disabled={busy !== null} onClick={() => handle("revert")}>
{busy === "revert"
? "Reverting & resubmitting…"
: `Revert ${downstreamReversibleCount} action${
downstreamReversibleCount === 1 ? "" : "s"
} & resubmit`}
</Button>
<Button variant="outline" disabled={busy !== null} onClick={() => handle("continue")}>
{busy === "continue" ? "Resubmitting…" : "Continue without reverting"}
</Button>
</div>
<AlertDialogFooter className="sm:justify-start">
<AlertDialogCancel disabled={busy !== null} onClick={() => handle("cancel")}>
Cancel
</AlertDialogCancel>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialog>
);
}

View file

@ -3,11 +3,11 @@
import { useQuery } from "@tanstack/react-query";
import { useSetAtom } from "jotai";
import { ExternalLink, FileText } from "lucide-react";
import dynamic from "next/dynamic";
import type { FC } from "react";
import { useCallback, useEffect, useRef, useState } from "react";
import { openCitationPanelAtom } from "@/atoms/citation/citation-panel.atom";
import { useCitationMetadata } from "@/components/assistant-ui/citation-metadata-context";
import { MarkdownViewer } from "@/components/markdown-viewer";
import { Citation } from "@/components/tool-ui/citation";
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
import { Spinner } from "@/components/ui/spinner";
@ -15,6 +15,16 @@ import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip
import { documentsApiService } from "@/lib/apis/documents-api.service";
import { cacheKeys } from "@/lib/query-client/cache-keys";
// Lazily load MarkdownViewer here to break the static import cycle:
// `markdown-viewer.tsx` → `citation-renderer.tsx` → `inline-citation.tsx`
// would otherwise pull `markdown-viewer.tsx` back in at module-init time.
// Only `SurfsenseDocCitation` (popover body) ever renders this viewer, so
// the lazy boundary is invisible to most call paths.
const MarkdownViewer = dynamic(
() => import("@/components/markdown-viewer").then((m) => m.MarkdownViewer),
{ ssr: false, loading: () => <Spinner size="xs" /> }
);
interface InlineCitationProps {
chunkId: number;
isDocsChunk?: boolean;
@ -172,7 +182,7 @@ const SurfsenseDocCitation: FC<{ chunkId: number }> = ({ chunkId }) => {
</p>
)}
{!isLoading && !error && citedChunk?.content && (
<MarkdownViewer content={citedChunk.content} maxLength={1500} />
<MarkdownViewer content={citedChunk.content} maxLength={1500} enableCitations />
)}
{!isLoading && !error && !citedChunk?.content && (
<p className="py-4 text-xs text-muted-foreground">No content available.</p>

File diff suppressed because it is too large Load diff

View file

@ -12,14 +12,15 @@ import { ExternalLinkIcon } from "lucide-react";
import dynamic from "next/dynamic";
import { useParams } from "next/navigation";
import { useTheme } from "next-themes";
import { memo, type ReactNode } from "react";
import { createContext, memo, type ReactNode, useCallback, useContext, useRef } from "react";
import rehypeKatex from "rehype-katex";
import remarkGfm from "remark-gfm";
import remarkMath from "remark-math";
import { openEditorPanelAtom } from "@/atoms/editor/editor-panel.atom";
import { ImagePreview, ImageRoot, ImageZoom } from "@/components/assistant-ui/image";
import "katex/dist/katex.min.css";
import { InlineCitation, UrlCitation } from "@/components/assistant-ui/inline-citation";
import { toast } from "sonner";
import { processChildrenWithCitations } from "@/components/citations/citation-renderer";
import { Skeleton } from "@/components/ui/skeleton";
import {
Table,
@ -30,6 +31,8 @@ import {
TableRow,
} from "@/components/ui/table";
import { useElectronAPI } from "@/hooks/use-platform";
import { documentsApiService } from "@/lib/apis/documents-api.service";
import { type CitationUrlMap, preprocessCitationMarkdown } from "@/lib/citations/citation-parser";
import { cn } from "@/lib/utils";
function MarkdownCodeBlockSkeleton() {
@ -59,31 +62,30 @@ const LazyMarkdownCodeBlock = dynamic(
}
);
// Storage for URL citations replaced during preprocess to avoid GFM autolink interference.
// Populated in preprocessMarkdown, consumed in parseTextWithCitations.
let _pendingUrlCitations = new Map<string, string>();
let _urlCiteIdx = 0;
// Per-render URL placeholder map propagated to component overrides via
// React Context. Replaces the previous module-level `_pendingUrlCitations`
// state, which was unsafe under concurrent renders / SSR.
type CitationUrlMapRef = { current: CitationUrlMap };
const EMPTY_URL_MAP: CitationUrlMap = new Map();
const CitationUrlMapContext = createContext<CitationUrlMapRef>({ current: EMPTY_URL_MAP });
function useCitationUrlMap(): CitationUrlMap {
return useContext(CitationUrlMapContext).current;
}
/**
* Preprocess raw markdown before it reaches the remark/rehype pipeline.
* - Replaces URL-based citations with safe placeholders (prevents GFM autolinks)
* - Normalises LaTeX delimiters to dollar-sign syntax for remark-math
*/
function preprocessMarkdown(content: string): string {
function preprocessMarkdown(content: string, urlMapRef: CitationUrlMapRef): string {
// Replace URL-based citations with safe placeholders BEFORE markdown parsing.
// GFM autolinks would otherwise convert the https://... inside [citation:URL]
// into an <a> element, splitting the text and preventing our citation regex
// from matching the full pattern.
_pendingUrlCitations = new Map();
_urlCiteIdx = 0;
content = content.replace(
/[[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+)\s*\u200B?[\]】]/g,
(_, url) => {
const key = `urlcite${_urlCiteIdx++}`;
_pendingUrlCitations.set(key, url.trim());
return `[citation:${key}]`;
}
);
const { content: rewritten, urlMap } = preprocessCitationMarkdown(content);
urlMapRef.current = urlMap;
content = rewritten;
// All math forms are normalised to $$...$$ so we can disable single-dollar
// inline math in remark-math (otherwise currency like "$3,120.00 and $0.00"
@ -116,113 +118,25 @@ function preprocessMarkdown(content: string): string {
return content;
}
// Matches [citation:...] with numeric IDs (incl. negative, doc- prefix, comma-separated),
// URL-based IDs from live web search, or urlciteN placeholders from preprocess.
// Also matches Chinese brackets 【】 and handles zero-width spaces that LLM sometimes inserts.
const CITATION_REGEX =
/[[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+|urlcite\d+|(?:doc-)?-?\d+(?:\s*,\s*(?:doc-)?-?\d+)*)\s*\u200B?[\]】]/g;
/**
* Parses text and replaces [citation:XXX] patterns with citation components.
* Supports:
* - Numeric chunk IDs: [citation:123]
* - Doc-prefixed IDs: [citation:doc-123]
* - Comma-separated IDs: [citation:4149, 4150, 4151]
* - URL-based citations from live search: [citation:https://example.com/page]
*/
function parseTextWithCitations(text: string): ReactNode[] {
const parts: ReactNode[] = [];
let lastIndex = 0;
let match: RegExpExecArray | null;
let instanceIndex = 0;
CITATION_REGEX.lastIndex = 0;
match = CITATION_REGEX.exec(text);
while (match !== null) {
if (match.index > lastIndex) {
parts.push(text.substring(lastIndex, match.index));
}
const captured = match[1];
if (captured.startsWith("http://") || captured.startsWith("https://")) {
parts.push(<UrlCitation key={`citation-url-${instanceIndex}`} url={captured.trim()} />);
instanceIndex++;
} else if (captured.startsWith("urlcite")) {
const url = _pendingUrlCitations.get(captured);
if (url) {
parts.push(<UrlCitation key={`citation-url-${instanceIndex}`} url={url} />);
}
instanceIndex++;
} else {
const rawIds = captured.split(",").map((s) => s.trim());
for (const rawId of rawIds) {
const isDocsChunk = rawId.startsWith("doc-");
const chunkId = Number.parseInt(isDocsChunk ? rawId.slice(4) : rawId, 10);
parts.push(
<InlineCitation
key={`citation-${isDocsChunk ? "doc-" : ""}${chunkId}-${instanceIndex}`}
chunkId={chunkId}
isDocsChunk={isDocsChunk}
/>
);
instanceIndex++;
}
}
lastIndex = match.index + match[0].length;
match = CITATION_REGEX.exec(text);
}
if (lastIndex < text.length) {
parts.push(text.substring(lastIndex));
}
return parts.length > 0 ? parts : [text];
}
const MarkdownTextImpl = () => {
const urlMapRef = useRef<CitationUrlMap>(EMPTY_URL_MAP);
const preprocess = useCallback((content: string) => preprocessMarkdown(content, urlMapRef), []);
return (
<MarkdownTextPrimitive
smooth={false}
remarkPlugins={[remarkGfm, [remarkMath, { singleDollarTextMath: false }]]}
rehypePlugins={[rehypeKatex]}
className="aui-md"
components={defaultComponents}
preprocess={preprocessMarkdown}
/>
<CitationUrlMapContext.Provider value={urlMapRef}>
<MarkdownTextPrimitive
smooth={false}
remarkPlugins={[remarkGfm, [remarkMath, { singleDollarTextMath: false }]]}
rehypePlugins={[rehypeKatex]}
className="aui-md"
components={defaultComponents}
preprocess={preprocess}
/>
</CitationUrlMapContext.Provider>
);
};
export const MarkdownText = memo(MarkdownTextImpl);
/**
* Helper to process children and replace citation patterns with components
*/
function processChildrenWithCitations(children: ReactNode): ReactNode {
if (typeof children === "string") {
const parsed = parseTextWithCitations(children);
return parsed.length === 1 && typeof parsed[0] === "string" ? children : parsed;
}
if (Array.isArray(children)) {
return children.map((child) => {
if (typeof child === "string") {
const parsed = parseTextWithCitations(child);
return parsed.length === 1 && typeof parsed[0] === "string" ? (
child
) : (
<span key={child}>{parsed}</span>
);
}
return child;
});
}
return children;
}
function extractDomain(url: string): string {
try {
const parsed = new URL(url);
@ -282,6 +196,85 @@ function isVirtualFilePathToken(value: string): boolean {
return segments.length >= 2;
}
function isStandaloneDocumentsPathText(node: ReactNode): string | null {
if (typeof node !== "string") return null;
const value = node.trim();
if (!value.startsWith("/documents/")) return null;
if (value.includes(" ")) return null;
const normalized = value.replace(/\/+$/, "");
const leaf = normalized.split("/").filter(Boolean).at(-1) ?? "";
if (!leaf || !leaf.includes(".")) return null;
return value;
}
function FilePathLink({ path, className }: { path: string; className?: string }) {
const openEditorPanel = useSetAtom(openEditorPanelAtom);
const params = useParams();
const electronAPI = useElectronAPI();
const searchSpaceIdParam = params?.search_space_id;
const parsedSearchSpaceId = Array.isArray(searchSpaceIdParam)
? Number(searchSpaceIdParam[0])
: Number(searchSpaceIdParam);
const resolvedSearchSpaceId = Number.isFinite(parsedSearchSpaceId)
? parsedSearchSpaceId
: undefined;
return (
<button
type="button"
className={cn(
"cursor-pointer font-mono text-[0.9em] font-medium text-primary underline underline-offset-4 transition-colors hover:text-primary/80",
className
)}
onClick={(event) => {
event.preventDefault();
event.stopPropagation();
void (async () => {
if (electronAPI) {
let resolvedLocalPath = path;
if (electronAPI.getAgentFilesystemMounts) {
try {
const mounts = (await electronAPI.getAgentFilesystemMounts(
resolvedSearchSpaceId
)) as AgentFilesystemMount[];
resolvedLocalPath = normalizeLocalVirtualPathForEditor(path, mounts);
} catch {
// Fall back to the raw path if mount lookup fails.
}
}
openEditorPanel({
kind: "local_file",
localFilePath: resolvedLocalPath,
title: resolvedLocalPath.split("/").pop() || resolvedLocalPath,
searchSpaceId: resolvedSearchSpaceId,
});
return;
}
if (!resolvedSearchSpaceId || !path.startsWith("/documents/")) return;
try {
const doc = await documentsApiService.getDocumentByVirtualPath({
search_space_id: resolvedSearchSpaceId,
virtual_path: path,
});
openEditorPanel({
kind: "document",
documentId: doc.id,
searchSpaceId: resolvedSearchSpaceId,
title: doc.title,
});
} catch {
toast.error("Document not found in knowledge base.");
}
})();
}}
title="Open in editor panel"
>
{path}
</button>
);
}
function MarkdownImage({ src, alt }: { src?: string; alt?: string }) {
if (!src) return null;
@ -322,92 +315,127 @@ function MarkdownImage({ src, alt }: { src?: string; alt?: string }) {
}
const defaultComponents = memoizeMarkdownComponents({
h1: ({ className, children, ...props }) => (
<h1
className={cn(
"aui-md-h1 mb-8 scroll-m-20 font-extrabold text-4xl tracking-tight last:mb-0",
className
)}
{...props}
>
{processChildrenWithCitations(children)}
</h1>
),
h2: ({ className, children, ...props }) => (
<h2
className={cn(
"aui-md-h2 mt-8 mb-4 scroll-m-20 font-semibold text-3xl tracking-tight first:mt-0 last:mb-0",
className
)}
{...props}
>
{processChildrenWithCitations(children)}
</h2>
),
h3: ({ className, children, ...props }) => (
<h3
className={cn(
"aui-md-h3 mt-6 mb-4 scroll-m-20 font-semibold text-2xl tracking-tight first:mt-0 last:mb-0",
className
)}
{...props}
>
{processChildrenWithCitations(children)}
</h3>
),
h4: ({ className, children, ...props }) => (
<h4
className={cn(
"aui-md-h4 mt-6 mb-4 scroll-m-20 font-semibold text-xl tracking-tight first:mt-0 last:mb-0",
className
)}
{...props}
>
{processChildrenWithCitations(children)}
</h4>
),
h5: ({ className, children, ...props }) => (
<h5
className={cn("aui-md-h5 my-4 font-semibold text-lg first:mt-0 last:mb-0", className)}
{...props}
>
{processChildrenWithCitations(children)}
</h5>
),
h6: ({ className, children, ...props }) => (
<h6 className={cn("aui-md-h6 my-4 font-semibold first:mt-0 last:mb-0", className)} {...props}>
{processChildrenWithCitations(children)}
</h6>
),
p: ({ className, children, ...props }) => (
<p className={cn("aui-md-p mt-5 mb-5 leading-7 first:mt-0 last:mb-0", className)} {...props}>
{processChildrenWithCitations(children)}
</p>
),
a: ({ className, children, ...props }) => (
<a
className={cn("aui-md-a font-medium text-primary underline underline-offset-4", className)}
{...props}
>
{processChildrenWithCitations(children)}
</a>
),
blockquote: ({ className, children, ...props }) => (
<blockquote className={cn("aui-md-blockquote border-l-2 pl-6 italic", className)} {...props}>
{processChildrenWithCitations(children)}
</blockquote>
),
h1: function H1({ className, children, ...props }) {
const urlMap = useCitationUrlMap();
return (
<h1
className={cn(
"aui-md-h1 mb-8 scroll-m-20 font-extrabold text-4xl tracking-tight last:mb-0",
className
)}
{...props}
>
{processChildrenWithCitations(children, urlMap)}
</h1>
);
},
h2: function H2({ className, children, ...props }) {
const urlMap = useCitationUrlMap();
return (
<h2
className={cn(
"aui-md-h2 mt-8 mb-4 scroll-m-20 font-semibold text-3xl tracking-tight first:mt-0 last:mb-0",
className
)}
{...props}
>
{processChildrenWithCitations(children, urlMap)}
</h2>
);
},
h3: function H3({ className, children, ...props }) {
const urlMap = useCitationUrlMap();
return (
<h3
className={cn(
"aui-md-h3 mt-6 mb-4 scroll-m-20 font-semibold text-2xl tracking-tight first:mt-0 last:mb-0",
className
)}
{...props}
>
{processChildrenWithCitations(children, urlMap)}
</h3>
);
},
h4: function H4({ className, children, ...props }) {
const urlMap = useCitationUrlMap();
return (
<h4
className={cn(
"aui-md-h4 mt-6 mb-4 scroll-m-20 font-semibold text-xl tracking-tight first:mt-0 last:mb-0",
className
)}
{...props}
>
{processChildrenWithCitations(children, urlMap)}
</h4>
);
},
h5: function H5({ className, children, ...props }) {
const urlMap = useCitationUrlMap();
return (
<h5
className={cn("aui-md-h5 my-4 font-semibold text-lg first:mt-0 last:mb-0", className)}
{...props}
>
{processChildrenWithCitations(children, urlMap)}
</h5>
);
},
h6: function H6({ className, children, ...props }) {
const urlMap = useCitationUrlMap();
return (
<h6 className={cn("aui-md-h6 my-4 font-semibold first:mt-0 last:mb-0", className)} {...props}>
{processChildrenWithCitations(children, urlMap)}
</h6>
);
},
p: function P({ className, children, ...props }) {
const urlMap = useCitationUrlMap();
const standalonePath = isStandaloneDocumentsPathText(children);
return (
<p className={cn("aui-md-p mt-5 mb-5 leading-7 first:mt-0 last:mb-0", className)} {...props}>
{standalonePath ? (
<FilePathLink path={standalonePath} />
) : (
processChildrenWithCitations(children, urlMap)
)}
</p>
);
},
a: function A({ className, children, ...props }) {
const urlMap = useCitationUrlMap();
return (
<a
className={cn("aui-md-a font-medium text-primary underline underline-offset-4", className)}
{...props}
>
{processChildrenWithCitations(children, urlMap)}
</a>
);
},
blockquote: function Blockquote({ className, children, ...props }) {
const urlMap = useCitationUrlMap();
return (
<blockquote className={cn("aui-md-blockquote border-l-2 pl-6 italic", className)} {...props}>
{processChildrenWithCitations(children, urlMap)}
</blockquote>
);
},
ul: ({ className, ...props }) => (
<ul className={cn("aui-md-ul my-5 ml-6 list-disc [&>li]:mt-2", className)} {...props} />
),
ol: ({ className, ...props }) => (
<ol className={cn("aui-md-ol my-5 ml-6 list-decimal [&>li]:mt-2", className)} {...props} />
),
li: ({ className, children, ...props }) => (
<li className={cn("aui-md-li", className)} {...props}>
{processChildrenWithCitations(children)}
</li>
),
li: function Li({ className, children, ...props }) {
const urlMap = useCitationUrlMap();
return (
<li className={cn("aui-md-li", className)} {...props}>
{processChildrenWithCitations(children, urlMap)}
</li>
);
},
hr: ({ className, ...props }) => (
<hr className={cn("aui-md-hr my-5 border-b", className)} {...props} />
),
@ -422,28 +450,34 @@ const defaultComponents = memoizeMarkdownComponents({
tbody: ({ className, ...props }) => (
<TableBody className={cn("aui-md-tbody", className)} {...props} />
),
th: ({ className, children, ...props }) => (
<TableHead
className={cn(
"aui-md-th bg-muted/50 whitespace-normal [[align=center]]:text-center [[align=right]]:text-right",
className
)}
{...props}
>
{processChildrenWithCitations(children)}
</TableHead>
),
td: ({ className, children, ...props }) => (
<TableCell
className={cn(
"aui-md-td whitespace-normal [[align=center]]:text-center [[align=right]]:text-right",
className
)}
{...props}
>
{processChildrenWithCitations(children)}
</TableCell>
),
th: function Th({ className, children, ...props }) {
const urlMap = useCitationUrlMap();
return (
<TableHead
className={cn(
"aui-md-th bg-muted/50 whitespace-normal [[align=center]]:text-center [[align=right]]:text-right",
className
)}
{...props}
>
{processChildrenWithCitations(children, urlMap)}
</TableHead>
);
},
td: function Td({ className, children, ...props }) {
const urlMap = useCitationUrlMap();
return (
<TableCell
className={cn(
"aui-md-td whitespace-normal [[align=center]]:text-center [[align=right]]:text-right",
className
)}
{...props}
>
{processChildrenWithCitations(children, urlMap)}
</TableCell>
);
},
tr: ({ className, ...props }) => <TableRow className={cn("aui-md-tr", className)} {...props} />,
sup: ({ className, ...props }) => (
<sup className={cn("aui-md-sup [&>a]:text-xs [&>a]:no-underline", className)} {...props} />
@ -452,8 +486,6 @@ const defaultComponents = memoizeMarkdownComponents({
code: function Code({ className, children, ...props }) {
const isCodeBlock = useIsMarkdownCodeBlock();
const { resolvedTheme } = useTheme();
const openEditorPanel = useSetAtom(openEditorPanelAtom);
const params = useParams();
const electronAPI = useElectronAPI();
const language = /language-(\w+)/.exec(className || "")?.[1] ?? "text";
const codeString = String(children).replace(/\n$/, "");
@ -470,53 +502,17 @@ const defaultComponents = memoizeMarkdownComponents({
const isLikelyFolder =
inlineValue.endsWith("/") || !leafSegment || !leafSegment.includes(".");
const isLocalPath =
!!electronAPI &&
isVirtualFilePathToken(inlineValue) &&
!inlineValue.startsWith("//") &&
!isLikelyFolder;
const displayLocalPath = inlineValue.replace(/^\/+/, "");
const searchSpaceIdParam = params?.search_space_id;
const parsedSearchSpaceId = Array.isArray(searchSpaceIdParam)
? Number(searchSpaceIdParam[0])
: Number(searchSpaceIdParam);
(isVirtualFilePathToken(inlineValue) &&
!inlineValue.startsWith("//") &&
!isLikelyFolder &&
!!electronAPI) ||
(isVirtualFilePathToken(inlineValue) &&
!inlineValue.startsWith("//") &&
!isLikelyFolder &&
!electronAPI &&
inlineValue.startsWith("/documents/"));
if (isLocalPath) {
return (
<button
type="button"
className={cn(
"cursor-pointer font-mono text-[0.9em] font-medium text-primary underline underline-offset-4 transition-colors hover:text-primary/80"
)}
onClick={(event) => {
event.preventDefault();
event.stopPropagation();
void (async () => {
let resolvedLocalPath = inlineValue;
const resolvedSearchSpaceId = Number.isFinite(parsedSearchSpaceId)
? parsedSearchSpaceId
: undefined;
if (electronAPI?.getAgentFilesystemMounts) {
try {
const mounts = (await electronAPI.getAgentFilesystemMounts(
resolvedSearchSpaceId
)) as AgentFilesystemMount[];
resolvedLocalPath = normalizeLocalVirtualPathForEditor(inlineValue, mounts);
} catch {
// Fall back to the raw inline path if mount lookup fails.
}
}
openEditorPanel({
kind: "local_file",
localFilePath: resolvedLocalPath,
title: resolvedLocalPath.split("/").pop() || resolvedLocalPath,
searchSpaceId: resolvedSearchSpaceId,
});
})();
}}
title="Open in editor panel"
>
{displayLocalPath}
</button>
);
return <FilePathLink path={inlineValue} className="text-[0.9em]" />;
}
return (
<code
@ -552,16 +548,22 @@ const defaultComponents = memoizeMarkdownComponents({
/>
);
},
strong: ({ className, children, ...props }) => (
<strong className={cn("aui-md-strong font-semibold", className)} {...props}>
{processChildrenWithCitations(children)}
</strong>
),
em: ({ className, children, ...props }) => (
<em className={cn("aui-md-em", className)} {...props}>
{processChildrenWithCitations(children)}
</em>
),
strong: function Strong({ className, children, ...props }) {
const urlMap = useCitationUrlMap();
return (
<strong className={cn("aui-md-strong font-semibold", className)} {...props}>
{processChildrenWithCitations(children, urlMap)}
</strong>
);
},
em: function Em({ className, children, ...props }) {
const urlMap = useCitationUrlMap();
return (
<em className={cn("aui-md-em", className)} {...props}>
{processChildrenWithCitations(children, urlMap)}
</em>
);
},
img: ({ src, alt }) => (
<MarkdownImage src={typeof src === "string" ? src : undefined} alt={alt} />
),

View file

@ -0,0 +1,24 @@
"use client";
import { type ComponentPropsWithoutRef, forwardRef, type WheelEvent } from "react";
export type NestedScrollProps = ComponentPropsWithoutRef<"div">;
export const NestedScroll = forwardRef<HTMLDivElement, NestedScrollProps>(
({ onWheel, ...props }, ref) => {
const handleWheel = (event: WheelEvent<HTMLDivElement>) => {
const el = event.currentTarget;
const canScrollUp = el.scrollTop > 0;
const canScrollDown = el.scrollTop < el.scrollHeight - el.clientHeight - 1;
const goingUp = event.deltaY < 0;
const goingDown = event.deltaY > 0;
if ((goingUp && canScrollUp) || (goingDown && canScrollDown)) {
event.stopPropagation();
}
onWheel?.(event);
};
return <div ref={ref} onWheel={handleWheel} {...props} />;
}
);
NestedScroll.displayName = "NestedScroll";

View file

@ -0,0 +1,81 @@
"use client";
import type { ReasoningMessagePartComponent } from "@assistant-ui/react";
import { ChevronRightIcon } from "lucide-react";
import { useEffect, useMemo, useState } from "react";
import { TextShimmerLoader } from "@/components/prompt-kit/loader";
import { cn } from "@/lib/utils";
/**
* Renders the structured `reasoning` part emitted by the backend's
* stream-parity v2 path (A1).
*
* Behaviour mirrors the existing `ThinkingStepsDisplay`:
* - collapsed by default;
* - auto-expanded while the part is still `running`;
* - auto-collapsed once status flips to `complete`.
*
* The component is registered via the `Reasoning` slot on
* `MessagePrimitive.Parts` in `assistant-message.tsx` so it lives at the
* exact ordinal position of the reasoning block in the message content
* array (i.e. above the assistant text that follows it).
*/
export const ReasoningMessagePart: ReasoningMessagePartComponent = ({ text, status }) => {
const isRunning = status?.type === "running";
const [isOpen, setIsOpen] = useState(() => isRunning);
useEffect(() => {
if (isRunning) {
setIsOpen(true);
} else if (status?.type === "complete") {
setIsOpen(false);
}
}, [isRunning, status?.type]);
const headerLabel = useMemo(() => {
if (isRunning) return "Thinking";
if (status?.type === "incomplete") return "Thinking interrupted";
return "Thought";
}, [isRunning, status?.type]);
if (!text || text.length === 0) {
if (!isRunning) return null;
}
return (
<div className="mx-auto w-full max-w-(--thread-max-width) px-2 py-2">
<div className="rounded-lg">
<button
type="button"
onClick={() => setIsOpen((prev) => !prev)}
className={cn(
"flex w-full items-center gap-1.5 text-left text-sm transition-colors",
"text-muted-foreground hover:text-foreground"
)}
>
{isRunning ? (
<TextShimmerLoader text={headerLabel} size="sm" />
) : (
<span>{headerLabel}</span>
)}
<ChevronRightIcon
className={cn("size-4 transition-transform duration-200", isOpen && "rotate-90")}
/>
</button>
<div
className={cn(
"grid transition-[grid-template-rows] duration-300 ease-out",
isOpen ? "grid-rows-[1fr]" : "grid-rows-[0fr]"
)}
>
<div className="overflow-hidden">
<div className="mt-2 border-l border-muted-foreground/30 pl-3 text-sm leading-relaxed text-muted-foreground whitespace-pre-wrap wrap-break-word">
{text}
</div>
</div>
</div>
</div>
</div>
);
};

View file

@ -0,0 +1,213 @@
"use client";
/**
* "Revert turn" button rendered at the bottom of every completed
* assistant turn that has at least one reversible action.
*
* The button reads from the unified ``useAgentActionsQuery`` cache
* (the SAME react-query cache the agent-actions sheet and the inline
* Revert button consume) filtered by ``chat_turn_id``. It shows a
* confirmation dialog summarising "N reversible / M total" and, on
* confirm, calls ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
*
* The route returns a per-action result list and never collapses the
* batch into a 4xx so we render any failed/not_reversible rows inline
* with their messages.
*/
import { useQueryClient } from "@tanstack/react-query";
import { useAtomValue } from "jotai";
import { CheckIcon, RotateCcw, XCircleIcon } from "lucide-react";
import { useMemo, useState } from "react";
import { toast } from "sonner";
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
import {
AlertDialog,
AlertDialogAction,
AlertDialogCancel,
AlertDialogContent,
AlertDialogDescription,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogTitle,
AlertDialogTrigger,
} from "@/components/ui/alert-dialog";
import { Button } from "@/components/ui/button";
import { getToolDisplayName } from "@/contracts/enums/toolIcons";
import {
applyRevertTurnResultsToCache,
useAgentActionsQuery,
} from "@/hooks/use-agent-actions-query";
import {
agentActionsApiService,
type RevertTurnActionResult,
} from "@/lib/apis/agent-actions-api.service";
import { AppError } from "@/lib/error";
import { cn } from "@/lib/utils";
interface RevertTurnButtonProps {
chatTurnId: string | null | undefined;
}
export function RevertTurnButton({ chatTurnId }: RevertTurnButtonProps) {
const session = useAtomValue(chatSessionStateAtom);
const threadId = session?.threadId ?? null;
const queryClient = useQueryClient();
const { findByChatTurnId } = useAgentActionsQuery(threadId);
const [isReverting, setIsReverting] = useState(false);
const [confirmOpen, setConfirmOpen] = useState(false);
const [resultsOpen, setResultsOpen] = useState(false);
const [results, setResults] = useState<RevertTurnActionResult[]>([]);
const actions = useMemo(() => findByChatTurnId(chatTurnId), [findByChatTurnId, chatTurnId]);
const reversibleCount = useMemo(
() =>
actions.filter(
(a) =>
a.reversible &&
(a.reverted_by_action_id === null || a.reverted_by_action_id === undefined) &&
!a.is_revert_action &&
(a.error === null || a.error === undefined)
).length,
[actions]
);
const totalCount = useMemo(() => actions.filter((a) => !a.is_revert_action).length, [actions]);
if (!chatTurnId) return null;
if (reversibleCount === 0) return null;
if (!threadId) return null;
const handleRevertTurn = async () => {
setIsReverting(true);
try {
const response = await agentActionsApiService.revertTurn(threadId, chatTurnId);
setResults(response.results);
const revertedEntries = response.results
.filter((r) => r.status === "reverted" || r.status === "already_reverted")
.map((r) => ({ id: r.action_id, newActionId: r.new_action_id ?? null }));
if (revertedEntries.length > 0) {
applyRevertTurnResultsToCache(queryClient, threadId, revertedEntries);
}
if (response.status === "ok") {
toast.success(
response.reverted === 1 ? "Reverted 1 action." : `Reverted ${response.reverted} actions.`
);
} else {
// Every "not undone" bucket counts as a failure for the
// user-facing summary. ``skipped`` rows are batch
// artefacts (revert rows themselves) and intentionally
// excluded from the failure tally.
const failureCount =
response.failed + response.not_reversible + (response.permission_denied ?? 0);
toast.warning(
`Reverted ${response.reverted} of ${response.total}. ${failureCount} could not be undone.`
);
setResultsOpen(true);
}
} catch (err) {
if (err instanceof AppError && err.status === 503) {
return;
}
const message =
err instanceof AppError
? err.message
: err instanceof Error
? err.message
: "Failed to revert turn.";
toast.error(message);
} finally {
setIsReverting(false);
setConfirmOpen(false);
}
};
return (
<>
<AlertDialog open={confirmOpen} onOpenChange={setConfirmOpen}>
<AlertDialogTrigger asChild>
<Button
size="sm"
variant="ghost"
className="text-muted-foreground hover:text-foreground gap-1.5"
onClick={(e) => {
e.stopPropagation();
setConfirmOpen(true);
}}
>
<RotateCcw className="size-3.5" />
<span>Revert turn</span>
<span className="text-xs tabular-nums opacity-70">
{reversibleCount}/{totalCount}
</span>
</Button>
</AlertDialogTrigger>
<AlertDialogContent>
<AlertDialogHeader>
<AlertDialogTitle>Revert this turn?</AlertDialogTitle>
<AlertDialogDescription>
This will undo {reversibleCount} of {totalCount} action
{totalCount === 1 ? "" : "s"} from this turn in reverse order. The chat history and
any read-only actions are preserved. Some rows may not be reversible partial success
is normal.
</AlertDialogDescription>
</AlertDialogHeader>
<AlertDialogFooter>
<AlertDialogCancel disabled={isReverting}>Cancel</AlertDialogCancel>
<AlertDialogAction
onClick={(e) => {
e.preventDefault();
handleRevertTurn();
}}
disabled={isReverting}
>
{isReverting ? "Reverting…" : "Revert turn"}
</AlertDialogAction>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialog>
<AlertDialog open={resultsOpen} onOpenChange={setResultsOpen}>
<AlertDialogContent>
<AlertDialogHeader>
<AlertDialogTitle>Revert results</AlertDialogTitle>
<AlertDialogDescription>
Some actions could not be reverted. Review per-row outcomes below.
</AlertDialogDescription>
</AlertDialogHeader>
<ul className="max-h-72 overflow-y-auto space-y-2 text-sm">
{results.map((r) => (
<RevertResultRow key={r.action_id} result={r} />
))}
</ul>
<AlertDialogFooter>
<AlertDialogAction onClick={() => setResultsOpen(false)}>Close</AlertDialogAction>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialog>
</>
);
}
function RevertResultRow({ result }: { result: RevertTurnActionResult }) {
const isOk = result.status === "reverted" || result.status === "already_reverted";
const Icon = isOk ? CheckIcon : XCircleIcon;
return (
<li className="flex items-start gap-2 rounded-md border bg-muted/30 px-3 py-2">
<Icon
className={cn("size-4 mt-0.5 shrink-0", isOk ? "text-emerald-500" : "text-destructive")}
/>
<div className="min-w-0 flex-1">
<p className="font-medium truncate">
{getToolDisplayName(result.tool_name)}{" "}
<span className="ml-1 text-xs text-muted-foreground">
{result.status.replace(/_/g, " ")}
</span>
</p>
{(result.message || result.error) && (
<p className="text-xs text-muted-foreground mt-0.5">{result.error ?? result.message}</p>
)}
</div>
</li>
);
}

View file

@ -0,0 +1,27 @@
"use client";
import { makeAssistantDataUI } from "@assistant-ui/react";
/**
* Renders a thin horizontal divider between model steps within a single
* assistant turn. The data part is pushed by `addStepSeparator` in
* `streaming-state.ts` whenever a `start-step` SSE event arrives after
* the message already has non-step content.
*
* Today the backend emits one `start-step` / `finish-step` pair per turn,
* so most messages won't contain a separator. The renderer is wired up so
* the planned per-model-step refactor (A2 follow-up) can light up without
* touching the persistence path.
*/
function StepSeparatorDataRenderer() {
return (
<div className="mx-auto my-3 w-full max-w-(--thread-max-width) px-2">
<div className="border-t border-border/60" />
</div>
);
}
export const StepSeparatorDataUI = makeAssistantDataUI({
name: "step-separator",
render: StepSeparatorDataRenderer,
});

View file

@ -1,18 +0,0 @@
import { ThreadPrimitive } from "@assistant-ui/react";
import { ArrowDownIcon } from "lucide-react";
import type { FC } from "react";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
export const ThreadScrollToBottom: FC = () => {
return (
<ThreadPrimitive.ScrollToBottom asChild>
<TooltipIconButton
tooltip="Scroll to bottom"
variant="outline"
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
>
<ArrowDownIcon />
</TooltipIconButton>
</ThreadPrimitive.ScrollToBottom>
);
};

View file

@ -5,12 +5,10 @@ import {
ThreadPrimitive,
useAui,
useAuiState,
useThreadViewportStore,
} from "@assistant-ui/react";
import { useAtom, useAtomValue, useSetAtom } from "jotai";
import {
AlertCircle,
ArrowDownIcon,
ArrowUpIcon,
Camera,
ChevronDown,
@ -37,10 +35,13 @@ import {
toggleToolAtom,
} from "@/atoms/agent-tools/agent-tools.atoms";
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
import {
mentionedDocumentsAtom,
} from "@/atoms/chat/mentioned-documents.atom";
import { currentThreadAtom } from "@/atoms/chat/current-thread.atom";
import { mentionedDocumentsAtom } from "@/atoms/chat/mentioned-documents.atom";
import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom";
import {
clearPremiumAlertForThreadAtom,
premiumAlertByThreadAtom,
} from "@/atoms/chat/premium-alert.atom";
import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms";
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
import { membersAtom } from "@/atoms/members/members-query.atoms";
@ -52,6 +53,7 @@ import {
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
import { AssistantMessage } from "@/components/assistant-ui/assistant-message";
import { ChatSessionStatus } from "@/components/assistant-ui/chat-session-status";
import { ChatViewport } from "@/components/assistant-ui/chat-viewport";
import { ConnectorIndicator } from "@/components/assistant-ui/connector-popup";
import { useDocumentUploadDialog } from "@/components/assistant-ui/document-upload-popup";
import {
@ -82,6 +84,7 @@ import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
import {
CONNECTOR_ICON_TO_TYPES,
CONNECTOR_TOOL_ICON_PATHS,
getToolDisplayName,
getToolIcon,
} from "@/contracts/enums/toolIcons";
import type { Document } from "@/contracts/types/document.types";
@ -89,8 +92,8 @@ import { useBatchCommentsPreload } from "@/hooks/use-comments";
import { useCommentsSync } from "@/hooks/use-comments-sync";
import { useMediaQuery } from "@/hooks/use-media-query";
import { useElectronAPI } from "@/hooks/use-platform";
import { getMentionDocKey } from "@/lib/chat/mention-doc-key";
import { captureDisplayToPngDataUrl } from "@/lib/chat/display-media-capture";
import { getMentionDocKey } from "@/lib/chat/mention-doc-key";
import { SLIDEOUT_PANEL_OPENED_EVENT } from "@/lib/layout-events";
import { cn } from "@/lib/utils";
@ -108,10 +111,13 @@ const ThreadContent: FC = () => {
["--thread-max-width" as string]: "44rem",
}}
>
<ThreadPrimitive.Viewport
turnAnchor="top"
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4"
style={{ scrollbarGutter: "stable" }}
<ChatViewport
footer={
<AuiIf condition={({ thread }) => !thread.isEmpty}>
<PremiumQuotaPinnedAlert />
<Composer />
</AuiIf>
}
>
<AuiIf condition={({ thread }) => thread.isEmpty}>
<ThreadWelcome />
@ -124,36 +130,39 @@ const ThreadContent: FC = () => {
AssistantMessage,
}}
/>
<AuiIf condition={({ thread }) => !thread.isEmpty}>
<div className="grow" />
</AuiIf>
<ThreadPrimitive.ViewportFooter
className="aui-thread-viewport-footer sticky bottom-0 z-10 mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-2xl bg-main-panel pb-4 md:pb-6"
style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }}
>
<ThreadScrollToBottom />
<AuiIf condition={({ thread }) => !thread.isEmpty}>
<Composer />
</AuiIf>
</ThreadPrimitive.ViewportFooter>
</ThreadPrimitive.Viewport>
</ChatViewport>
</ThreadPrimitive.Root>
);
};
const ThreadScrollToBottom: FC = () => {
const PremiumQuotaPinnedAlert: FC = () => {
const currentThreadState = useAtomValue(currentThreadAtom);
const alertsByThread = useAtomValue(premiumAlertByThreadAtom);
const clearPremiumAlertForThread = useSetAtom(clearPremiumAlertForThreadAtom);
const currentThreadId = currentThreadState?.id;
if (!currentThreadId) return null;
const alert = alertsByThread[currentThreadId];
if (!alert) return null;
return (
<ThreadPrimitive.ScrollToBottom asChild>
<TooltipIconButton
tooltip="Scroll to bottom"
variant="outline"
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
>
<ArrowDownIcon />
</TooltipIconButton>
</ThreadPrimitive.ScrollToBottom>
<div className="mx-0 overflow-hidden rounded-2xl border-input bg-muted px-4 py-4 text-foreground select-none">
<div className="flex items-center gap-2">
<AlertCircle className="size-4 shrink-0 text-muted-foreground" />
<div className="min-w-0 flex-1">
<p className="text-sm">{alert.message}</p>
</div>
<button
type="button"
className="inline-flex size-6 items-center justify-center text-muted-foreground transition-colors hover:text-foreground"
aria-label="Dismiss premium quota alert"
onClick={() => clearPremiumAlertForThread(currentThreadId)}
>
<X className="size-4" />
</button>
</div>
</div>
);
};
@ -373,23 +382,9 @@ const Composer: FC = () => {
>(new Map());
const documentPickerRef = useRef<DocumentMentionPickerRef>(null);
const promptPickerRef = useRef<PromptPickerRef>(null);
const viewportRef = useRef<Element | null>(null);
const { search_space_id, chat_id } = useParams();
const aui = useAui();
const threadViewportStore = useThreadViewportStore();
const hasAutoFocusedRef = useRef(false);
const submitCleanupRef = useRef<(() => void) | null>(null);
useEffect(() => {
return () => {
submitCleanupRef.current?.();
};
}, []);
// Store viewport element reference on mount
useEffect(() => {
viewportRef.current = document.querySelector(".aui-thread-viewport");
}, []);
const electronAPI = useElectronAPI();
const [clipboardInitialText, setClipboardInitialText] = useState<string | undefined>();
@ -588,7 +583,6 @@ const Composer: FC = () => {
[showDocumentPopover, showPromptPicker]
);
// Submit message (blocked during streaming, document picker open, or AI responding to another user)
const handleSubmit = useCallback(() => {
if (isThreadRunning || isBlockedByOtherUser) return;
if (showDocumentPopover || showPromptPicker) return;
@ -600,50 +594,9 @@ const Composer: FC = () => {
setClipboardInitialText(undefined);
}
const viewportEl = viewportRef.current;
const heightBefore = viewportEl?.scrollHeight ?? 0;
aui.composer().send();
editorRef.current?.clear();
setMentionedDocuments([]);
// With turnAnchor="top", ViewportSlack adds min-height to the last
// assistant message so that scrolling-to-bottom actually positions the
// user message at the TOP of the viewport. That slack height is
// calculated asynchronously (ResizeObserver → style → layout).
// Poll via rAF for ~500ms, re-scrolling whenever scrollHeight changes.
const scrollToBottom = () =>
threadViewportStore.getState().scrollToBottom({ behavior: "instant" });
let lastHeight = heightBefore;
let frames = 0;
let cancelled = false;
const POLL_FRAMES = 30;
const pollAndScroll = () => {
if (cancelled) return;
const el = viewportRef.current;
if (el) {
const h = el.scrollHeight;
if (h !== lastHeight) {
lastHeight = h;
scrollToBottom();
}
}
if (++frames < POLL_FRAMES) {
requestAnimationFrame(pollAndScroll);
}
};
requestAnimationFrame(pollAndScroll);
const t1 = setTimeout(scrollToBottom, 100);
const t2 = setTimeout(scrollToBottom, 300);
submitCleanupRef.current = () => {
cancelled = true;
clearTimeout(t1);
clearTimeout(t2);
};
}, [
showDocumentPopover,
showPromptPicker,
@ -652,7 +605,6 @@ const Composer: FC = () => {
clipboardInitialText,
aui,
setMentionedDocuments,
threadViewportStore,
]);
const handleDocumentRemove = useCallback(
@ -1317,12 +1269,14 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
);
};
/** Convert snake_case tool names to human-readable labels */
/**
* Friendly tool name for display in the chat UI. Delegates to the
* shared map in ``contracts/enums/toolIcons`` so unix-style identifiers
* (``rm``, ``ls``, ``grep`` ) and snake_cased function names render as
* plain English (e.g. "Delete file", "List files", "Search in files").
*/
function formatToolName(name: string): string {
return name
.split("_")
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
.join(" ");
return getToolDisplayName(name);
}
interface ToolGroup {

View file

@ -1,30 +1,277 @@
import type { ToolCallMessagePartComponent } from "@assistant-ui/react";
import { CheckIcon, ChevronDownIcon, ChevronUpIcon, XCircleIcon } from "lucide-react";
import { useMemo, useState } from "react";
import { type ToolCallMessagePartComponent, useAuiState } from "@assistant-ui/react";
import { useQueryClient } from "@tanstack/react-query";
import { useAtomValue } from "jotai";
import { CheckIcon, ChevronDownIcon, RotateCcw, XCircleIcon } from "lucide-react";
import { useEffect, useMemo, useState } from "react";
import { toast } from "sonner";
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
import { NestedScroll } from "@/components/assistant-ui/nested-scroll";
import {
DoomLoopApprovalToolUI,
isDoomLoopInterrupt,
} from "@/components/tool-ui/doom-loop-approval";
import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval";
import { getToolIcon } from "@/contracts/enums/toolIcons";
import {
AlertDialog,
AlertDialogAction,
AlertDialogCancel,
AlertDialogContent,
AlertDialogDescription,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogTitle,
AlertDialogTrigger,
} from "@/components/ui/alert-dialog";
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import { Card } from "@/components/ui/card";
import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible";
import { Separator } from "@/components/ui/separator";
import { Spinner } from "@/components/ui/spinner";
import { getToolDisplayName } from "@/contracts/enums/toolIcons";
import { markActionRevertedInCache, useAgentActionsQuery } from "@/hooks/use-agent-actions-query";
import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service";
import { AppError } from "@/lib/error";
import { isInterruptResult } from "@/lib/hitl";
import { cn } from "@/lib/utils";
function formatToolName(name: string): string {
return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase());
/**
* Inline Revert button rendered on a tool card when the matching
* ``AgentActionLog`` row is reversible and hasn't been reverted yet.
*
* Reads from the unified ``useAgentActionsQuery`` cache the SAME
* react-query cache the agent-actions sheet consumes. SSE events
* (``data-action-log`` / ``data-action-log-updated``) and
* ``POST /threads/{id}/revert/{id}`` responses both flow through the
* cache via ``setQueryData`` helpers, so the card and the sheet stay
* in lockstep on every code path: page reload, navigation, live
* stream, post-stream reversibility flip, and explicit revert clicks.
*
* Match key (in priority order):
* 1. ``a.tool_call_id === toolCallId`` direct hit in parity_v2 when
* the model streamed ``tool_call_chunks`` so the card's synthetic
* id IS the LangChain id.
* 2. ``a.tool_call_id === langchainToolCallId`` legacy mode (or
* parity_v2 with provider-side chunk emission) where the card's
* synthetic id is ``call_<run_id>`` and the LangChain id is
* backfilled onto the part by ``tool-output-available``.
* 3. ``(chat_turn_id, tool_name, position-within-turn)`` fallback
* for cards whose synthetic id is ``call_<run_id>`` AND whose
* ``langchainToolCallId`` never got backfilled (provider emitted
* the tool_call as a single payload with no chunks AND streaming
* pre-dated the ``tool-output-available langchainToolCallId``
* backfill, e.g. older threads). Reads the parent message's
* ``chatTurnId`` and ``content`` via ``useAuiState`` so we can
* match position-by-tool-name within the turn against the
* action_log rows the server returned in ``created_at`` order.
*/
function ToolCardRevertButton({
toolCallId,
toolName,
langchainToolCallId,
}: {
toolCallId: string;
toolName: string;
langchainToolCallId?: string;
}) {
const session = useAtomValue(chatSessionStateAtom);
const threadId = session?.threadId ?? null;
const queryClient = useQueryClient();
const { findByToolCallId, findByChatTurnAndTool } = useAgentActionsQuery(threadId);
// Parent message metadata, read via the narrowest possible
// selectors so this card doesn't re-render on every text-delta of
// every other part in the same message during streaming.
//
// IMPORTANT — ``useAuiState`` re-renders the component whenever the
// returned slice's identity changes. Returning ``message?.content``
// (an array) would re-render on every token because the runtime
// rebuilds the parts array. Returning a PRIMITIVE (the position
// number) lets ``useAuiState``'s ``Object.is`` check short-circuit
// when the position hasn't actually moved — which is the common
// case during text streaming, when only ``text``/``reasoning``
// parts are mutating and the same-toolName tool-call ordering is
// stable. (See Vercel React rule ``rerender-defer-reads``.)
const chatTurnId = useAuiState(({ message }) => {
const meta = message?.metadata as { custom?: { chatTurnId?: string } } | undefined;
return meta?.custom?.chatTurnId ?? null;
});
const positionInTurn = useAuiState(({ message }) => {
const content = message?.content;
if (!Array.isArray(content)) return -1;
let n = -1;
for (const part of content) {
if (
part &&
typeof part === "object" &&
(part as { type?: string }).type === "tool-call" &&
(part as { toolName?: string }).toolName === toolName
) {
n += 1;
if ((part as { toolCallId?: string }).toolCallId === toolCallId) return n;
}
}
return -1;
});
const action = useMemo(() => {
// Tier 1 + 2: O(1) Map-backed direct id match. Covers
// ~all parity_v2 streams and any legacy stream that backfilled
// ``langchainToolCallId`` via ``tool-output-available``.
const direct = findByToolCallId(toolCallId) ?? findByToolCallId(langchainToolCallId);
if (direct) return direct;
// Tier 3: position-within-turn fallback. Only kicks in when the
// card has a synthetic ``call_<run_id>`` id AND no
// ``langchainToolCallId`` was ever backfilled — i.e. the tool
// was emitted as a single non-chunked payload AND streaming
// pre-dated the on_tool_end backfill.
if (!chatTurnId || positionInTurn < 0) return null;
const turnSameTool = findByChatTurnAndTool(chatTurnId, toolName);
return turnSameTool[positionInTurn] ?? null;
}, [
findByToolCallId,
findByChatTurnAndTool,
toolCallId,
langchainToolCallId,
chatTurnId,
toolName,
positionInTurn,
]);
const [isReverting, setIsReverting] = useState(false);
const [confirmOpen, setConfirmOpen] = useState(false);
if (!action) return null;
if (!action.reversible) return null;
if (action.reverted_by_action_id !== null && action.reverted_by_action_id !== undefined)
return null;
if (action.is_revert_action) return null;
if (action.error !== null && action.error !== undefined) return null;
if (!threadId) return null;
const handleRevert = async () => {
setIsReverting(true);
try {
const response = await agentActionsApiService.revert(threadId, action.id);
markActionRevertedInCache(queryClient, threadId, action.id, response.new_action_id ?? null);
toast.success(response.message || "Action reverted.");
} catch (err) {
// 503 means revert is gated off on this deployment — hide the
// button silently rather than nagging the user. Any other error
// is surfaced as a toast so the operator can investigate.
if (err instanceof AppError && err.status === 503) {
return;
}
const message =
err instanceof AppError
? err.message
: err instanceof Error
? err.message
: "Failed to revert action.";
toast.error(message);
} finally {
setIsReverting(false);
setConfirmOpen(false);
}
};
return (
<AlertDialog open={confirmOpen} onOpenChange={setConfirmOpen}>
<AlertDialogTrigger asChild>
<Button
size="sm"
variant="outline"
className="gap-1.5"
onClick={(e) => {
e.stopPropagation();
setConfirmOpen(true);
}}
disabled={isReverting}
>
{isReverting ? (
// Spinner's typed props don't accept ``data-icon`` and
// it renders an <output>, not an <svg>, so Button's
// auto-sizing rule doesn't apply. Bare spinner +
// Button's gap handle layout.
<Spinner size="xs" />
) : (
<RotateCcw data-icon="inline-start" />
)}
Revert
</Button>
</AlertDialogTrigger>
<AlertDialogContent>
<AlertDialogHeader>
<AlertDialogTitle>Revert this action?</AlertDialogTitle>
<AlertDialogDescription>
This will undo{" "}
<span className="font-medium">{getToolDisplayName(action.tool_name)}</span> and add a
new entry to the history. Your chat is preserved only the changes the agent made to
your knowledge base or connected apps will be rolled back where possible.
</AlertDialogDescription>
</AlertDialogHeader>
<AlertDialogFooter>
<AlertDialogCancel disabled={isReverting}>Cancel</AlertDialogCancel>
<AlertDialogAction
onClick={(e) => {
e.preventDefault();
handleRevert();
}}
disabled={isReverting}
className="gap-1.5"
>
{isReverting && <Spinner size="xs" />}
Revert
</AlertDialogAction>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialog>
);
}
const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({
toolName,
argsText,
result,
status,
}) => {
const [isExpanded, setIsExpanded] = useState(false);
/**
* Compact tool-call card.
*
* shadcn composition note: we intentionally use ``Card`` as a visual
* frame WITHOUT ``CardHeader / CardContent``. The full composition's
* ``p-6`` padding doesn't fit a compact collapsible header that IS the
* trigger; using ``Card`` alone preserves the rounded border, shadow,
* and ``bg-card`` token (semantic colors) without forcing a layout
* that doesn't fit. All status colors use semantic tokens no manual
* dark-mode overrides, no raw hex.
*/
const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {
const { toolCallId, toolName, argsText, result, status } = props;
// ``langchainToolCallId`` is a SurfSense-specific extension the
// streaming pipeline attaches to the tool-call content part so
// the Revert button can resolve its ``AgentActionLog`` row even
// when only the LC id is known. assistant-ui's
// ``ToolCallMessagePartProps`` doesn't list it, but the runtime
// spreads ``{...part}`` so the prop reaches us at runtime.
const langchainToolCallId = (props as { langchainToolCallId?: string }).langchainToolCallId;
const isCancelled = status?.type === "incomplete" && status.reason === "cancelled";
const isError = status?.type === "incomplete" && status.reason === "error";
const isRunning = status?.type === "running" || status?.type === "requires-action";
/*
Per-card expansion state. Initial value is ``isRunning`` so a
card streaming in mounts already-expanded (no flash of
collapsed expanded on first paint), while a card loaded from
history (status="complete") mounts collapsed. The useEffect
below keeps this in lockstep with this card's own ``isRunning``
when it transitions: false true auto-expands (e.g. a tool
that re-runs after edit), true false auto-collapses once the
tool finishes. Because the dep is per-card ``isRunning`` and
not the chat-level streaming flag, sibling cards on the same
assistant turn each manage their own expansion independently.
Once ``isRunning`` is false the user controls expansion via
``onOpenChange``.
*/
const [isExpanded, setIsExpanded] = useState(isRunning);
useEffect(() => {
setIsExpanded(isRunning);
}, [isRunning]);
const errorData = status?.type === "incomplete" ? status.error : undefined;
const serializedError = useMemo(
() => (errorData && typeof errorData !== "string" ? JSON.stringify(errorData) : null),
@ -50,105 +297,207 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({
: serializedError
: null;
const Icon = getToolIcon(toolName);
const displayName = formatToolName(toolName);
const displayName = getToolDisplayName(toolName);
const subtitle = errorReason ?? cancelledReason;
return (
<div
<Card
className={cn(
"my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none",
"my-4 max-w-lg overflow-hidden",
isCancelled && "opacity-60",
isError && "border-destructive/20 bg-destructive/5"
isError && "border-destructive/30"
)}
>
<button
type="button"
onClick={() => setIsExpanded((prev) => !prev)}
className="flex w-full items-center gap-3 px-5 py-4 text-left transition-colors hover:bg-muted/50 focus:outline-none focus-visible:outline-none"
{/*
``group`` lets the chevron (rendered as a sibling of the
main trigger button) read the Collapsible Root's
``data-[state=open]`` for rotation. The Collapsible is
fully controlled via ``isExpanded`` the useEffect
above syncs it to ``isRunning`` so the card auto-opens
while a tool streams in and auto-collapses once it
finishes. We deliberately DON'T pass ``disabled`` so
both triggers stay clickable; ``onOpenChange`` is wired
to a setter that no-ops while ``isRunning`` (see
``handleOpenChange`` below) which keeps the card pinned
open mid-stream without losing keyboard / pointer
affordance the moment streaming ends.
*/}
<Collapsible
className="group"
open={isExpanded}
onOpenChange={(next) => {
// Block manual collapse while the tool is still
// streaming — otherwise a stray click on either
// trigger would close the card and hide the live
// ``argsText`` panel mid-run. After streaming the
// user has full control again.
if (isRunning) return;
setIsExpanded(next);
}}
>
<div
className={cn(
"flex size-8 shrink-0 items-center justify-center rounded-lg",
isError ? "bg-destructive/10" : isCancelled ? "bg-muted" : "bg-primary/10"
)}
>
{isError ? (
<XCircleIcon className="size-4 text-destructive" />
) : isCancelled ? (
<XCircleIcon className="size-4 text-muted-foreground" />
) : isRunning ? (
<Icon className="size-4 text-primary animate-pulse" />
) : (
<CheckIcon className="size-4 text-primary" />
)}
</div>
{/*
Header row: main trigger on the left (icon + title
col), Revert + chevron-trigger on the right as
siblings of the main trigger. The chevron is wrapped
in its OWN ``CollapsibleTrigger`` (Radix supports
multiple triggers per Root) so clicking the chevron
toggles the same state as clicking the title row.
The Revert button stays a separate AlertDialog
trigger and stops propagation in its onClick so it
doesn't toggle the collapsible while opening the
confirm dialog. Keeping these as flat siblings
rather than nesting Revert / chevron inside the
title trigger avoids invalid HTML
(button-in-button) and lets the Revert button
render in BOTH the collapsed and expanded states.
*/}
<div className="flex items-stretch transition-colors hover:bg-muted/50">
<CollapsibleTrigger asChild>
<button
type="button"
className={cn(
"flex flex-1 min-w-0 items-center gap-3 py-4 pl-5 pr-2 text-left",
// Inset ring — Card's ``overflow-hidden`` would
// clip an ``offset-2`` ring; ``ring-inset``
// paints inside the button box.
"focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-inset",
"disabled:cursor-default"
)}
>
<div
className={cn(
"flex size-8 shrink-0 items-center justify-center rounded-lg",
isError ? "bg-destructive/10" : isCancelled ? "bg-muted" : "bg-primary/10"
)}
>
{isError ? (
<XCircleIcon className="size-4 text-destructive" />
) : isCancelled ? (
<XCircleIcon className="size-4 text-muted-foreground" />
) : isRunning ? (
<Spinner size="sm" className="text-primary" />
) : (
<CheckIcon className="size-4 text-primary" />
)}
</div>
<div className="flex-1 min-w-0">
<p
className={cn(
"text-sm font-semibold",
isError
? "text-destructive"
: isCancelled
? "text-muted-foreground line-through"
: "text-foreground"
)}
>
{isRunning
? displayName
: isCancelled
? `Cancelled: ${displayName}`
: isError
? `Failed: ${displayName}`
: displayName}
</p>
{isRunning && <p className="text-xs text-muted-foreground mt-0.5">Running...</p>}
{cancelledReason && (
<p className="text-xs text-muted-foreground mt-0.5 truncate">{cancelledReason}</p>
)}
{errorReason && (
<p className="text-xs text-destructive/80 mt-0.5 truncate">{errorReason}</p>
)}
</div>
<div className="flex flex-1 min-w-0 flex-col gap-0.5">
<div className="flex items-center gap-2">
<p
className={cn(
"text-sm font-semibold truncate",
isCancelled && "text-muted-foreground line-through",
isError && "text-destructive"
)}
>
{displayName}
</p>
{isRunning && <Badge variant="secondary">Running</Badge>}
{isError && <Badge variant="destructive">Failed</Badge>}
{isCancelled && <Badge variant="outline">Cancelled</Badge>}
</div>
{subtitle && (
<p
className={cn(
"text-xs truncate",
isError ? "text-destructive/80" : "text-muted-foreground"
)}
>
{subtitle}
</p>
)}
</div>
</button>
</CollapsibleTrigger>
{!isRunning && (
<div className="shrink-0 text-muted-foreground">
{isExpanded ? (
<ChevronDownIcon className="size-4" />
) : (
<ChevronUpIcon className="size-4" />
)}
{/*
Right-side controls. The Revert button is
visible whenever the matching action is
reversible including the collapsed state
but ``ToolCardRevertButton`` itself returns
``null`` while a tool is still running because
no action-log row exists yet, so it doesn't
need an explicit ``isRunning`` gate here.
*/}
<div className="flex shrink-0 items-center gap-2 pl-2 pr-5">
<ToolCardRevertButton
toolCallId={toolCallId}
toolName={toolName}
langchainToolCallId={langchainToolCallId}
/>
<CollapsibleTrigger asChild>
<button
type="button"
aria-label={isExpanded ? "Collapse details" : "Expand details"}
className={cn(
"flex size-7 shrink-0 items-center justify-center rounded-md",
"text-muted-foreground hover:bg-muted hover:text-foreground",
"focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-inset",
"disabled:cursor-default"
)}
>
<ChevronDownIcon
className={cn(
"size-4 transition-transform duration-200",
"group-data-[state=open]:rotate-180"
)}
/>
</button>
</CollapsibleTrigger>
</div>
)}
</button>
</div>
{isExpanded && !isRunning && (
<>
<div className="mx-5 h-px bg-border/50" />
<div className="px-5 py-3 space-y-3">
{argsText && (
<div>
<p className="text-xs font-medium text-muted-foreground mb-1">Arguments</p>
<pre className="text-xs text-foreground/80 whitespace-pre-wrap break-all">
{argsText}
</pre>
{/*
CollapsibleContent body auto-open while streaming
(see ``open`` prop above) so the live ``argsText``
streams into the Inputs panel directly, no need for
a separate "Live input" panel. Native
``overflow-auto`` instead of ``ScrollArea`` because
Radix's Viewport can let content bleed past
``max-h-*`` in dynamic flex layouts. ``min-w-0`` on
the column wrappers guarantees ``break-all`` wraps
correctly within the bounded ``max-w-lg`` Card.
*/}
<CollapsibleContent>
<Separator />
<div className="flex flex-col gap-3 px-5 py-3">
{(argsText || isRunning) && (
<div className="flex flex-col gap-1 min-w-0">
<p className="text-xs font-medium text-muted-foreground">Inputs</p>
<NestedScroll className="max-h-48 overflow-auto rounded-md bg-muted/40">
{argsText ? (
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
{argsText}
</pre>
) : (
// Bridges the brief gap between
// ``tool-input-start`` (creates the
// card, ``argsText`` undefined) and
// the first ``tool-input-delta``.
<p className="px-3 py-2 text-xs italic text-muted-foreground">
Waiting for input
</p>
)}
</NestedScroll>
</div>
)}
{!isCancelled && result !== undefined && (
<>
<div className="h-px bg-border/30" />
<div>
<p className="text-xs font-medium text-muted-foreground mb-1">Result</p>
<pre className="text-xs text-foreground/80 whitespace-pre-wrap break-all">
{typeof result === "string" ? result : serializedResult}
</pre>
<Separator />
<div className="flex flex-col gap-1 min-w-0">
<p className="text-xs font-medium text-muted-foreground">Result</p>
<NestedScroll className="max-h-64 overflow-auto rounded-md bg-muted/40">
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
{typeof result === "string" ? result : serializedResult}
</pre>
</NestedScroll>
</div>
</>
)}
</div>
</>
)}
</div>
</CollapsibleContent>
</Collapsible>
</Card>
);
};

View file

@ -1,4 +1,10 @@
import { ActionBarPrimitive, AuiIf, MessagePrimitive, useAuiState } from "@assistant-ui/react";
import {
ActionBarPrimitive,
AuiIf,
MessagePrimitive,
useAuiState,
useMessagePartText,
} from "@assistant-ui/react";
import { useAtomValue } from "jotai";
import { CheckIcon, CopyIcon, Pencil } from "lucide-react";
import Image from "next/image";
@ -7,6 +13,8 @@ import { currentThreadAtom } from "@/atoms/chat/current-thread.atom";
import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
import { getMentionDocKey } from "@/lib/chat/mention-doc-key";
import { parseMentionSegments } from "@/lib/chat/parse-mention-segments";
interface AuthorMetadata {
displayName: string | null;
@ -47,23 +55,40 @@ const UserAvatar: FC<AuthorMetadata> = ({ displayName, avatarUrl }) => {
);
};
export const UserMessage: FC = () => {
const UserTextPart: FC = () => {
const messageId = useAuiState(({ message }) => message?.id);
const messageText = useAuiState(({ message }) =>
(message?.content ?? [])
.map((part) =>
typeof part === "object" &&
part !== null &&
"type" in part &&
(part as { type?: string }).type === "text" &&
"text" in part
? String((part as { text?: string }).text ?? "")
: ""
)
.join("")
);
const part = useMessagePartText();
const text = (part as { text?: string }).text ?? "";
const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom);
const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined;
const mentionedDocs = (messageId ? messageDocumentsMap[messageId] : undefined) ?? [];
const segments = parseMentionSegments(text, mentionedDocs);
return (
<p style={{ whiteSpace: "pre-line" }} className="break-words">
{segments.map((segment) =>
segment.type === "text" ? (
<span key={`txt-${segment.start}`}>{segment.value}</span>
) : (
<span
key={`mention-${getMentionDocKey(segment.doc)}-${segment.start}`}
className="inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none align-middle leading-none"
title={segment.doc.title}
>
<span className="flex items-center text-muted-foreground">
{getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")}
</span>
<span className="max-w-[120px] truncate">{segment.doc.title}</span>
</span>
)
)}
</p>
);
};
const userMessageParts = { Text: UserTextPart };
export const UserMessage: FC = () => {
const metadata = useAuiState(({ message }) => message?.metadata);
const author = metadata?.custom?.author as AuthorMetadata | undefined;
const isSharedChat = useAtomValue(currentThreadAtom).visibility === "SEARCH_SPACE";
@ -77,12 +102,8 @@ export const UserMessage: FC = () => {
<div className="col-start-2 min-w-0">
<div className="aui-user-message-content-wrapper flex items-end gap-2">
<div className="relative flex-1 min-w-0">
<div className="aui-user-message-content wrap-break-word rounded-xl bg-muted px-4 py-2.5 text-foreground">
{mentionedDocs && mentionedDocs.length > 0 ? (
<UserMessageWithMentionChips text={messageText} mentionedDocs={mentionedDocs} />
) : (
<MessagePrimitive.Parts />
)}
<div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground">
<MessagePrimitive.Parts components={userMessageParts} />
</div>
<div className="absolute right-0 top-full mt-1 z-10 opacity-100 pointer-events-auto md:opacity-0 md:pointer-events-none md:transition-opacity md:duration-200 md:delay-300 md:group-hover/user-msg:opacity-100 md:group-hover/user-msg:delay-0 md:group-hover/user-msg:pointer-events-auto">
<UserActionBar />
@ -99,64 +120,6 @@ export const UserMessage: FC = () => {
);
};
const UserMessageWithMentionChips: FC<{
text: string;
mentionedDocs: { id: number; title: string; document_type: string }[];
}> = ({ text, mentionedDocs }) => {
type Segment =
| { type: "text"; value: string; start: number }
| { type: "mention"; doc: { id: number; title: string; document_type: string }; start: number };
const tokens = mentionedDocs
.map((doc) => ({ doc, token: `@${doc.title}` }))
.sort((a, b) => b.token.length - a.token.length);
const segments: Segment[] = [];
let i = 0;
let buffer = "";
let bufferStart = 0;
while (i < text.length) {
const tokenMatch = tokens.find(({ token }) => text.startsWith(token, i));
if (tokenMatch) {
if (buffer) {
segments.push({ type: "text", value: buffer, start: bufferStart });
buffer = "";
}
segments.push({ type: "mention", doc: tokenMatch.doc, start: i });
i += tokenMatch.token.length;
bufferStart = i;
continue;
}
if (!buffer) bufferStart = i;
buffer += text[i];
i += 1;
}
if (buffer) {
segments.push({ type: "text", value: buffer, start: bufferStart });
}
return (
<span className="whitespace-pre-wrap break-words">
{segments.map((segment) =>
segment.type === "text" ? (
<span key={`txt-${segment.start}`}>{segment.value}</span>
) : (
<span
key={`mention-${segment.doc.document_type}:${segment.doc.id}-${segment.start}`}
className="inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none align-baseline"
title={segment.doc.title}
>
<span className="flex items-center text-muted-foreground">
{getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")}
</span>
<span className="max-w-[120px] truncate">{segment.doc.title}</span>
</span>
)
)}
</span>
);
};
const UserActionBar: FC = () => {
const isThreadRunning = useAuiState(({ thread }) => thread.isRunning);

View file

@ -169,7 +169,7 @@ export const CitationPanelContent: FC<CitationPanelContentProps> = ({ chunkId, o
)}
</div>
<div className="text-sm">
<MarkdownViewer content={chunk.content} />
<MarkdownViewer content={chunk.content} enableCitations />
</div>
</div>
);

View file

@ -0,0 +1,77 @@
"use client";
import type { ReactNode } from "react";
import { InlineCitation, UrlCitation } from "@/components/assistant-ui/inline-citation";
import {
type CitationToken,
type CitationUrlMap,
parseTextWithCitations,
} from "@/lib/citations/citation-parser";
/**
* Render a single parsed citation token as JSX.
*
* `ordinalKey` should be a stable per-render counter so duplicate identical
* citations within the same parent don't collide on `key`. The previous
* implementation in `markdown-text.tsx` used the source string itself as
* the key, which produced React warnings when two segments rendered the
* same `[citation:N]` text.
*/
export function renderCitationToken(token: CitationToken, ordinalKey: number): ReactNode {
if (token.kind === "url") {
return <UrlCitation key={`citation-url-${ordinalKey}`} url={token.url} />;
}
return (
<InlineCitation
key={`citation-${token.isDocsChunk ? "doc-" : ""}${token.chunkId}-${ordinalKey}`}
chunkId={token.chunkId}
isDocsChunk={token.isDocsChunk}
/>
);
}
/**
* Walk a `ReactNode` (string, array, or arbitrary node) and replace any
* `[citation:...]` tokens inside string children with citation badges.
*
* Designed for use inside `Streamdown`/`react-markdown` `components`
* overrides where the renderer hands you `children`. Non-string children
* are returned untouched so block/phrasing structure is preserved.
*/
export function processChildrenWithCitations(
children: ReactNode,
urlMap: CitationUrlMap
): ReactNode {
if (typeof children === "string") {
const segments = parseTextWithCitations(children, urlMap);
if (segments.length === 1 && typeof segments[0] === "string") {
return children;
}
let ordinal = 0;
return segments.map((segment) =>
typeof segment === "string" ? segment : renderCitationToken(segment, ordinal++)
);
}
if (Array.isArray(children)) {
let ordinal = 0;
return children.map((child, childIndex) => {
if (typeof child === "string") {
const segments = parseTextWithCitations(child, urlMap);
if (segments.length === 1 && typeof segments[0] === "string") {
return child;
}
return (
<span key={`citation-seg-${childIndex}`}>
{segments.map((segment) =>
typeof segment === "string" ? segment : renderCitationToken(segment, ordinal++)
)}
</span>
);
}
return child;
});
}
return children;
}

View file

@ -32,7 +32,7 @@ export function DocumentViewer({ title, content, trigger }: DocumentViewerProps)
<DialogTitle>{title}</DialogTitle>
</DialogHeader>
<div className="mt-4">
<MarkdownViewer content={content} />
<MarkdownViewer content={content} enableCitations />
</div>
</DialogContent>
</Dialog>

Some files were not shown because too many files have changed in this diff Show more