mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-21 18:55:16 +02:00
Merge remote-tracking branch 'upstream/dev' into feat/ui-revamp
This commit is contained in:
commit
9b1b5a504e
148 changed files with 19460 additions and 2708 deletions
13
.github/workflows/desktop-release.yml
vendored
13
.github/workflows/desktop-release.yml
vendored
|
|
@ -136,6 +136,19 @@ jobs:
|
||||||
AZURE_CODESIGN_ENDPOINT: ${{ vars.AZURE_CODESIGN_ENDPOINT }}
|
AZURE_CODESIGN_ENDPOINT: ${{ vars.AZURE_CODESIGN_ENDPOINT }}
|
||||||
AZURE_CODESIGN_ACCOUNT: ${{ vars.AZURE_CODESIGN_ACCOUNT }}
|
AZURE_CODESIGN_ACCOUNT: ${{ vars.AZURE_CODESIGN_ACCOUNT }}
|
||||||
AZURE_CODESIGN_PROFILE: ${{ vars.AZURE_CODESIGN_PROFILE }}
|
AZURE_CODESIGN_PROFILE: ${{ vars.AZURE_CODESIGN_PROFILE }}
|
||||||
|
# macOS Developer ID signing + notarization. Only the macos-latest runner
|
||||||
|
# consumes these; Windows/Linux runners ignore them. CSC_LINK accepts either
|
||||||
|
# a file path or a base64-encoded .p12 blob — electron-builder auto-detects.
|
||||||
|
CSC_LINK: ${{ secrets.MAC_CERT_P12_BASE64 }}
|
||||||
|
CSC_KEY_PASSWORD: ${{ secrets.MAC_CERT_PASSWORD }}
|
||||||
|
APPLE_ID: ${{ secrets.APPLE_ID }}
|
||||||
|
APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
|
||||||
|
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
|
||||||
|
# TEMP DEBUG — remove once the codesign hang on macos-latest is diagnosed.
|
||||||
|
# Surfaces the exact codesign / notarize commands electron-builder spawns,
|
||||||
|
# so we can see which subprocess hangs.
|
||||||
|
DEBUG: electron-builder,electron-osx-sign*,@electron/notarize*
|
||||||
|
ELECTRON_BUILDER_ALLOW_UNRESOLVED_DEPENDENCIES: "true"
|
||||||
# Service principal credentials for Azure.Identity EnvironmentCredential used by the
|
# Service principal credentials for Azure.Identity EnvironmentCredential used by the
|
||||||
# TrustedSigning PowerShell module. Only populated when signing is enabled.
|
# TrustedSigning PowerShell module. Only populated when signing is enabled.
|
||||||
# electron-builder 26 does not yet support OIDC federated tokens for Azure signing,
|
# electron-builder 26 does not yet support OIDC federated tokens for Azure signing,
|
||||||
|
|
|
||||||
60
.github/workflows/notary-status.yml
vendored
Normal file
60
.github/workflows/notary-status.yml
vendored
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
name: Notary status check
|
||||||
|
|
||||||
|
# One-off diagnostic workflow. Queries Apple's notary service to see if your
|
||||||
|
# submissions are queued, in progress, accepted, or rejected. Useful when a
|
||||||
|
# notarization seems "hung" — most often the queue itself, especially on a
|
||||||
|
# brand-new Apple Developer account.
|
||||||
|
#
|
||||||
|
# Run via: Actions tab -> "Notary status check" -> Run workflow.
|
||||||
|
# Inputs are optional; if you provide a submission ID, it also fetches that
|
||||||
|
# submission's full Apple log.
|
||||||
|
#
|
||||||
|
# Safe to delete after diagnosis.
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
submission_id:
|
||||||
|
description: 'Optional: submission UUID to fetch full Apple log for'
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
status:
|
||||||
|
runs-on: macos-latest
|
||||||
|
steps:
|
||||||
|
- name: List recent notarization submissions
|
||||||
|
env:
|
||||||
|
APPLE_ID: ${{ secrets.APPLE_ID }}
|
||||||
|
APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
|
||||||
|
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
echo "::group::Submission history (most recent first)"
|
||||||
|
xcrun notarytool history \
|
||||||
|
--apple-id "$APPLE_ID" \
|
||||||
|
--password "$APPLE_APP_SPECIFIC_PASSWORD" \
|
||||||
|
--team-id "$APPLE_TEAM_ID"
|
||||||
|
echo "::endgroup::"
|
||||||
|
|
||||||
|
- name: Inspect specific submission (if id provided)
|
||||||
|
if: ${{ inputs.submission_id != '' }}
|
||||||
|
env:
|
||||||
|
APPLE_ID: ${{ secrets.APPLE_ID }}
|
||||||
|
APPLE_APP_SPECIFIC_PASSWORD: ${{ secrets.APPLE_APP_SPECIFIC_PASSWORD }}
|
||||||
|
APPLE_TEAM_ID: ${{ secrets.APPLE_TEAM_ID }}
|
||||||
|
SUBMISSION_ID: ${{ inputs.submission_id }}
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
echo "::group::Submission info"
|
||||||
|
xcrun notarytool info "$SUBMISSION_ID" \
|
||||||
|
--apple-id "$APPLE_ID" \
|
||||||
|
--password "$APPLE_APP_SPECIFIC_PASSWORD" \
|
||||||
|
--team-id "$APPLE_TEAM_ID"
|
||||||
|
echo "::endgroup::"
|
||||||
|
echo "::group::Apple's processing log for this submission"
|
||||||
|
xcrun notarytool log "$SUBMISSION_ID" \
|
||||||
|
--apple-id "$APPLE_ID" \
|
||||||
|
--password "$APPLE_APP_SPECIFIC_PASSWORD" \
|
||||||
|
--team-id "$APPLE_TEAM_ID" || true
|
||||||
|
echo "::endgroup::"
|
||||||
|
|
@ -282,6 +282,14 @@ LANGSMITH_PROJECT=surfsense
|
||||||
# SURFSENSE_ENABLE_ACTION_LOG=false
|
# SURFSENSE_ENABLE_ACTION_LOG=false
|
||||||
# SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships
|
# SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships
|
||||||
|
|
||||||
|
# Streaming parity v2 — opt in to LangChain's structured AIMessageChunk
|
||||||
|
# content (typed reasoning blocks, tool-input deltas) and propagate the
|
||||||
|
# real tool_call_id to the SSE layer. When OFF, the stream falls back to
|
||||||
|
# the str-only text path and synthetic "call_<run_id>" tool-call ids.
|
||||||
|
# Schema migrations 135/136 ship unconditionally because they are
|
||||||
|
# forward-compatible.
|
||||||
|
# SURFSENSE_ENABLE_STREAM_PARITY_V2=false
|
||||||
|
|
||||||
# Plugins
|
# Plugins
|
||||||
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
|
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
|
||||||
# Comma-separated allowlist of plugin entry-point names
|
# Comma-separated allowlist of plugin entry-point names
|
||||||
|
|
|
||||||
139
surfsense_backend/alembic/versions/134_relax_revision_fks.py
Normal file
139
surfsense_backend/alembic/versions/134_relax_revision_fks.py
Normal file
|
|
@ -0,0 +1,139 @@
|
||||||
|
"""134_relax_revision_fks
|
||||||
|
|
||||||
|
Revision ID: 134
|
||||||
|
Revises: 133
|
||||||
|
Create Date: 2026-04-29
|
||||||
|
|
||||||
|
Relax the parent FKs on ``document_revisions`` and ``folder_revisions`` so
|
||||||
|
revisions survive the deletes they describe.
|
||||||
|
|
||||||
|
Why: the snapshot/revert pipeline writes a ``DocumentRevision`` BEFORE
|
||||||
|
hard-deleting a document via the ``rm`` tool (and likewise a
|
||||||
|
``FolderRevision`` before ``rmdir``). If the FK is ``ON DELETE CASCADE``
|
||||||
|
the snapshot row is wiped at the exact moment we need it most — revert
|
||||||
|
then has nothing to read and the operation becomes irreversible.
|
||||||
|
|
||||||
|
Migration:
|
||||||
|
|
||||||
|
* ``document_revisions.document_id``: ``NOT NULL`` -> nullable; FK
|
||||||
|
``ON DELETE CASCADE`` -> ``ON DELETE SET NULL``.
|
||||||
|
* ``folder_revisions.folder_id``: same treatment.
|
||||||
|
|
||||||
|
The ``search_space_id`` FK on both tables is left unchanged (still
|
||||||
|
``ON DELETE CASCADE``). When a search space is deleted, all documents,
|
||||||
|
folders, AND their revisions go together — that's the correct teardown
|
||||||
|
story.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import inspect
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "134"
|
||||||
|
down_revision: str | None = "133"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _fk_name(bind, table: str, column: str) -> str | None:
|
||||||
|
"""Return the (single) FK constraint name on ``table.column``, if any."""
|
||||||
|
inspector = inspect(bind)
|
||||||
|
for fk in inspector.get_foreign_keys(table):
|
||||||
|
cols = fk.get("constrained_columns") or []
|
||||||
|
if cols == [column]:
|
||||||
|
return fk.get("name")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
bind = op.get_bind()
|
||||||
|
|
||||||
|
# --- document_revisions.document_id -> nullable + SET NULL ---------------
|
||||||
|
fk_name = _fk_name(bind, "document_revisions", "document_id")
|
||||||
|
if fk_name:
|
||||||
|
op.drop_constraint(fk_name, "document_revisions", type_="foreignkey")
|
||||||
|
op.alter_column(
|
||||||
|
"document_revisions",
|
||||||
|
"document_id",
|
||||||
|
existing_type=sa.Integer(),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"document_revisions_document_id_fkey",
|
||||||
|
"document_revisions",
|
||||||
|
"documents",
|
||||||
|
["document_id"],
|
||||||
|
["id"],
|
||||||
|
ondelete="SET NULL",
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- folder_revisions.folder_id -> nullable + SET NULL -------------------
|
||||||
|
fk_name = _fk_name(bind, "folder_revisions", "folder_id")
|
||||||
|
if fk_name:
|
||||||
|
op.drop_constraint(fk_name, "folder_revisions", type_="foreignkey")
|
||||||
|
op.alter_column(
|
||||||
|
"folder_revisions",
|
||||||
|
"folder_id",
|
||||||
|
existing_type=sa.Integer(),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"folder_revisions_folder_id_fkey",
|
||||||
|
"folder_revisions",
|
||||||
|
"folders",
|
||||||
|
["folder_id"],
|
||||||
|
["id"],
|
||||||
|
ondelete="SET NULL",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
bind = op.get_bind()
|
||||||
|
|
||||||
|
# Reinstating NOT NULL + CASCADE requires draining orphan rows first
|
||||||
|
# (any revision whose parent doc/folder has already been deleted).
|
||||||
|
op.execute("DELETE FROM document_revisions WHERE document_id IS NULL")
|
||||||
|
op.execute("DELETE FROM folder_revisions WHERE folder_id IS NULL")
|
||||||
|
|
||||||
|
# --- document_revisions.document_id -> NOT NULL + CASCADE ---------------
|
||||||
|
fk_name = _fk_name(bind, "document_revisions", "document_id")
|
||||||
|
if fk_name:
|
||||||
|
op.drop_constraint(fk_name, "document_revisions", type_="foreignkey")
|
||||||
|
op.alter_column(
|
||||||
|
"document_revisions",
|
||||||
|
"document_id",
|
||||||
|
existing_type=sa.Integer(),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"document_revisions_document_id_fkey",
|
||||||
|
"document_revisions",
|
||||||
|
"documents",
|
||||||
|
["document_id"],
|
||||||
|
["id"],
|
||||||
|
ondelete="CASCADE",
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- folder_revisions.folder_id -> NOT NULL + CASCADE -------------------
|
||||||
|
fk_name = _fk_name(bind, "folder_revisions", "folder_id")
|
||||||
|
if fk_name:
|
||||||
|
op.drop_constraint(fk_name, "folder_revisions", type_="foreignkey")
|
||||||
|
op.alter_column(
|
||||||
|
"folder_revisions",
|
||||||
|
"folder_id",
|
||||||
|
existing_type=sa.Integer(),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"folder_revisions_folder_id_fkey",
|
||||||
|
"folder_revisions",
|
||||||
|
"folders",
|
||||||
|
["folder_id"],
|
||||||
|
["id"],
|
||||||
|
ondelete="CASCADE",
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,82 @@
|
||||||
|
"""135_action_log_correlation_ids
|
||||||
|
|
||||||
|
Revision ID: 135
|
||||||
|
Revises: 134
|
||||||
|
Create Date: 2026-04-29
|
||||||
|
|
||||||
|
Action-log correlation-id cleanup.
|
||||||
|
|
||||||
|
Background
|
||||||
|
----------
|
||||||
|
``agent_action_log.turn_id`` is misnamed. ``ActionLogMiddleware`` writes
|
||||||
|
the LangChain ``tool_call.id`` into that column today (see
|
||||||
|
``action_log.py:_resolve_turn_id``), and ``kb_persistence._find_action_ids_batch``
|
||||||
|
joins on it as such. The real chat-turn id (``f"{chat_id}:{ms}"`` from
|
||||||
|
``stream_new_chat.py``) lives in ``config.configurable.turn_id`` and was
|
||||||
|
never persisted.
|
||||||
|
|
||||||
|
This migration introduces two new, correctly-named columns:
|
||||||
|
|
||||||
|
* ``tool_call_id`` (LangChain tool-call id, what ``turn_id`` actually held)
|
||||||
|
* ``chat_turn_id`` (the per-turn correlation id from
|
||||||
|
``configurable.turn_id`` — used by the per-turn ``revert-turn`` route).
|
||||||
|
|
||||||
|
Backfill copies the current ``turn_id`` values into ``tool_call_id`` so
|
||||||
|
existing joins keep working. The old ``turn_id`` column is left in place
|
||||||
|
for one release as a deprecated alias to give safe rollback. ``ActionLogMiddleware``
|
||||||
|
keeps writing it (= ``tool_call_id``) for the same reason.
|
||||||
|
|
||||||
|
Indexes
|
||||||
|
-------
|
||||||
|
|
||||||
|
* ``ix_agent_action_log_tool_call_id`` — required by
|
||||||
|
``_find_action_ids_batch`` (was on ``turn_id``).
|
||||||
|
* ``ix_agent_action_log_chat_turn_id`` — required by the
|
||||||
|
``revert-turn/{chat_turn_id}`` query.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "135"
|
||||||
|
down_revision: str | None = "134"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"agent_action_log",
|
||||||
|
sa.Column("tool_call_id", sa.String(length=64), nullable=True),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
"agent_action_log",
|
||||||
|
sa.Column("chat_turn_id", sa.String(length=64), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_index(
|
||||||
|
"ix_agent_action_log_tool_call_id",
|
||||||
|
"agent_action_log",
|
||||||
|
["tool_call_id"],
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_agent_action_log_chat_turn_id",
|
||||||
|
"agent_action_log",
|
||||||
|
["chat_turn_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"UPDATE agent_action_log SET tool_call_id = turn_id WHERE tool_call_id IS NULL"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_agent_action_log_chat_turn_id", table_name="agent_action_log")
|
||||||
|
op.drop_index("ix_agent_action_log_tool_call_id", table_name="agent_action_log")
|
||||||
|
op.drop_column("agent_action_log", "chat_turn_id")
|
||||||
|
op.drop_column("agent_action_log", "tool_call_id")
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
"""136_new_chat_message_turn_id
|
||||||
|
|
||||||
|
Revision ID: 136
|
||||||
|
Revises: 135
|
||||||
|
Create Date: 2026-04-29
|
||||||
|
|
||||||
|
Persist the per-turn correlation id on each chat message.
|
||||||
|
|
||||||
|
Background
|
||||||
|
----------
|
||||||
|
LangGraph's checkpointer stores user-provided ``configurable.turn_id``
|
||||||
|
in checkpoint metadata (see
|
||||||
|
``langgraph/checkpoint/base/__init__.py:get_checkpoint_metadata``). To
|
||||||
|
support edit-from-arbitrary-position, the regenerate route needs to map
|
||||||
|
a ``message_id`` -> ``turn_id`` -> checkpoint at request time. Without
|
||||||
|
this column the mapping doesn't exist anywhere, so regenerate would
|
||||||
|
have to hardcode the "last 2 messages" rewind heuristic.
|
||||||
|
|
||||||
|
This migration adds a nullable ``turn_id`` column to ``new_chat_messages``
|
||||||
|
plus an index. Legacy rows have NULL — the regenerate route degrades
|
||||||
|
gracefully to the reload-last-two heuristic for those.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "136"
|
||||||
|
down_revision: str | None = "135"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"new_chat_messages",
|
||||||
|
sa.Column("turn_id", sa.String(length=64), nullable=True),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_new_chat_messages_turn_id",
|
||||||
|
"new_chat_messages",
|
||||||
|
["turn_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_new_chat_messages_turn_id", table_name="new_chat_messages")
|
||||||
|
op.drop_column("new_chat_messages", "turn_id")
|
||||||
|
|
@ -0,0 +1,74 @@
|
||||||
|
"""137_unique_reverse_of_in_action_log
|
||||||
|
|
||||||
|
Revision ID: 137
|
||||||
|
Revises: 136
|
||||||
|
Create Date: 2026-04-29
|
||||||
|
|
||||||
|
Protect ``agent_action_log.reverse_of`` against double inserts. Two
|
||||||
|
concurrent revert calls (single-action route + the per-turn batch
|
||||||
|
route, or two batch routes racing) both pass the
|
||||||
|
``_was_already_reverted`` SELECT and both insert their own
|
||||||
|
``_revert:*`` rows. The application-level idempotency check is racy
|
||||||
|
because there's no DB constraint backing it.
|
||||||
|
|
||||||
|
This migration adds a partial unique index on ``reverse_of`` (PostgreSQL
|
||||||
|
``WHERE reverse_of IS NOT NULL``) so the second concurrent insert raises
|
||||||
|
``IntegrityError`` and the route can translate it to ``"already_reverted"``
|
||||||
|
deterministically.
|
||||||
|
|
||||||
|
The plain ``UniqueConstraint`` flavour can't be used because most
|
||||||
|
existing rows have ``reverse_of = NULL`` (only revert rows fill it),
|
||||||
|
and Postgres does treat NULL as distinct in unique indexes — but a
|
||||||
|
partial index is the cleanest expression of intent and works even on
|
||||||
|
older Postgres releases that distinguish NULL handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "137"
|
||||||
|
down_revision: str | None = "136"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
_INDEX_NAME = "ux_agent_action_log_reverse_of"
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Defensively de-dup any pre-existing double-revert rows before
|
||||||
|
# adding the unique index. Keeps the OLDEST row (smallest id) and
|
||||||
|
# NULLs out the duplicates' ``reverse_of`` so they survive as audit
|
||||||
|
# trail but no longer claim to be the canonical revert. We do NOT
|
||||||
|
# delete them — operators can still inspect them via /actions.
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
WITH dups AS (
|
||||||
|
SELECT id,
|
||||||
|
reverse_of,
|
||||||
|
ROW_NUMBER() OVER (
|
||||||
|
PARTITION BY reverse_of ORDER BY id ASC
|
||||||
|
) AS rn
|
||||||
|
FROM agent_action_log
|
||||||
|
WHERE reverse_of IS NOT NULL
|
||||||
|
)
|
||||||
|
UPDATE agent_action_log
|
||||||
|
SET reverse_of = NULL
|
||||||
|
WHERE id IN (SELECT id FROM dups WHERE rn > 1)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_index(
|
||||||
|
_INDEX_NAME,
|
||||||
|
"agent_action_log",
|
||||||
|
["reverse_of"],
|
||||||
|
unique=True,
|
||||||
|
postgresql_where="reverse_of IS NOT NULL",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index(_INDEX_NAME, table_name="agent_action_log")
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
"""138_add_thread_auto_model_pinning_fields
|
||||||
|
|
||||||
|
Revision ID: 138
|
||||||
|
Revises: 137
|
||||||
|
Create Date: 2026-04-30
|
||||||
|
|
||||||
|
Add a single thread-level column to persist the Auto (Fastest) model pin:
|
||||||
|
- pinned_llm_config_id: concrete resolved global LLM config id used for this
|
||||||
|
thread. NULL means "no pin; Auto will resolve on next turn".
|
||||||
|
|
||||||
|
The column is unindexed: all reads are by new_chat_threads.id (primary key),
|
||||||
|
so a secondary index would be dead write amplification.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "138"
|
||||||
|
down_revision: str | None = "137"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE new_chat_threads "
|
||||||
|
"ADD COLUMN IF NOT EXISTS pinned_llm_config_id INTEGER"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Drop any shape the thread row may be carrying. The extra columns and
|
||||||
|
# indexes only exist on dev DBs that ran an earlier draft of 138; IF EXISTS
|
||||||
|
# makes each statement a safe no-op on the lean shape.
|
||||||
|
op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_auto_mode")
|
||||||
|
op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_pinned_llm_config_id")
|
||||||
|
op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_at")
|
||||||
|
op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_auto_mode")
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS pinned_llm_config_id"
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,160 @@
|
||||||
|
"""add user table to zero_publication with column list
|
||||||
|
|
||||||
|
Adds the "user" table to zero_publication with a column-list publication
|
||||||
|
so that only the 5 fields driving the live usage meters are replicated
|
||||||
|
through WAL -> zero-cache -> browser IndexedDB:
|
||||||
|
|
||||||
|
id, pages_limit, pages_used,
|
||||||
|
premium_tokens_limit, premium_tokens_used
|
||||||
|
|
||||||
|
Sensitive columns (hashed_password, email, oauth_account, display_name,
|
||||||
|
avatar_url, memory_md, refresh_tokens, last_login, etc.) are NOT
|
||||||
|
included in the publication, so they never enter WAL replication.
|
||||||
|
|
||||||
|
Also re-asserts REPLICA IDENTITY DEFAULT on "user" for idempotency
|
||||||
|
(it is already DEFAULT today since "user" was never in the
|
||||||
|
TABLES_WITH_FULL_IDENTITY list of migration 117).
|
||||||
|
|
||||||
|
IMPORTANT - before AND after running this migration:
|
||||||
|
1. Stop zero-cache (it holds replication locks that will deadlock DDL)
|
||||||
|
2. Run: alembic upgrade head
|
||||||
|
3. Delete / reset the zero-cache data volume
|
||||||
|
4. Restart zero-cache (it will do a fresh initial sync)
|
||||||
|
|
||||||
|
Revision ID: 139
|
||||||
|
Revises: 138
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "139"
|
||||||
|
down_revision: str | None = "138"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
PUBLICATION_NAME = "zero_publication"
|
||||||
|
|
||||||
|
# Document column list as left by migration 117. Must match exactly.
|
||||||
|
DOCUMENT_COLS = [
|
||||||
|
"id",
|
||||||
|
"title",
|
||||||
|
"document_type",
|
||||||
|
"search_space_id",
|
||||||
|
"folder_id",
|
||||||
|
"created_by_id",
|
||||||
|
"status",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Five fields needed by the live usage meters (sidebar Tokens/Pages,
|
||||||
|
# Buy Tokens content). Keep this list narrow on purpose: anything added
|
||||||
|
# here flows into WAL and IndexedDB for every connected browser.
|
||||||
|
USER_COLS = [
|
||||||
|
"id",
|
||||||
|
"pages_limit",
|
||||||
|
"pages_used",
|
||||||
|
"premium_tokens_limit",
|
||||||
|
"premium_tokens_used",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _terminate_blocked_pids(conn, table: str) -> None:
|
||||||
|
"""Kill backends whose locks on *table* would block our AccessExclusiveLock."""
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT pg_terminate_backend(l.pid) "
|
||||||
|
"FROM pg_locks l "
|
||||||
|
"JOIN pg_class c ON c.oid = l.relation "
|
||||||
|
"WHERE c.relname = :tbl "
|
||||||
|
" AND l.pid != pg_backend_pid()"
|
||||||
|
),
|
||||||
|
{"tbl": table},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _has_zero_version(conn, table: str) -> bool:
|
||||||
|
return (
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"SELECT 1 FROM information_schema.columns "
|
||||||
|
"WHERE table_name = :tbl AND column_name = '_0_version'"
|
||||||
|
),
|
||||||
|
{"tbl": table},
|
||||||
|
).fetchone()
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_publication_ddl(
|
||||||
|
documents_has_zero_ver: bool, user_has_zero_ver: bool
|
||||||
|
) -> str:
|
||||||
|
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
|
||||||
|
user_cols = USER_COLS + (['"_0_version"'] if user_has_zero_ver else [])
|
||||||
|
doc_col_list = ", ".join(doc_cols)
|
||||||
|
user_col_list = ", ".join(user_cols)
|
||||||
|
return (
|
||||||
|
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
|
||||||
|
f"notifications, "
|
||||||
|
f"documents ({doc_col_list}), "
|
||||||
|
f"folders, "
|
||||||
|
f"search_source_connectors, "
|
||||||
|
f"new_chat_messages, "
|
||||||
|
f"chat_comments, "
|
||||||
|
f"chat_session_state, "
|
||||||
|
f'"user" ({user_col_list})'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_publication_ddl_without_user(documents_has_zero_ver: bool) -> str:
|
||||||
|
doc_cols = DOCUMENT_COLS + (['"_0_version"'] if documents_has_zero_ver else [])
|
||||||
|
doc_col_list = ", ".join(doc_cols)
|
||||||
|
return (
|
||||||
|
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
|
||||||
|
f"notifications, "
|
||||||
|
f"documents ({doc_col_list}), "
|
||||||
|
f"folders, "
|
||||||
|
f"search_source_connectors, "
|
||||||
|
f"new_chat_messages, "
|
||||||
|
f"chat_comments, "
|
||||||
|
f"chat_session_state"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
conn = op.get_bind()
|
||||||
|
# asyncpg requires LOCK TABLE inside a transaction block. Alembic already
|
||||||
|
# opened one via context.begin_transaction(), but the driver still errors
|
||||||
|
# unless we use an explicit SAVEPOINT (nested transaction) for this block.
|
||||||
|
tx = conn.begin_nested() if conn.in_transaction() else conn.begin()
|
||||||
|
with tx:
|
||||||
|
conn.execute(sa.text("SET lock_timeout = '10s'"))
|
||||||
|
|
||||||
|
_terminate_blocked_pids(conn, "user")
|
||||||
|
conn.execute(sa.text('LOCK TABLE "user" IN ACCESS EXCLUSIVE MODE'))
|
||||||
|
|
||||||
|
# Idempotent: "user" was never in TABLES_WITH_FULL_IDENTITY of
|
||||||
|
# migration 117, so this is already DEFAULT. Re-assert anyway so
|
||||||
|
# the column-list publication stays valid (DEFAULT identity only
|
||||||
|
# requires the PK to be in the column list).
|
||||||
|
conn.execute(sa.text('ALTER TABLE "user" REPLICA IDENTITY DEFAULT'))
|
||||||
|
|
||||||
|
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||||
|
|
||||||
|
documents_has_zero_ver = _has_zero_version(conn, "documents")
|
||||||
|
user_has_zero_ver = _has_zero_version(conn, "user")
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
sa.text(_build_publication_ddl(documents_has_zero_ver, user_has_zero_ver))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
conn = op.get_bind()
|
||||||
|
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||||
|
documents_has_zero_ver = _has_zero_version(conn, "documents")
|
||||||
|
conn.execute(sa.text(_build_publication_ddl_without_user(documents_has_zero_ver)))
|
||||||
|
|
@ -10,7 +10,9 @@ We use ``create_agent`` (from langchain) rather than ``create_deep_agent``
|
||||||
This lets us swap in ``SurfSenseFilesystemMiddleware`` — a customisable
|
This lets us swap in ``SurfSenseFilesystemMiddleware`` — a customisable
|
||||||
subclass of the default ``FilesystemMiddleware`` — while preserving every
|
subclass of the default ``FilesystemMiddleware`` — while preserving every
|
||||||
other behaviour that ``create_deep_agent`` provides (todo-list, subagents,
|
other behaviour that ``create_deep_agent`` provides (todo-list, subagents,
|
||||||
summarisation, prompt-caching, etc.).
|
summarisation, etc.). Prompt caching is configured at LLM-build time via
|
||||||
|
``apply_litellm_prompt_caching`` (LiteLLM-native, multi-provider) rather
|
||||||
|
than as a middleware.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -33,7 +35,6 @@ from langchain.agents.middleware import (
|
||||||
TodoListMiddleware,
|
TodoListMiddleware,
|
||||||
ToolCallLimitMiddleware,
|
ToolCallLimitMiddleware,
|
||||||
)
|
)
|
||||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langgraph.types import Checkpointer
|
from langgraph.types import Checkpointer
|
||||||
|
|
@ -74,6 +75,7 @@ from app.agents.new_chat.plugin_loader import (
|
||||||
load_allowed_plugin_names_from_env,
|
load_allowed_plugin_names_from_env,
|
||||||
load_plugin_middlewares,
|
load_plugin_middlewares,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
|
||||||
from app.agents.new_chat.subagents import build_specialized_subagents
|
from app.agents.new_chat.subagents import build_specialized_subagents
|
||||||
from app.agents.new_chat.system_prompt import (
|
from app.agents.new_chat.system_prompt import (
|
||||||
build_configurable_system_prompt,
|
build_configurable_system_prompt,
|
||||||
|
|
@ -94,6 +96,39 @@ from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
_perf_log = get_perf_logger()
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_prompt_model_name(
|
||||||
|
agent_config: AgentConfig | None,
|
||||||
|
llm: BaseChatModel,
|
||||||
|
) -> str | None:
|
||||||
|
"""Resolve the model id to feed to provider-variant detection.
|
||||||
|
|
||||||
|
Preference order (matches the established idiom in
|
||||||
|
``llm_router_service.py`` — see ``params.get("base_model") or
|
||||||
|
params.get("model", "")`` usages there):
|
||||||
|
|
||||||
|
1. ``agent_config.litellm_params["base_model"]`` — required for Azure
|
||||||
|
deployments where ``model_name`` is the deployment slug, not the
|
||||||
|
underlying family. Without this, a deployment named e.g.
|
||||||
|
``"prod-chat-001"`` would silently miss every provider regex.
|
||||||
|
2. ``agent_config.model_name`` — the user's configured model id.
|
||||||
|
3. ``getattr(llm, "model", None)`` — fallback for direct callers that
|
||||||
|
don't supply an ``AgentConfig`` (currently a defensive path; all
|
||||||
|
production callers pass ``agent_config``).
|
||||||
|
|
||||||
|
Returns ``None`` when nothing is available; ``compose_system_prompt``
|
||||||
|
treats that as the ``"default"`` variant (no provider block emitted).
|
||||||
|
"""
|
||||||
|
if agent_config is not None:
|
||||||
|
params = agent_config.litellm_params or {}
|
||||||
|
base_model = params.get("base_model")
|
||||||
|
if isinstance(base_model, str) and base_model.strip():
|
||||||
|
return base_model
|
||||||
|
if agent_config.model_name:
|
||||||
|
return agent_config.model_name
|
||||||
|
return getattr(llm, "model", None)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Connector Type Mapping
|
# Connector Type Mapping
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -279,6 +314,14 @@ async def create_surfsense_deep_agent(
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
_t_agent_total = time.perf_counter()
|
_t_agent_total = time.perf_counter()
|
||||||
|
|
||||||
|
# Layer thread-aware prompt caching onto the LLM. Idempotent with the
|
||||||
|
# build-time call in ``llm_config.py``; this run merely adds
|
||||||
|
# ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` for OpenAI-family
|
||||||
|
# configs now that ``thread_id`` is known. No-op when ``thread_id`` is
|
||||||
|
# None or the provider is non-OpenAI-family.
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=agent_config, thread_id=thread_id)
|
||||||
|
|
||||||
filesystem_selection = filesystem_selection or FilesystemSelection()
|
filesystem_selection = filesystem_selection or FilesystemSelection()
|
||||||
backend_resolver = build_backend_resolver(
|
backend_resolver = build_backend_resolver(
|
||||||
filesystem_selection,
|
filesystem_selection,
|
||||||
|
|
@ -398,6 +441,7 @@ async def create_surfsense_deep_agent(
|
||||||
enabled_tool_names=_enabled_tool_names,
|
enabled_tool_names=_enabled_tool_names,
|
||||||
disabled_tool_names=_user_disabled_tool_names,
|
disabled_tool_names=_user_disabled_tool_names,
|
||||||
mcp_connector_tools=_mcp_connector_tools,
|
mcp_connector_tools=_mcp_connector_tools,
|
||||||
|
model_name=_resolve_prompt_model_name(agent_config, llm),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
system_prompt = build_surfsense_system_prompt(
|
system_prompt = build_surfsense_system_prompt(
|
||||||
|
|
@ -405,6 +449,7 @@ async def create_surfsense_deep_agent(
|
||||||
enabled_tool_names=_enabled_tool_names,
|
enabled_tool_names=_enabled_tool_names,
|
||||||
disabled_tool_names=_user_disabled_tool_names,
|
disabled_tool_names=_user_disabled_tool_names,
|
||||||
mcp_connector_tools=_mcp_connector_tools,
|
mcp_connector_tools=_mcp_connector_tools,
|
||||||
|
model_name=_resolve_prompt_model_name(agent_config, llm),
|
||||||
)
|
)
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
|
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
|
||||||
|
|
@ -568,7 +613,6 @@ def _build_compiled_agent_blocking(
|
||||||
),
|
),
|
||||||
create_surfsense_compaction_middleware(llm, StateBackend),
|
create_surfsense_compaction_middleware(llm, StateBackend),
|
||||||
PatchToolCallsMiddleware(),
|
PatchToolCallsMiddleware(),
|
||||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key]
|
general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key]
|
||||||
|
|
@ -724,7 +768,8 @@ def _build_compiled_agent_blocking(
|
||||||
repair_mw = None
|
repair_mw = None
|
||||||
if flags.enable_tool_call_repair and not flags.disable_new_agent_stack:
|
if flags.enable_tool_call_repair and not flags.disable_new_agent_stack:
|
||||||
registered_names: set[str] = {t.name for t in tools}
|
registered_names: set[str] = {t.name for t in tools}
|
||||||
# Tools owned by the standard deepagents middleware stack.
|
# Tools owned by the standard deepagents middleware stack and the
|
||||||
|
# SurfSense filesystem extension.
|
||||||
registered_names |= {
|
registered_names |= {
|
||||||
"write_todos",
|
"write_todos",
|
||||||
"ls",
|
"ls",
|
||||||
|
|
@ -735,6 +780,14 @@ def _build_compiled_agent_blocking(
|
||||||
"grep",
|
"grep",
|
||||||
"execute",
|
"execute",
|
||||||
"task",
|
"task",
|
||||||
|
"mkdir",
|
||||||
|
"cd",
|
||||||
|
"pwd",
|
||||||
|
"move_file",
|
||||||
|
"rm",
|
||||||
|
"rmdir",
|
||||||
|
"list_tree",
|
||||||
|
"execute_code",
|
||||||
}
|
}
|
||||||
repair_mw = ToolCallNameRepairMiddleware(
|
repair_mw = ToolCallNameRepairMiddleware(
|
||||||
registered_tool_names=registered_names,
|
registered_tool_names=registered_names,
|
||||||
|
|
@ -763,25 +816,51 @@ def _build_compiled_agent_blocking(
|
||||||
# on every safe read-only call (``ls``, ``read_file``, ``grep``,
|
# on every safe read-only call (``ls``, ``read_file``, ``grep``,
|
||||||
# ``glob``, ``web_search`` …) and, on resume, replay the previous
|
# ``glob``, ``web_search`` …) and, on resume, replay the previous
|
||||||
# reject decision into innocent calls.
|
# reject decision into innocent calls.
|
||||||
# 2. ``connector_synthesized`` — deny rules for tools whose required
|
# 2. ``desktop_safety`` — ``ask`` for destructive filesystem ops when
|
||||||
# connector is not connected to this space. Overrides #1.
|
# the agent is operating against the user's real disk. Cloud mode
|
||||||
# 3. (future) user-defined rules from ``agent_permission_rules`` table
|
# has full revision-based revert via ``revert_service``, but
|
||||||
# via the Agent Permissions UI. Loaded last so they override both.
|
# desktop mode hits disk immediately with no undo, so an
|
||||||
|
# accidental ``rm`` / ``rmdir`` / ``move_file`` / ``edit_file`` /
|
||||||
|
# ``write_file`` is unrecoverable. This layer is forced on in
|
||||||
|
# desktop mode regardless of ``enable_permission`` because the
|
||||||
|
# safety net is non-negotiable.
|
||||||
|
# 3. ``connector_synthesized`` — deny rules for tools whose required
|
||||||
|
# connector is not connected to this space. Overrides #1/#2.
|
||||||
|
# 4. (future) user-defined rules from ``agent_permission_rules`` table
|
||||||
|
# via the Agent Permissions UI. Loaded last so they override all.
|
||||||
permission_mw: PermissionMiddleware | None = None
|
permission_mw: PermissionMiddleware | None = None
|
||||||
if flags.enable_permission and not flags.disable_new_agent_stack:
|
is_desktop_fs = filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
|
||||||
synthesized = _synthesize_connector_deny_rules(
|
permission_enabled = flags.enable_permission and not flags.disable_new_agent_stack
|
||||||
available_connectors=available_connectors,
|
# Build the middleware whenever it has work to do: either the user
|
||||||
enabled_tool_names={t.name for t in tools},
|
# opted into the rule engine, OR we're in desktop mode and need the
|
||||||
)
|
# safety rules unconditionally.
|
||||||
permission_mw = PermissionMiddleware(
|
if permission_enabled or is_desktop_fs:
|
||||||
rulesets=[
|
rulesets: list[Ruleset] = [
|
||||||
Ruleset(
|
Ruleset(
|
||||||
rules=[Rule(permission="*", pattern="*", action="allow")],
|
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||||
origin="surfsense_defaults",
|
origin="surfsense_defaults",
|
||||||
),
|
),
|
||||||
Ruleset(rules=synthesized, origin="connector_synthesized"),
|
]
|
||||||
|
if is_desktop_fs:
|
||||||
|
rulesets.append(
|
||||||
|
Ruleset(
|
||||||
|
rules=[
|
||||||
|
Rule(permission="rm", pattern="*", action="ask"),
|
||||||
|
Rule(permission="rmdir", pattern="*", action="ask"),
|
||||||
|
Rule(permission="move_file", pattern="*", action="ask"),
|
||||||
|
Rule(permission="edit_file", pattern="*", action="ask"),
|
||||||
|
Rule(permission="write_file", pattern="*", action="ask"),
|
||||||
],
|
],
|
||||||
|
origin="desktop_safety",
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
if permission_enabled:
|
||||||
|
synthesized = _synthesize_connector_deny_rules(
|
||||||
|
available_connectors=available_connectors,
|
||||||
|
enabled_tool_names={t.name for t in tools},
|
||||||
|
)
|
||||||
|
rulesets.append(Ruleset(rules=synthesized, origin="connector_synthesized"))
|
||||||
|
permission_mw = PermissionMiddleware(rulesets=rulesets)
|
||||||
|
|
||||||
# ActionLogMiddleware. Off by default until the ``agent_action_log``
|
# ActionLogMiddleware. Off by default until the ``agent_action_log``
|
||||||
# table is migrated. When enabled, persists one row per tool call
|
# table is migrated. When enabled, persists one row per tool call
|
||||||
|
|
@ -938,6 +1017,7 @@ def _build_compiled_agent_blocking(
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
created_by_id=user_id,
|
created_by_id=user_id,
|
||||||
filesystem_mode=filesystem_mode,
|
filesystem_mode=filesystem_mode,
|
||||||
|
thread_id=thread_id,
|
||||||
)
|
)
|
||||||
if filesystem_mode == FilesystemMode.CLOUD
|
if filesystem_mode == FilesystemMode.CLOUD
|
||||||
else None,
|
else None,
|
||||||
|
|
@ -970,12 +1050,12 @@ def _build_compiled_agent_blocking(
|
||||||
action_log_mw,
|
action_log_mw,
|
||||||
PatchToolCallsMiddleware(),
|
PatchToolCallsMiddleware(),
|
||||||
DedupHITLToolCallsMiddleware(agent_tools=list(tools)),
|
DedupHITLToolCallsMiddleware(agent_tools=list(tools)),
|
||||||
# Plugin slot — sits just before AnthropicCache so plugin-side
|
# Plugin slot — sits at the tail so plugin-side transforms see the
|
||||||
# transforms see the final tool result and run before any
|
# final tool result. Prompt caching is now applied at LLM build time
|
||||||
# caching heuristics. Multiple plugins in declared order; loader
|
# via ``apply_litellm_prompt_caching`` (see prompt_caching.py), so no
|
||||||
# filtered by the admin allowlist already.
|
# caching middleware is needed here. Multiple plugins run in declared
|
||||||
|
# order; loader filtered by the admin allowlist already.
|
||||||
*plugin_middlewares,
|
*plugin_middlewares,
|
||||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
|
||||||
]
|
]
|
||||||
deepagent_middleware = [m for m in deepagent_middleware if m is not None]
|
deepagent_middleware = [m for m in deepagent_middleware if m is not None]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ Local development (recommended for trying everything except doom-loop / selector
|
||||||
SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy
|
SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy
|
||||||
SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships
|
SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships
|
||||||
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false
|
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false
|
||||||
|
SURFSENSE_ENABLE_STREAM_PARITY_V2=false # structured streaming events
|
||||||
|
|
||||||
Master kill-switch (overrides everything else):
|
Master kill-switch (overrides everything else):
|
||||||
|
|
||||||
|
|
@ -86,6 +87,15 @@ class AgentFeatureFlags:
|
||||||
False # Backend ships before UI; route returns 503 until this flips
|
False # Backend ships before UI; route returns 503 until this flips
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Streaming parity v2 — opt in to LangChain's structured
|
||||||
|
# ``AIMessageChunk`` content (typed reasoning blocks, tool-input
|
||||||
|
# deltas) and propagate the real ``tool_call_id`` to the SSE layer.
|
||||||
|
# When OFF the ``stream_new_chat`` task falls back to the str-only
|
||||||
|
# text path and the synthetic ``call_<run_id>`` tool-call id (no
|
||||||
|
# ``langchainToolCallId`` propagation). Schema migrations 135/136
|
||||||
|
# ship unconditionally because they're forward-compatible.
|
||||||
|
enable_stream_parity_v2: bool = False
|
||||||
|
|
||||||
# Plugins
|
# Plugins
|
||||||
enable_plugin_loader: bool = False
|
enable_plugin_loader: bool = False
|
||||||
|
|
||||||
|
|
@ -139,6 +149,10 @@ class AgentFeatureFlags:
|
||||||
# Snapshot / revert
|
# Snapshot / revert
|
||||||
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False),
|
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False),
|
||||||
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False),
|
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False),
|
||||||
|
# Streaming parity v2
|
||||||
|
enable_stream_parity_v2=_env_bool(
|
||||||
|
"SURFSENSE_ENABLE_STREAM_PARITY_V2", False
|
||||||
|
),
|
||||||
# Plugins
|
# Plugins
|
||||||
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
|
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
|
||||||
# Observability
|
# Observability
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,14 @@ extra fields needed to implement Postgres-backed virtual filesystem semantics:
|
||||||
|
|
||||||
* ``cwd`` — current working directory (per-thread checkpointed).
|
* ``cwd`` — current working directory (per-thread checkpointed).
|
||||||
* ``staged_dirs`` — pending mkdir requests (cloud only).
|
* ``staged_dirs`` — pending mkdir requests (cloud only).
|
||||||
|
* ``staged_dir_tool_calls`` — sidecar map ``path -> tool_call_id`` for staged dirs.
|
||||||
* ``pending_moves`` — pending move_file requests (cloud only).
|
* ``pending_moves`` — pending move_file requests (cloud only).
|
||||||
|
* ``pending_deletes`` — pending ``rm`` requests (cloud only).
|
||||||
|
* ``pending_dir_deletes`` — pending ``rmdir`` requests (cloud only).
|
||||||
* ``doc_id_by_path`` — virtual_path -> Document.id, populated by lazy reads.
|
* ``doc_id_by_path`` — virtual_path -> Document.id, populated by lazy reads.
|
||||||
* ``dirty_paths`` — paths whose state file content differs from DB.
|
* ``dirty_paths`` — paths whose state file content differs from DB.
|
||||||
|
* ``dirty_path_tool_calls`` — sidecar map ``path -> latest tool_call_id`` for
|
||||||
|
dirty paths; used to bind the per-path snapshot to an action_id.
|
||||||
* ``kb_priority`` — top-K priority hints rendered into a system message.
|
* ``kb_priority`` — top-K priority hints rendered into a system message.
|
||||||
* ``kb_matched_chunk_ids`` — internal hand-off for matched-chunk highlighting.
|
* ``kb_matched_chunk_ids`` — internal hand-off for matched-chunk highlighting.
|
||||||
* ``kb_anon_doc`` — Redis-loaded anonymous document (if any).
|
* ``kb_anon_doc`` — Redis-loaded anonymous document (if any).
|
||||||
|
|
@ -32,12 +37,31 @@ from app.agents.new_chat.state_reducers import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PendingMove(TypedDict):
|
class PendingMove(TypedDict, total=False):
|
||||||
"""A staged move_file operation pending end-of-turn commit."""
|
"""A staged move_file operation pending end-of-turn commit.
|
||||||
|
|
||||||
|
``tool_call_id`` is optional for backward compatibility with checkpoints
|
||||||
|
written before the snapshot/revert pipeline was wired up; new entries
|
||||||
|
always include it so the persistence body can resolve an action_id.
|
||||||
|
"""
|
||||||
|
|
||||||
source: str
|
source: str
|
||||||
dest: str
|
dest: str
|
||||||
overwrite: bool
|
overwrite: bool
|
||||||
|
tool_call_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class PendingDelete(TypedDict, total=False):
|
||||||
|
"""A staged ``rm`` or ``rmdir`` operation pending end-of-turn commit.
|
||||||
|
|
||||||
|
``tool_call_id`` is required for new entries (it's the binding key used
|
||||||
|
by :class:`KnowledgeBasePersistenceMiddleware` to find the matching
|
||||||
|
:class:`AgentActionLog` row and bind the snapshot to it). Marked
|
||||||
|
``total=False`` only to tolerate older checkpoint payloads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
path: str
|
||||||
|
tool_call_id: str
|
||||||
|
|
||||||
|
|
||||||
class KbPriorityEntry(TypedDict, total=False):
|
class KbPriorityEntry(TypedDict, total=False):
|
||||||
|
|
@ -76,9 +100,38 @@ class SurfSenseFilesystemState(FilesystemState):
|
||||||
staged_dirs: NotRequired[Annotated[list[str], _add_unique_reducer]]
|
staged_dirs: NotRequired[Annotated[list[str], _add_unique_reducer]]
|
||||||
"""mkdir paths staged for end-of-turn folder creation (cloud only)."""
|
"""mkdir paths staged for end-of-turn folder creation (cloud only)."""
|
||||||
|
|
||||||
|
staged_dir_tool_calls: NotRequired[
|
||||||
|
Annotated[dict[str, str], _dict_merge_with_tombstones_reducer]
|
||||||
|
]
|
||||||
|
"""``path -> tool_call_id`` sidecar for ``staged_dirs``.
|
||||||
|
|
||||||
|
Used by :class:`KnowledgeBasePersistenceMiddleware` to bind the
|
||||||
|
:class:`FolderRevision` snapshot to the originating ``mkdir`` action.
|
||||||
|
Kept separate from ``staged_dirs`` (which stays a unique-string list)
|
||||||
|
to avoid breaking ``_add_unique_reducer`` semantics.
|
||||||
|
"""
|
||||||
|
|
||||||
pending_moves: NotRequired[Annotated[list[PendingMove], _list_append_reducer]]
|
pending_moves: NotRequired[Annotated[list[PendingMove], _list_append_reducer]]
|
||||||
"""move_file ops staged for end-of-turn commit (cloud only)."""
|
"""move_file ops staged for end-of-turn commit (cloud only)."""
|
||||||
|
|
||||||
|
pending_deletes: NotRequired[Annotated[list[PendingDelete], _list_append_reducer]]
|
||||||
|
"""``rm`` ops staged for end-of-turn ``DELETE FROM documents`` (cloud only).
|
||||||
|
|
||||||
|
Each entry is a dict ``{"path": ..., "tool_call_id": ...}``. Per-path
|
||||||
|
uniqueness is enforced inside the commit body, not the reducer (we keep
|
||||||
|
``tool_call_id`` per occurrence so snapshot binding works).
|
||||||
|
"""
|
||||||
|
|
||||||
|
pending_dir_deletes: NotRequired[
|
||||||
|
Annotated[list[PendingDelete], _list_append_reducer]
|
||||||
|
]
|
||||||
|
"""``rmdir`` ops staged for end-of-turn ``DELETE FROM folders`` (cloud only).
|
||||||
|
|
||||||
|
Same shape as :data:`pending_deletes`. Commit body re-verifies the
|
||||||
|
folder is empty (in-DB AND with this turn's pending changes accounted
|
||||||
|
for) before issuing the DELETE.
|
||||||
|
"""
|
||||||
|
|
||||||
doc_id_by_path: NotRequired[
|
doc_id_by_path: NotRequired[
|
||||||
Annotated[dict[str, int], _dict_merge_with_tombstones_reducer]
|
Annotated[dict[str, int], _dict_merge_with_tombstones_reducer]
|
||||||
]
|
]
|
||||||
|
|
@ -92,6 +145,17 @@ class SurfSenseFilesystemState(FilesystemState):
|
||||||
dirty_paths: NotRequired[Annotated[list[str], _add_unique_reducer]]
|
dirty_paths: NotRequired[Annotated[list[str], _add_unique_reducer]]
|
||||||
"""Paths whose ``state["files"]`` content has been modified this turn."""
|
"""Paths whose ``state["files"]`` content has been modified this turn."""
|
||||||
|
|
||||||
|
dirty_path_tool_calls: NotRequired[
|
||||||
|
Annotated[dict[str, str], _dict_merge_with_tombstones_reducer]
|
||||||
|
]
|
||||||
|
"""``path -> latest tool_call_id`` sidecar for ``dirty_paths``.
|
||||||
|
|
||||||
|
The persistence body coalesces multiple writes/edits to the same path
|
||||||
|
into one snapshot per turn. This map captures the most-recent
|
||||||
|
``tool_call_id`` so the resulting :class:`DocumentRevision` is bound
|
||||||
|
to the latest action_id (the one the user is most likely to revert).
|
||||||
|
"""
|
||||||
|
|
||||||
kb_priority: NotRequired[Annotated[list[KbPriorityEntry], _replace_reducer]]
|
kb_priority: NotRequired[Annotated[list[KbPriorityEntry], _replace_reducer]]
|
||||||
"""Top-K priority hints rendered as a system message before the user turn."""
|
"""Top-K priority hints rendered as a system message before the user turn."""
|
||||||
|
|
||||||
|
|
@ -108,6 +172,7 @@ class SurfSenseFilesystemState(FilesystemState):
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"KbAnonDoc",
|
"KbAnonDoc",
|
||||||
"KbPriorityEntry",
|
"KbPriorityEntry",
|
||||||
|
"PendingDelete",
|
||||||
"PendingMove",
|
"PendingMove",
|
||||||
"SurfSenseFilesystemState",
|
"SurfSenseFilesystemState",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ from litellm import get_model_info
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
|
||||||
from app.services.llm_router_service import (
|
from app.services.llm_router_service import (
|
||||||
AUTO_MODE_ID,
|
AUTO_MODE_ID,
|
||||||
ChatLiteLLMRouter,
|
ChatLiteLLMRouter,
|
||||||
|
|
@ -494,6 +495,11 @@ def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
||||||
|
|
||||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
_attach_model_profile(llm, model_string)
|
_attach_model_profile(llm, model_string)
|
||||||
|
# Configure LiteLLM-native prompt caching (cache_control_injection_points
|
||||||
|
# for Anthropic/Bedrock/Vertex/Gemini/Azure-AI/OpenRouter/Databricks/etc.).
|
||||||
|
# ``agent_config=None`` here — the YAML path doesn't have provider intent
|
||||||
|
# in a structured form, so we set only the universal injection points.
|
||||||
|
apply_litellm_prompt_caching(llm)
|
||||||
return llm
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -518,7 +524,16 @@ def create_chat_litellm_from_agent_config(
|
||||||
print("Error: Auto mode requested but LLM Router not initialized")
|
print("Error: Auto mode requested but LLM Router not initialized")
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
return get_auto_mode_llm()
|
router_llm = get_auto_mode_llm()
|
||||||
|
if router_llm is not None:
|
||||||
|
# Universal cache_control_injection_points only — auto-mode
|
||||||
|
# fans out across providers, so OpenAI-only kwargs (e.g.
|
||||||
|
# ``prompt_cache_key``) are left off here. ``drop_params``
|
||||||
|
# would strip them at the provider boundary anyway, but
|
||||||
|
# there's no point setting them when we don't know the
|
||||||
|
# destination.
|
||||||
|
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
|
||||||
|
return router_llm
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error creating ChatLiteLLMRouter: {e}")
|
print(f"Error creating ChatLiteLLMRouter: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
@ -549,4 +564,9 @@ def create_chat_litellm_from_agent_config(
|
||||||
|
|
||||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||||
_attach_model_profile(llm, model_string)
|
_attach_model_profile(llm, model_string)
|
||||||
|
# Build-time prompt caching: sets ``cache_control_injection_points`` for
|
||||||
|
# all providers and (for OpenAI/DeepSeek/xAI) ``prompt_cache_retention``.
|
||||||
|
# Per-thread ``prompt_cache_key`` is layered on later in
|
||||||
|
# ``create_surfsense_deep_agent`` once ``thread_id`` is known.
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=agent_config)
|
||||||
return llm
|
return llm
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ from collections.abc import Awaitable, Callable
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
|
from langchain_core.callbacks import adispatch_custom_event
|
||||||
from langchain_core.messages import ToolMessage
|
from langchain_core.messages import ToolMessage
|
||||||
|
|
||||||
from app.agents.new_chat.feature_flags import get_flags
|
from app.agents.new_chat.feature_flags import get_flags
|
||||||
|
|
@ -144,11 +145,19 @@ class ActionLogMiddleware(AgentMiddleware):
|
||||||
result=result,
|
result=result,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tool_call_id = _resolve_tool_call_id(request)
|
||||||
|
chat_turn_id = _resolve_chat_turn_id(request)
|
||||||
|
|
||||||
row = AgentActionLog(
|
row = AgentActionLog(
|
||||||
thread_id=self._thread_id,
|
thread_id=self._thread_id,
|
||||||
user_id=self._user_id,
|
user_id=self._user_id,
|
||||||
search_space_id=self._search_space_id,
|
search_space_id=self._search_space_id,
|
||||||
turn_id=_resolve_turn_id(request),
|
# ``turn_id`` is the deprecated alias of ``tool_call_id``
|
||||||
|
# kept for one release for safe rollback. New consumers
|
||||||
|
# should read ``tool_call_id`` directly.
|
||||||
|
turn_id=tool_call_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
chat_turn_id=chat_turn_id,
|
||||||
message_id=_resolve_message_id(request),
|
message_id=_resolve_message_id(request),
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
args=args_payload,
|
args=args_payload,
|
||||||
|
|
@ -160,11 +169,41 @@ class ActionLogMiddleware(AgentMiddleware):
|
||||||
async with shielded_async_session() as session:
|
async with shielded_async_session() as session:
|
||||||
session.add(row)
|
session.add(row)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
row_id = int(row.id) if row.id is not None else None
|
||||||
|
row_created_at = row.created_at
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"ActionLogMiddleware failed to persist action log row",
|
"ActionLogMiddleware failed to persist action log row",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Surface a side-channel SSE event so the chat tool card can
|
||||||
|
# render a Revert button immediately after the row is durable.
|
||||||
|
# ``stream_new_chat`` translates this into a
|
||||||
|
# ``data-action-log`` SSE event. We DO NOT include the
|
||||||
|
# ``reverse_descriptor`` payload here; only a presence flag.
|
||||||
|
try:
|
||||||
|
await adispatch_custom_event(
|
||||||
|
"action_log",
|
||||||
|
{
|
||||||
|
"id": row_id,
|
||||||
|
"lc_tool_call_id": tool_call_id,
|
||||||
|
"chat_turn_id": chat_turn_id,
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"reversible": bool(reversible),
|
||||||
|
"reverse_descriptor_present": reverse_descriptor is not None,
|
||||||
|
"created_at": row_created_at.isoformat()
|
||||||
|
if row_created_at
|
||||||
|
else None,
|
||||||
|
"error": error_payload is not None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
"ActionLogMiddleware failed to dispatch action_log event",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
def _render_reverse(
|
def _render_reverse(
|
||||||
self,
|
self,
|
||||||
|
|
@ -254,7 +293,8 @@ def _resolve_args_payload(request: Any) -> dict[str, Any] | None:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _resolve_turn_id(request: Any) -> str | None:
|
def _resolve_tool_call_id(request: Any) -> str | None:
|
||||||
|
"""Return the LangChain ``tool_call.id`` for this request, if any."""
|
||||||
try:
|
try:
|
||||||
call = getattr(request, "tool_call", None) or {}
|
call = getattr(request, "tool_call", None) or {}
|
||||||
if isinstance(call, dict):
|
if isinstance(call, dict):
|
||||||
|
|
@ -266,9 +306,40 @@ def _resolve_turn_id(request: Any) -> str | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Deprecated alias kept for one release. Old callers and tests treated
|
||||||
|
# ``turn_id`` as if it carried the LangChain tool_call id; the new column
|
||||||
|
# lives under ``tool_call_id``. Both resolve to the same value today.
|
||||||
|
_resolve_turn_id = _resolve_tool_call_id
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_chat_turn_id(request: Any) -> str | None:
|
||||||
|
"""Return ``configurable.turn_id`` for this request, if accessible.
|
||||||
|
|
||||||
|
``ToolRuntime.config`` is exposed by LangGraph (see
|
||||||
|
``langgraph/prebuilt/tool_node.py``); the chat-turn correlation id
|
||||||
|
lives at ``runtime.config["configurable"]["turn_id"]``.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
runtime = getattr(request, "runtime", None)
|
||||||
|
if runtime is None:
|
||||||
|
return None
|
||||||
|
config = getattr(runtime, "config", None)
|
||||||
|
if not isinstance(config, dict):
|
||||||
|
return None
|
||||||
|
configurable = config.get("configurable")
|
||||||
|
if not isinstance(configurable, dict):
|
||||||
|
return None
|
||||||
|
value = configurable.get("turn_id")
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
return value
|
||||||
|
except Exception: # pragma: no cover - defensive
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _resolve_message_id(request: Any) -> str | None:
|
def _resolve_message_id(request: Any) -> str | None:
|
||||||
"""Tool-call IDs serve as best-available message correlator at this layer."""
|
"""Tool-call IDs serve as best-available message correlator at this layer."""
|
||||||
return _resolve_turn_id(request)
|
return _resolve_tool_call_id(request)
|
||||||
|
|
||||||
|
|
||||||
def _resolve_result_id(result: Any) -> str | None:
|
def _resolve_result_id(result: Any) -> str | None:
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -58,6 +59,11 @@ class _ThreadLockManager:
|
||||||
weakref.WeakValueDictionary()
|
weakref.WeakValueDictionary()
|
||||||
)
|
)
|
||||||
self._cancel_events: dict[str, asyncio.Event] = {}
|
self._cancel_events: dict[str, asyncio.Event] = {}
|
||||||
|
self._cancel_requested_at_ms: dict[str, int] = {}
|
||||||
|
self._cancel_attempt_count: dict[str, int] = {}
|
||||||
|
# Monotonic per-thread epoch used to prevent stale middleware
|
||||||
|
# teardown from releasing a newer turn's lock.
|
||||||
|
self._turn_epoch: dict[str, int] = {}
|
||||||
|
|
||||||
def lock_for(self, thread_id: str) -> asyncio.Lock:
|
def lock_for(self, thread_id: str) -> asyncio.Lock:
|
||||||
lock = self._locks.get(thread_id)
|
lock = self._locks.get(thread_id)
|
||||||
|
|
@ -76,14 +82,57 @@ class _ThreadLockManager:
|
||||||
def request_cancel(self, thread_id: str) -> bool:
|
def request_cancel(self, thread_id: str) -> bool:
|
||||||
event = self._cancel_events.get(thread_id)
|
event = self._cancel_events.get(thread_id)
|
||||||
if event is None:
|
if event is None:
|
||||||
return False
|
event = asyncio.Event()
|
||||||
|
self._cancel_events[thread_id] = event
|
||||||
event.set()
|
event.set()
|
||||||
|
now_ms = int(time.time() * 1000)
|
||||||
|
self._cancel_requested_at_ms[thread_id] = now_ms
|
||||||
|
self._cancel_attempt_count[thread_id] = (
|
||||||
|
self._cancel_attempt_count.get(thread_id, 0) + 1
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def is_cancel_requested(self, thread_id: str) -> bool:
|
||||||
|
event = self._cancel_events.get(thread_id)
|
||||||
|
return bool(event and event.is_set())
|
||||||
|
|
||||||
|
def cancel_state(self, thread_id: str) -> tuple[int, int] | None:
|
||||||
|
if not self.is_cancel_requested(thread_id):
|
||||||
|
return None
|
||||||
|
attempts = self._cancel_attempt_count.get(thread_id, 1)
|
||||||
|
requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0)
|
||||||
|
return attempts, requested_at_ms
|
||||||
|
|
||||||
def reset(self, thread_id: str) -> None:
|
def reset(self, thread_id: str) -> None:
|
||||||
event = self._cancel_events.get(thread_id)
|
event = self._cancel_events.get(thread_id)
|
||||||
if event is not None:
|
if event is not None:
|
||||||
event.clear()
|
event.clear()
|
||||||
|
self._cancel_requested_at_ms.pop(thread_id, None)
|
||||||
|
self._cancel_attempt_count.pop(thread_id, None)
|
||||||
|
|
||||||
|
def bump_turn_epoch(self, thread_id: str) -> int:
|
||||||
|
epoch = self._turn_epoch.get(thread_id, 0) + 1
|
||||||
|
self._turn_epoch[thread_id] = epoch
|
||||||
|
return epoch
|
||||||
|
|
||||||
|
def current_turn_epoch(self, thread_id: str) -> int:
|
||||||
|
return self._turn_epoch.get(thread_id, 0)
|
||||||
|
|
||||||
|
def end_turn(self, thread_id: str) -> None:
|
||||||
|
"""Best-effort terminal cleanup for a thread turn.
|
||||||
|
|
||||||
|
This is intentionally idempotent and safe to call from outer stream
|
||||||
|
finally-blocks where middleware teardown might be skipped due to abort
|
||||||
|
or disconnect edge-cases.
|
||||||
|
"""
|
||||||
|
# Invalidate any in-flight middleware holder first. This guarantees a
|
||||||
|
# stale ``aafter_agent`` from an older attempt cannot unlock a newer
|
||||||
|
# retry that already acquired the lock for the same thread.
|
||||||
|
self.bump_turn_epoch(thread_id)
|
||||||
|
lock = self._locks.get(thread_id)
|
||||||
|
if lock is not None and lock.locked():
|
||||||
|
lock.release()
|
||||||
|
self.reset(thread_id)
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton — process-local but reused across all agent
|
# Module-level singleton — process-local but reused across all agent
|
||||||
|
|
@ -98,15 +147,30 @@ def get_cancel_event(thread_id: str) -> asyncio.Event:
|
||||||
|
|
||||||
|
|
||||||
def request_cancel(thread_id: str) -> bool:
|
def request_cancel(thread_id: str) -> bool:
|
||||||
"""Trip the cancel event for ``thread_id``. Returns True if found."""
|
"""Trip the cancel event for ``thread_id``. Always returns True."""
|
||||||
return manager.request_cancel(thread_id)
|
return manager.request_cancel(thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
def is_cancel_requested(thread_id: str) -> bool:
|
||||||
|
"""Return whether ``thread_id`` currently has a pending cancel signal."""
|
||||||
|
return manager.is_cancel_requested(thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cancel_state(thread_id: str) -> tuple[int, int] | None:
|
||||||
|
"""Return ``(attempt_count, requested_at_ms)`` for pending cancel state."""
|
||||||
|
return manager.cancel_state(thread_id)
|
||||||
|
|
||||||
|
|
||||||
def reset_cancel(thread_id: str) -> None:
|
def reset_cancel(thread_id: str) -> None:
|
||||||
"""Reset the cancel event for ``thread_id`` (called between turns)."""
|
"""Reset the cancel event for ``thread_id`` (called between turns)."""
|
||||||
manager.reset(thread_id)
|
manager.reset(thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
def end_turn(thread_id: str) -> None:
|
||||||
|
"""Force end-of-turn cleanup for lock + cancel state."""
|
||||||
|
manager.end_turn(thread_id)
|
||||||
|
|
||||||
|
|
||||||
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||||
"""Block concurrent prompts on the same thread.
|
"""Block concurrent prompts on the same thread.
|
||||||
|
|
||||||
|
|
@ -129,10 +193,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._require_thread_id = require_thread_id
|
self._require_thread_id = require_thread_id
|
||||||
self.tools = []
|
self.tools = []
|
||||||
# Per-call locks owned by this middleware. We track them as
|
# Per-call lock ownership tracked as (lock, epoch). ``aafter_agent``
|
||||||
# an instance attribute so ``aafter_agent`` knows which lock
|
# only releases when its epoch still matches the manager's current
|
||||||
# to release.
|
# epoch for the thread, preventing stale unlock races.
|
||||||
self._held_locks: dict[str, asyncio.Lock] = {}
|
self._held_locks: dict[str, tuple[asyncio.Lock, int]] = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _thread_id(runtime: Runtime[ContextT]) -> str | None:
|
def _thread_id(runtime: Runtime[ContextT]) -> str | None:
|
||||||
|
|
@ -183,7 +247,8 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
||||||
if lock.locked():
|
if lock.locked():
|
||||||
raise BusyError(request_id=thread_id)
|
raise BusyError(request_id=thread_id)
|
||||||
await lock.acquire()
|
await lock.acquire()
|
||||||
self._held_locks[thread_id] = lock
|
epoch = manager.bump_turn_epoch(thread_id)
|
||||||
|
self._held_locks[thread_id] = (lock, epoch)
|
||||||
# Reset the cancel event so this turn starts fresh
|
# Reset the cancel event so this turn starts fresh
|
||||||
reset_cancel(thread_id)
|
reset_cancel(thread_id)
|
||||||
return None
|
return None
|
||||||
|
|
@ -197,8 +262,15 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
||||||
thread_id = self._thread_id(runtime)
|
thread_id = self._thread_id(runtime)
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
return None
|
return None
|
||||||
lock = self._held_locks.pop(thread_id, None)
|
held = self._held_locks.pop(thread_id, None)
|
||||||
if lock is not None and lock.locked():
|
if held is None:
|
||||||
|
return None
|
||||||
|
lock, held_epoch = held
|
||||||
|
if held_epoch != manager.current_turn_epoch(thread_id):
|
||||||
|
# Stale teardown from an older attempt (e.g. runtime-recovery path
|
||||||
|
# already advanced epoch). Do not touch current lock/cancel state.
|
||||||
|
return None
|
||||||
|
if lock.locked():
|
||||||
lock.release()
|
lock.release()
|
||||||
# Always clear cancel event between turns so a stale signal
|
# Always clear cancel event between turns so a stale signal
|
||||||
# doesn't leak into the next request.
|
# doesn't leak into the next request.
|
||||||
|
|
@ -229,7 +301,10 @@ class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, Respo
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BusyMutexMiddleware",
|
"BusyMutexMiddleware",
|
||||||
|
"end_turn",
|
||||||
"get_cancel_event",
|
"get_cancel_event",
|
||||||
|
"get_cancel_state",
|
||||||
|
"is_cancel_requested",
|
||||||
"manager",
|
"manager",
|
||||||
"request_cancel",
|
"request_cancel",
|
||||||
"reset_cancel",
|
"reset_cancel",
|
||||||
|
|
|
||||||
|
|
@ -102,6 +102,8 @@ current working directory (`cwd`, default `/documents`).
|
||||||
- cd(path): change the current working directory.
|
- cd(path): change the current working directory.
|
||||||
- pwd(): print the current working directory.
|
- pwd(): print the current working directory.
|
||||||
- move_file(source, dest): move/rename a file under `/documents/`.
|
- move_file(source, dest): move/rename a file under `/documents/`.
|
||||||
|
- rm(path): delete a single file under `/documents/` (no `-r`).
|
||||||
|
- rmdir(path): delete an empty directory under `/documents/`.
|
||||||
- list_tree(path, max_depth, page_size): recursively list files/folders.
|
- list_tree(path, max_depth, page_size): recursively list files/folders.
|
||||||
|
|
||||||
## Persistence Rules
|
## Persistence Rules
|
||||||
|
|
@ -112,8 +114,9 @@ current working directory (`cwd`, default `/documents`).
|
||||||
`/documents/temp_scratch.md`) are **discarded** at end of turn — use this
|
`/documents/temp_scratch.md`) are **discarded** at end of turn — use this
|
||||||
prefix for any scratch/working content you do NOT want saved.
|
prefix for any scratch/working content you do NOT want saved.
|
||||||
- All other paths (outside `/documents/` and not `temp_*`) are rejected.
|
- All other paths (outside `/documents/` and not `temp_*`) are rejected.
|
||||||
- mkdir/move_file are staged this turn and committed at end of turn alongside
|
- mkdir/move_file/rm/rmdir are staged this turn and committed at end of
|
||||||
any new/edited documents.
|
turn alongside any new/edited documents. Snapshot/revert is enabled
|
||||||
|
for every destructive operation when action logging is on.
|
||||||
|
|
||||||
## Reading Documents Efficiently
|
## Reading Documents Efficiently
|
||||||
|
|
||||||
|
|
@ -176,6 +179,8 @@ directory (`cwd`).
|
||||||
- cd(path): change the current working directory.
|
- cd(path): change the current working directory.
|
||||||
- pwd(): print the current working directory.
|
- pwd(): print the current working directory.
|
||||||
- move_file(source, dest): move/rename a file.
|
- move_file(source, dest): move/rename a file.
|
||||||
|
- rm(path): delete a single file from disk (no `-r`). NOT reversible.
|
||||||
|
- rmdir(path): delete an empty directory from disk. NOT reversible.
|
||||||
- list_tree(path, max_depth, page_size): recursively list files/folders.
|
- list_tree(path, max_depth, page_size): recursively list files/folders.
|
||||||
|
|
||||||
## Workflow Tips
|
## Workflow Tips
|
||||||
|
|
@ -184,6 +189,8 @@ directory (`cwd`).
|
||||||
- For large trees, prefer `list_tree` then `grep` then `read_file` over
|
- For large trees, prefer `list_tree` then `grep` then `read_file` over
|
||||||
brute-force directory traversal.
|
brute-force directory traversal.
|
||||||
- Cross-mount moves are not supported.
|
- Cross-mount moves are not supported.
|
||||||
|
- Desktop deletes hit disk immediately and cannot be undone via the
|
||||||
|
agent's revert flow — confirm before calling `rm`/`rmdir`.
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -355,6 +362,42 @@ Notes:
|
||||||
- Parent folders are created as needed.
|
- Parent folders are created as needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_CLOUD_RM_TOOL_DESCRIPTION = """Deletes a single file under `/documents/`.
|
||||||
|
|
||||||
|
Mirrors POSIX `rm path` (no `-r`, no glob expansion). Stages the deletion
|
||||||
|
for end-of-turn commit; the row is removed only after the agent's turn
|
||||||
|
finishes successfully.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- path: absolute or relative file path. Cannot point at a directory — use
|
||||||
|
`rmdir` for empty folders. Cannot target the root or `/documents`.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- The action is reversible via the per-action revert flow when action
|
||||||
|
logging is enabled.
|
||||||
|
- The anonymous uploaded document is read-only and cannot be deleted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_CLOUD_RMDIR_TOOL_DESCRIPTION = """Deletes an empty directory under `/documents/`.
|
||||||
|
|
||||||
|
Mirrors POSIX `rmdir path`: refuses non-empty directories. Recursive
|
||||||
|
deletion (`rm -r`) is intentionally NOT supported — clear contents with
|
||||||
|
`rm` first.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- path: absolute or relative directory path. Cannot target the root,
|
||||||
|
`/documents`, the current cwd, or any ancestor of cwd (use `cd` to
|
||||||
|
move out first).
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Emptiness is evaluated against the post-staged view, so a same-turn
|
||||||
|
`rm /a/x.md` followed by `rmdir /a` is fine.
|
||||||
|
- If the directory was added in this same turn via `mkdir` and never
|
||||||
|
committed, the staged mkdir is dropped instead of issuing a delete.
|
||||||
|
- The action is reversible via the per-action revert flow when action
|
||||||
|
logging is enabled.
|
||||||
|
"""
|
||||||
|
|
||||||
# --- desktop-only ----------------------------------------------------------
|
# --- desktop-only ----------------------------------------------------------
|
||||||
|
|
||||||
_DESKTOP_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path.
|
_DESKTOP_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path.
|
||||||
|
|
@ -421,6 +464,28 @@ Notes:
|
||||||
- Parent folders are created as needed.
|
- Parent folders are created as needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_DESKTOP_RM_TOOL_DESCRIPTION = """Deletes a single file from disk.
|
||||||
|
|
||||||
|
Mirrors POSIX `rm path` (no `-r`, no glob expansion). The deletion hits
|
||||||
|
disk immediately. Desktop deletes are NOT reversible via the agent's
|
||||||
|
revert flow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- path: absolute mount-prefixed file path. Cannot point at a directory —
|
||||||
|
use `rmdir` for empty folders.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DESKTOP_RMDIR_TOOL_DESCRIPTION = """Deletes an empty directory from disk.
|
||||||
|
|
||||||
|
Mirrors POSIX `rmdir path`: refuses non-empty directories. Recursive
|
||||||
|
deletion is NOT supported. The deletion hits disk immediately and is
|
||||||
|
NOT reversible via the agent's revert flow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- path: absolute mount-prefixed directory path. Cannot target the mount
|
||||||
|
root or any directory containing files/subfolders.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]:
|
def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]:
|
||||||
"""Pick the active-mode description for every filesystem tool."""
|
"""Pick the active-mode description for every filesystem tool."""
|
||||||
|
|
@ -437,6 +502,8 @@ def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]:
|
||||||
"mkdir": _CLOUD_MKDIR_TOOL_DESCRIPTION,
|
"mkdir": _CLOUD_MKDIR_TOOL_DESCRIPTION,
|
||||||
"cd": SURFSENSE_CD_TOOL_DESCRIPTION,
|
"cd": SURFSENSE_CD_TOOL_DESCRIPTION,
|
||||||
"pwd": SURFSENSE_PWD_TOOL_DESCRIPTION,
|
"pwd": SURFSENSE_PWD_TOOL_DESCRIPTION,
|
||||||
|
"rm": _CLOUD_RM_TOOL_DESCRIPTION,
|
||||||
|
"rmdir": _CLOUD_RMDIR_TOOL_DESCRIPTION,
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
"ls": _DESKTOP_LIST_FILES_TOOL_DESCRIPTION,
|
"ls": _DESKTOP_LIST_FILES_TOOL_DESCRIPTION,
|
||||||
|
|
@ -450,6 +517,8 @@ def _build_tool_descriptions(filesystem_mode: FilesystemMode) -> dict[str, str]:
|
||||||
"mkdir": _DESKTOP_MKDIR_TOOL_DESCRIPTION,
|
"mkdir": _DESKTOP_MKDIR_TOOL_DESCRIPTION,
|
||||||
"cd": SURFSENSE_CD_TOOL_DESCRIPTION,
|
"cd": SURFSENSE_CD_TOOL_DESCRIPTION,
|
||||||
"pwd": SURFSENSE_PWD_TOOL_DESCRIPTION,
|
"pwd": SURFSENSE_PWD_TOOL_DESCRIPTION,
|
||||||
|
"rm": _DESKTOP_RM_TOOL_DESCRIPTION,
|
||||||
|
"rmdir": _DESKTOP_RMDIR_TOOL_DESCRIPTION,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -476,6 +545,21 @@ def _basename(path: str) -> str:
|
||||||
return path.rsplit("/", 1)[-1]
|
return path.rsplit("/", 1)[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_ancestor_of(candidate: str, target: str) -> bool:
|
||||||
|
"""True iff ``candidate`` is a strict ancestor directory of ``target``.
|
||||||
|
|
||||||
|
``target`` itself is NOT considered an ancestor (use equality for that).
|
||||||
|
Both paths are assumed to be canonicalised, absolute, and free of
|
||||||
|
trailing slashes (except the root ``/``).
|
||||||
|
"""
|
||||||
|
if not candidate.startswith("/") or not target.startswith("/"):
|
||||||
|
return False
|
||||||
|
if candidate == target:
|
||||||
|
return False
|
||||||
|
prefix = candidate.rstrip("/") + "/"
|
||||||
|
return target.startswith(prefix)
|
||||||
|
|
||||||
|
|
||||||
class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
"""SurfSense-specific filesystem middleware (cloud + desktop)."""
|
"""SurfSense-specific filesystem middleware (cloud + desktop)."""
|
||||||
|
|
||||||
|
|
@ -519,6 +603,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
self.tools.append(self._create_cd_tool())
|
self.tools.append(self._create_cd_tool())
|
||||||
self.tools.append(self._create_pwd_tool())
|
self.tools.append(self._create_pwd_tool())
|
||||||
self.tools.append(self._create_move_file_tool())
|
self.tools.append(self._create_move_file_tool())
|
||||||
|
self.tools.append(self._create_rm_tool())
|
||||||
|
self.tools.append(self._create_rmdir_tool())
|
||||||
self.tools.append(self._create_list_tree_tool())
|
self.tools.append(self._create_list_tree_tool())
|
||||||
if self._sandbox_available:
|
if self._sandbox_available:
|
||||||
self.tools.append(self._create_execute_code_tool())
|
self.tools.append(self._create_execute_code_tool())
|
||||||
|
|
@ -941,6 +1027,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
}
|
}
|
||||||
if self._is_cloud():
|
if self._is_cloud():
|
||||||
update["dirty_paths"] = [path]
|
update["dirty_paths"] = [path]
|
||||||
|
update["dirty_path_tool_calls"] = {path: runtime.tool_call_id}
|
||||||
return Command(update=update)
|
return Command(update=update)
|
||||||
|
|
||||||
def sync_write_file(
|
def sync_write_file(
|
||||||
|
|
@ -1036,6 +1123,7 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
}
|
}
|
||||||
if self._is_cloud():
|
if self._is_cloud():
|
||||||
update["dirty_paths"] = [path]
|
update["dirty_paths"] = [path]
|
||||||
|
update["dirty_path_tool_calls"] = {path: runtime.tool_call_id}
|
||||||
if doc_id_to_attach is not None:
|
if doc_id_to_attach is not None:
|
||||||
update["doc_id_by_path"] = {path: doc_id_to_attach}
|
update["doc_id_by_path"] = {path: doc_id_to_attach}
|
||||||
return Command(update=update)
|
return Command(update=update)
|
||||||
|
|
@ -1103,6 +1191,9 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
return Command(
|
return Command(
|
||||||
update={
|
update={
|
||||||
"staged_dirs": [validated],
|
"staged_dirs": [validated],
|
||||||
|
"staged_dir_tool_calls": {
|
||||||
|
validated: runtime.tool_call_id,
|
||||||
|
},
|
||||||
"messages": [
|
"messages": [
|
||||||
ToolMessage(
|
ToolMessage(
|
||||||
content=(
|
content=(
|
||||||
|
|
@ -1372,7 +1463,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
files_update: dict[str, Any] = {source: None, dest: source_file_data}
|
files_update: dict[str, Any] = {source: None, dest: source_file_data}
|
||||||
update: dict[str, Any] = {
|
update: dict[str, Any] = {
|
||||||
"files": files_update,
|
"files": files_update,
|
||||||
"pending_moves": [{"source": source, "dest": dest, "overwrite": False}],
|
"pending_moves": [
|
||||||
|
{
|
||||||
|
"source": source,
|
||||||
|
"dest": dest,
|
||||||
|
"overwrite": False,
|
||||||
|
"tool_call_id": runtime.tool_call_id,
|
||||||
|
}
|
||||||
|
],
|
||||||
"messages": [
|
"messages": [
|
||||||
ToolMessage(
|
ToolMessage(
|
||||||
content=(
|
content=(
|
||||||
|
|
@ -1396,6 +1494,323 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
update["dirty_paths"] = new_dirty
|
update["dirty_paths"] = new_dirty
|
||||||
return Command(update=update)
|
return Command(update=update)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ tool: rm
|
||||||
|
|
||||||
|
def _create_rm_tool(self) -> BaseTool:
|
||||||
|
tool_description = (
|
||||||
|
self._custom_tool_descriptions.get("rm") or _CLOUD_RM_TOOL_DESCRIPTION
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_rm(
|
||||||
|
path: Annotated[
|
||||||
|
str,
|
||||||
|
"Absolute or relative path to the file to delete.",
|
||||||
|
],
|
||||||
|
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||||
|
) -> Command | str:
|
||||||
|
if not path or not path.strip():
|
||||||
|
return "Error: path is required."
|
||||||
|
|
||||||
|
target = self._resolve_relative(path, runtime)
|
||||||
|
try:
|
||||||
|
validated = validate_path(target)
|
||||||
|
except ValueError as exc:
|
||||||
|
return f"Error: {exc}"
|
||||||
|
|
||||||
|
if self._is_cloud():
|
||||||
|
if validated in ("/", DOCUMENTS_ROOT):
|
||||||
|
return f"Error: refusing to rm '{validated}'."
|
||||||
|
if not validated.startswith(DOCUMENTS_ROOT + "/"):
|
||||||
|
return (
|
||||||
|
"Error: cloud rm must target a path under /documents/ "
|
||||||
|
f"(got '{validated}')."
|
||||||
|
)
|
||||||
|
|
||||||
|
anon = runtime.state.get("kb_anon_doc") or {}
|
||||||
|
if isinstance(anon, dict) and str(anon.get("path") or "") == validated:
|
||||||
|
return "Error: the anonymous uploaded document is read-only."
|
||||||
|
|
||||||
|
# Refuse if the path looks like a directory.
|
||||||
|
staged_dirs = list(runtime.state.get("staged_dirs") or [])
|
||||||
|
if validated in staged_dirs:
|
||||||
|
return (
|
||||||
|
f"Error: '{validated}' is a directory. Use rmdir for "
|
||||||
|
"empty directories."
|
||||||
|
)
|
||||||
|
pending_dir_deletes = list(
|
||||||
|
runtime.state.get("pending_dir_deletes") or []
|
||||||
|
)
|
||||||
|
if any(
|
||||||
|
isinstance(d, dict) and d.get("path") == validated
|
||||||
|
for d in pending_dir_deletes
|
||||||
|
):
|
||||||
|
return f"Error: '{validated}' is already queued for rmdir."
|
||||||
|
|
||||||
|
backend = self._get_backend(runtime)
|
||||||
|
if isinstance(backend, KBPostgresBackend):
|
||||||
|
# Detect "is a directory" via `ls`: if the path lists
|
||||||
|
# children we know it's a folder. Otherwise we still
|
||||||
|
# need to confirm it's a real file before staging.
|
||||||
|
children = await backend.als_info(validated)
|
||||||
|
if children:
|
||||||
|
return (
|
||||||
|
f"Error: '{validated}' is a directory. Use rmdir for "
|
||||||
|
"empty directories."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Already queued for delete this turn?
|
||||||
|
pending_deletes = list(runtime.state.get("pending_deletes") or [])
|
||||||
|
if any(
|
||||||
|
isinstance(d, dict) and d.get("path") == validated
|
||||||
|
for d in pending_deletes
|
||||||
|
):
|
||||||
|
return f"'{validated}' is already queued for deletion."
|
||||||
|
|
||||||
|
# Resolve doc_id (best-effort): file in state or DB.
|
||||||
|
files_state = runtime.state.get("files") or {}
|
||||||
|
doc_id_by_path = runtime.state.get("doc_id_by_path") or {}
|
||||||
|
resolved_doc_id: int | None = doc_id_by_path.get(validated)
|
||||||
|
if (
|
||||||
|
validated not in files_state
|
||||||
|
and resolved_doc_id is None
|
||||||
|
and isinstance(backend, KBPostgresBackend)
|
||||||
|
):
|
||||||
|
loaded = await backend._load_file_data(validated)
|
||||||
|
if loaded is None:
|
||||||
|
return f"Error: file '{validated}' not found."
|
||||||
|
_, resolved_doc_id = loaded
|
||||||
|
|
||||||
|
files_update: dict[str, Any] = {validated: None}
|
||||||
|
update: dict[str, Any] = {
|
||||||
|
"pending_deletes": [
|
||||||
|
{
|
||||||
|
"path": validated,
|
||||||
|
"tool_call_id": runtime.tool_call_id,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"files": files_update,
|
||||||
|
"doc_id_by_path": {validated: None},
|
||||||
|
"messages": [
|
||||||
|
ToolMessage(
|
||||||
|
content=(
|
||||||
|
f"Staged delete of '{validated}' (will commit at "
|
||||||
|
"end of turn)."
|
||||||
|
),
|
||||||
|
tool_call_id=runtime.tool_call_id,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Drop the path from dirty_paths so a same-turn write+rm
|
||||||
|
# doesn't recreate the doc at commit time.
|
||||||
|
dirty_paths = list(runtime.state.get("dirty_paths") or [])
|
||||||
|
if validated in dirty_paths:
|
||||||
|
new_dirty: list[Any] = [_CLEAR]
|
||||||
|
for entry in dirty_paths:
|
||||||
|
if entry != validated:
|
||||||
|
new_dirty.append(entry)
|
||||||
|
update["dirty_paths"] = new_dirty
|
||||||
|
update["dirty_path_tool_calls"] = {validated: None}
|
||||||
|
|
||||||
|
return Command(update=update)
|
||||||
|
|
||||||
|
# Desktop mode — hit disk immediately.
|
||||||
|
backend = self._get_backend(runtime)
|
||||||
|
adelete = getattr(backend, "adelete_file", None)
|
||||||
|
if not callable(adelete):
|
||||||
|
return "Error: rm is not supported by the active backend."
|
||||||
|
res: WriteResult = await adelete(validated)
|
||||||
|
if res.error:
|
||||||
|
return res.error
|
||||||
|
update_desktop: dict[str, Any] = {
|
||||||
|
"files": {validated: None},
|
||||||
|
"messages": [
|
||||||
|
ToolMessage(
|
||||||
|
content=f"Deleted file '{res.path or validated}'",
|
||||||
|
tool_call_id=runtime.tool_call_id,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
return Command(update=update_desktop)
|
||||||
|
|
||||||
|
def sync_rm(
|
||||||
|
path: Annotated[
|
||||||
|
str,
|
||||||
|
"Absolute or relative path to the file to delete.",
|
||||||
|
],
|
||||||
|
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||||
|
) -> Command | str:
|
||||||
|
return self._run_async_blocking(async_rm(path, runtime))
|
||||||
|
|
||||||
|
return StructuredTool.from_function(
|
||||||
|
name="rm",
|
||||||
|
description=tool_description,
|
||||||
|
func=sync_rm,
|
||||||
|
coroutine=async_rm,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ tool: rmdir
|
||||||
|
|
||||||
|
def _create_rmdir_tool(self) -> BaseTool:
|
||||||
|
tool_description = (
|
||||||
|
self._custom_tool_descriptions.get("rmdir") or _CLOUD_RMDIR_TOOL_DESCRIPTION
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_rmdir(
|
||||||
|
path: Annotated[
|
||||||
|
str,
|
||||||
|
"Absolute or relative path of the empty directory to delete.",
|
||||||
|
],
|
||||||
|
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||||
|
) -> Command | str:
|
||||||
|
if not path or not path.strip():
|
||||||
|
return "Error: path is required."
|
||||||
|
|
||||||
|
target = self._resolve_relative(path, runtime)
|
||||||
|
try:
|
||||||
|
validated = validate_path(target)
|
||||||
|
except ValueError as exc:
|
||||||
|
return f"Error: {exc}"
|
||||||
|
|
||||||
|
if self._is_cloud():
|
||||||
|
if validated in ("/", DOCUMENTS_ROOT):
|
||||||
|
return f"Error: refusing to rmdir '{validated}'."
|
||||||
|
if not validated.startswith(DOCUMENTS_ROOT + "/"):
|
||||||
|
return (
|
||||||
|
"Error: cloud rmdir must target a path under /documents/ "
|
||||||
|
f"(got '{validated}')."
|
||||||
|
)
|
||||||
|
|
||||||
|
cwd = self._current_cwd(runtime)
|
||||||
|
if validated == cwd or _is_ancestor_of(validated, cwd):
|
||||||
|
return (
|
||||||
|
f"Error: cannot rmdir '{validated}' because the current "
|
||||||
|
"cwd is at or under it. cd out first."
|
||||||
|
)
|
||||||
|
|
||||||
|
staged_dirs = list(runtime.state.get("staged_dirs") or [])
|
||||||
|
pending_dir_deletes = list(
|
||||||
|
runtime.state.get("pending_dir_deletes") or []
|
||||||
|
)
|
||||||
|
if any(
|
||||||
|
isinstance(d, dict) and d.get("path") == validated
|
||||||
|
for d in pending_dir_deletes
|
||||||
|
):
|
||||||
|
return f"'{validated}' is already queued for deletion."
|
||||||
|
|
||||||
|
backend = self._get_backend(runtime)
|
||||||
|
|
||||||
|
# The path must currently exist either in DB folder paths or
|
||||||
|
# in staged_dirs. We rely on KBPostgresBackend.als_info (which
|
||||||
|
# already accounts for pending deletes/moves) to evaluate
|
||||||
|
# both existence and emptiness against the post-staged view.
|
||||||
|
exists_in_staged = validated in staged_dirs
|
||||||
|
children: list[Any] = []
|
||||||
|
if isinstance(backend, KBPostgresBackend):
|
||||||
|
children = list(await backend.als_info(validated))
|
||||||
|
|
||||||
|
# Detect "is a file" — if als_info returns no children but
|
||||||
|
# the path is actually a file, we should reject. We use
|
||||||
|
# _load_file_data to disambiguate file vs missing folder.
|
||||||
|
if (
|
||||||
|
isinstance(backend, KBPostgresBackend)
|
||||||
|
and not children
|
||||||
|
and not exists_in_staged
|
||||||
|
):
|
||||||
|
loaded = await backend._load_file_data(validated)
|
||||||
|
if loaded is not None:
|
||||||
|
return (
|
||||||
|
f"Error: '{validated}' is a file. Use rm to delete files."
|
||||||
|
)
|
||||||
|
# Confirm folder exists in DB by checking the parent listing.
|
||||||
|
parent = posixpath.dirname(validated) or "/"
|
||||||
|
parent_listing = await backend.als_info(parent)
|
||||||
|
parent_has_dir = any(
|
||||||
|
info.get("path") == validated and info.get("is_dir")
|
||||||
|
for info in parent_listing
|
||||||
|
)
|
||||||
|
if not parent_has_dir:
|
||||||
|
return f"Error: directory '{validated}' not found."
|
||||||
|
|
||||||
|
if children:
|
||||||
|
return (
|
||||||
|
f"Error: directory '{validated}' is not empty. "
|
||||||
|
"Remove contents first."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Same-turn mkdir un-stage: drop the staged_dirs entry
|
||||||
|
# entirely and skip queuing a DB delete (nothing was ever
|
||||||
|
# committed).
|
||||||
|
if exists_in_staged:
|
||||||
|
rest = [d for d in staged_dirs if d != validated]
|
||||||
|
return Command(
|
||||||
|
update={
|
||||||
|
"staged_dirs": [_CLEAR, *rest],
|
||||||
|
"staged_dir_tool_calls": {validated: None},
|
||||||
|
"messages": [
|
||||||
|
ToolMessage(
|
||||||
|
content=(f"Un-staged directory '{validated}'."),
|
||||||
|
tool_call_id=runtime.tool_call_id,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return Command(
|
||||||
|
update={
|
||||||
|
"pending_dir_deletes": [
|
||||||
|
{
|
||||||
|
"path": validated,
|
||||||
|
"tool_call_id": runtime.tool_call_id,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"messages": [
|
||||||
|
ToolMessage(
|
||||||
|
content=(
|
||||||
|
f"Staged rmdir of '{validated}' (will commit "
|
||||||
|
"at end of turn)."
|
||||||
|
),
|
||||||
|
tool_call_id=runtime.tool_call_id,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Desktop mode — hit disk immediately.
|
||||||
|
backend = self._get_backend(runtime)
|
||||||
|
armdir = getattr(backend, "armdir", None)
|
||||||
|
if not callable(armdir):
|
||||||
|
return "Error: rmdir is not supported by the active backend."
|
||||||
|
res: WriteResult = await armdir(validated)
|
||||||
|
if res.error:
|
||||||
|
return res.error
|
||||||
|
return Command(
|
||||||
|
update={
|
||||||
|
"messages": [
|
||||||
|
ToolMessage(
|
||||||
|
content=f"Deleted directory '{res.path or validated}'",
|
||||||
|
tool_call_id=runtime.tool_call_id,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def sync_rmdir(
|
||||||
|
path: Annotated[
|
||||||
|
str,
|
||||||
|
"Absolute or relative path of the empty directory to delete.",
|
||||||
|
],
|
||||||
|
runtime: ToolRuntime[None, SurfSenseFilesystemState],
|
||||||
|
) -> Command | str:
|
||||||
|
return self._run_async_blocking(async_rmdir(path, runtime))
|
||||||
|
|
||||||
|
return StructuredTool.from_function(
|
||||||
|
name="rmdir",
|
||||||
|
description=tool_description,
|
||||||
|
func=sync_rmdir,
|
||||||
|
coroutine=async_rmdir,
|
||||||
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------------ tool: list_tree
|
# ------------------------------------------------------------------ tool: list_tree
|
||||||
|
|
||||||
def _create_list_tree_tool(self) -> BaseTool:
|
def _create_list_tree_tool(self) -> BaseTool:
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -115,6 +115,12 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
def _pending_moves(self) -> list[dict[str, Any]]:
|
def _pending_moves(self) -> list[dict[str, Any]]:
|
||||||
return list(self.state.get("pending_moves") or [])
|
return list(self.state.get("pending_moves") or [])
|
||||||
|
|
||||||
|
def _pending_deletes(self) -> list[dict[str, Any]]:
|
||||||
|
return list(self.state.get("pending_deletes") or [])
|
||||||
|
|
||||||
|
def _pending_dir_deletes(self) -> list[dict[str, Any]]:
|
||||||
|
return list(self.state.get("pending_dir_deletes") or [])
|
||||||
|
|
||||||
def _kb_anon_doc(self) -> dict[str, Any] | None:
|
def _kb_anon_doc(self) -> dict[str, Any] | None:
|
||||||
anon = self.state.get("kb_anon_doc")
|
anon = self.state.get("kb_anon_doc")
|
||||||
return anon if isinstance(anon, dict) else None
|
return anon if isinstance(anon, dict) else None
|
||||||
|
|
@ -140,18 +146,28 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
return path
|
return path
|
||||||
return path.rstrip("/") if path != "/" else path
|
return path.rstrip("/") if path != "/" else path
|
||||||
|
|
||||||
def _moved_view_paths(
|
def _pending_filesystem_view(
|
||||||
self,
|
self,
|
||||||
existing: dict[str, dict[str, Any]],
|
existing: dict[str, dict[str, Any]],
|
||||||
) -> tuple[set[str], dict[str, str]]:
|
) -> tuple[set[str], dict[str, str], set[str]]:
|
||||||
"""Apply ``pending_moves`` to a path set and return ``(removed, alias)``.
|
"""Compute removed/aliased/dir-suppressed paths from staged ops.
|
||||||
|
|
||||||
Removed paths should disappear from listings; ``alias[source] = dest``
|
Returns ``(removed, alias, deleted_dirs)`` where:
|
||||||
means a virtual entry should appear at ``dest`` even if no DB row is
|
|
||||||
yet there.
|
* ``removed`` — paths to drop from listings (sources of pending moves
|
||||||
|
AND paths queued for ``rm``).
|
||||||
|
* ``alias`` — ``{source: dest}`` for pending moves; the dest should
|
||||||
|
appear as a virtual entry even when no DB row is at that path yet.
|
||||||
|
* ``deleted_dirs`` — folder paths queued for ``rmdir``; their entire
|
||||||
|
subtree (descendants) is suppressed from listings/glob/grep.
|
||||||
|
|
||||||
|
Entries in ``existing`` (the ``files`` state cache) keyed by a
|
||||||
|
removed path are popped so a same-turn delete-after-write doesn't
|
||||||
|
leave a stale virtual file in listings.
|
||||||
"""
|
"""
|
||||||
removed: set[str] = set()
|
removed: set[str] = set()
|
||||||
alias: dict[str, str] = {}
|
alias: dict[str, str] = {}
|
||||||
|
deleted_dirs: set[str] = set()
|
||||||
for move in self._pending_moves():
|
for move in self._pending_moves():
|
||||||
src = move.get("source")
|
src = move.get("source")
|
||||||
dst = move.get("dest")
|
dst = move.get("dest")
|
||||||
|
|
@ -160,7 +176,23 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
removed.add(src)
|
removed.add(src)
|
||||||
alias[src] = dst
|
alias[src] = dst
|
||||||
existing.pop(src, None)
|
existing.pop(src, None)
|
||||||
return removed, alias
|
for entry in self._pending_deletes():
|
||||||
|
path = entry.get("path") if isinstance(entry, dict) else None
|
||||||
|
if not path:
|
||||||
|
continue
|
||||||
|
removed.add(path)
|
||||||
|
existing.pop(path, None)
|
||||||
|
for entry in self._pending_dir_deletes():
|
||||||
|
path = entry.get("path") if isinstance(entry, dict) else None
|
||||||
|
if not path:
|
||||||
|
continue
|
||||||
|
deleted_dirs.add(path)
|
||||||
|
return removed, alias, deleted_dirs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_dir_suppressed(path: str, deleted_dirs: set[str]) -> bool:
|
||||||
|
"""Return True iff ``path`` is at-or-under any directory in ``deleted_dirs``."""
|
||||||
|
return any(path == d or _is_under(path, d) for d in deleted_dirs)
|
||||||
|
|
||||||
# ------------------------------------------------------------------ ls/read
|
# ------------------------------------------------------------------ ls/read
|
||||||
|
|
||||||
|
|
@ -189,7 +221,7 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
seen.add(anon_path)
|
seen.add(anon_path)
|
||||||
|
|
||||||
files = self._state_files()
|
files = self._state_files()
|
||||||
moved_removed, moved_alias = self._moved_view_paths(files)
|
moved_removed, moved_alias, deleted_dirs = self._pending_filesystem_view(files)
|
||||||
|
|
||||||
if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/":
|
if normalized.startswith(DOCUMENTS_ROOT) or normalized == "/":
|
||||||
try:
|
try:
|
||||||
|
|
@ -203,7 +235,12 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
|
|
||||||
for info in db_infos:
|
for info in db_infos:
|
||||||
p = info.get("path", "")
|
p = info.get("path", "")
|
||||||
if not p or p in seen or p in moved_removed:
|
if (
|
||||||
|
not p
|
||||||
|
or p in seen
|
||||||
|
or p in moved_removed
|
||||||
|
or self._is_dir_suppressed(p, deleted_dirs)
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
infos.append(info)
|
infos.append(info)
|
||||||
seen.add(p)
|
seen.add(p)
|
||||||
|
|
@ -212,6 +249,8 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
if src not in seen:
|
if src not in seen:
|
||||||
if not _is_under(dst, normalized):
|
if not _is_under(dst, normalized):
|
||||||
continue
|
continue
|
||||||
|
if self._is_dir_suppressed(dst, deleted_dirs):
|
||||||
|
continue
|
||||||
rel = (
|
rel = (
|
||||||
dst[len(normalized) :].lstrip("/")
|
dst[len(normalized) :].lstrip("/")
|
||||||
if normalized != "/"
|
if normalized != "/"
|
||||||
|
|
@ -247,6 +286,8 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
continue
|
continue
|
||||||
if not _is_under(staged, normalized):
|
if not _is_under(staged, normalized):
|
||||||
continue
|
continue
|
||||||
|
if self._is_dir_suppressed(staged, deleted_dirs):
|
||||||
|
continue
|
||||||
rel = (
|
rel = (
|
||||||
staged[len(normalized) :].lstrip("/")
|
staged[len(normalized) :].lstrip("/")
|
||||||
if normalized != "/"
|
if normalized != "/"
|
||||||
|
|
@ -265,14 +306,26 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
for sub in sorted(subdir_paths):
|
for sub in sorted(subdir_paths):
|
||||||
if sub in seen:
|
if sub in seen:
|
||||||
continue
|
continue
|
||||||
|
if self._is_dir_suppressed(sub, deleted_dirs):
|
||||||
|
continue
|
||||||
infos.append(FileInfo(path=sub, is_dir=True, size=0, modified_at=""))
|
infos.append(FileInfo(path=sub, is_dir=True, size=0, modified_at=""))
|
||||||
seen.add(sub)
|
seen.add(sub)
|
||||||
|
|
||||||
for path_key, fd in files.items():
|
for path_key, fd in files.items():
|
||||||
if not isinstance(path_key, str) or path_key in seen:
|
if not isinstance(path_key, str) or path_key in seen:
|
||||||
continue
|
continue
|
||||||
|
# Tombstones (None values) are deletion markers from `rm`. The
|
||||||
|
# deepagents reducer normally pops them, but a stale tombstone
|
||||||
|
# surviving a checkpoint must NOT be reported as a child here —
|
||||||
|
# otherwise rmdir mistakenly sees the deleted file as content.
|
||||||
|
if fd is None:
|
||||||
|
continue
|
||||||
if not _is_under(path_key, normalized) or path_key == normalized:
|
if not _is_under(path_key, normalized) or path_key == normalized:
|
||||||
continue
|
continue
|
||||||
|
if path_key in moved_removed or self._is_dir_suppressed(
|
||||||
|
path_key, deleted_dirs
|
||||||
|
):
|
||||||
|
continue
|
||||||
if normalized == "/":
|
if normalized == "/":
|
||||||
rel = path_key.lstrip("/")
|
rel = path_key.lstrip("/")
|
||||||
else:
|
else:
|
||||||
|
|
@ -550,10 +603,12 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
seen: set[str] = set()
|
seen: set[str] = set()
|
||||||
|
|
||||||
files = self._state_files()
|
files = self._state_files()
|
||||||
moved_removed, _ = self._moved_view_paths(files)
|
moved_removed, _, deleted_dirs = self._pending_filesystem_view(files)
|
||||||
regex = re.compile(fnmatch.translate(pattern))
|
regex = re.compile(fnmatch.translate(pattern))
|
||||||
for path_key, fd in files.items():
|
for path_key, fd in files.items():
|
||||||
if path_key in moved_removed:
|
if path_key in moved_removed or self._is_dir_suppressed(
|
||||||
|
path_key, deleted_dirs
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
if not _is_under(path_key, normalized):
|
if not _is_under(path_key, normalized):
|
||||||
continue
|
continue
|
||||||
|
|
@ -595,7 +650,11 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
folder_id=row.folder_id,
|
folder_id=row.folder_id,
|
||||||
index=index,
|
index=index,
|
||||||
)
|
)
|
||||||
if candidate in seen or candidate in moved_removed:
|
if (
|
||||||
|
candidate in seen
|
||||||
|
or candidate in moved_removed
|
||||||
|
or self._is_dir_suppressed(candidate, deleted_dirs)
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
if not _is_under(candidate, normalized):
|
if not _is_under(candidate, normalized):
|
||||||
continue
|
continue
|
||||||
|
|
@ -634,10 +693,12 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
matches: list[GrepMatch] = []
|
matches: list[GrepMatch] = []
|
||||||
|
|
||||||
files = self._state_files()
|
files = self._state_files()
|
||||||
moved_removed, _ = self._moved_view_paths(files)
|
moved_removed, _, deleted_dirs = self._pending_filesystem_view(files)
|
||||||
glob_re = re.compile(fnmatch.translate(glob)) if glob else None
|
glob_re = re.compile(fnmatch.translate(glob)) if glob else None
|
||||||
for path_key, fd in files.items():
|
for path_key, fd in files.items():
|
||||||
if path_key in moved_removed:
|
if path_key in moved_removed or self._is_dir_suppressed(
|
||||||
|
path_key, deleted_dirs
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
if not _is_under(path_key, normalized):
|
if not _is_under(path_key, normalized):
|
||||||
continue
|
continue
|
||||||
|
|
@ -695,7 +756,11 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
)
|
)
|
||||||
for doc_id, chunk_id, content in chunk_buffer:
|
for doc_id, chunk_id, content in chunk_buffer:
|
||||||
candidate = doc_id_to_path.get(doc_id)
|
candidate = doc_id_to_path.get(doc_id)
|
||||||
if not candidate or candidate in moved_removed:
|
if (
|
||||||
|
not candidate
|
||||||
|
or candidate in moved_removed
|
||||||
|
or self._is_dir_suppressed(candidate, deleted_dirs)
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
if not _is_under(candidate, normalized):
|
if not _is_under(candidate, normalized):
|
||||||
continue
|
continue
|
||||||
|
|
@ -769,7 +834,7 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
return {"entries": [], "truncated": False}
|
return {"entries": [], "truncated": False}
|
||||||
|
|
||||||
files = self._state_files()
|
files = self._state_files()
|
||||||
moved_removed, _ = self._moved_view_paths(files)
|
moved_removed, _, deleted_dirs = self._pending_filesystem_view(files)
|
||||||
anon = self._kb_anon_doc()
|
anon = self._kb_anon_doc()
|
||||||
anon_path = str(anon.get("path") or "") if anon else ""
|
anon_path = str(anon.get("path") or "") if anon else ""
|
||||||
|
|
||||||
|
|
@ -795,6 +860,8 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
for _fid, fpath in sorted(index.folder_paths.items(), key=lambda kv: kv[1]):
|
for _fid, fpath in sorted(index.folder_paths.items(), key=lambda kv: kv[1]):
|
||||||
if not _is_under(fpath, normalized):
|
if not _is_under(fpath, normalized):
|
||||||
continue
|
continue
|
||||||
|
if self._is_dir_suppressed(fpath, deleted_dirs):
|
||||||
|
continue
|
||||||
depth = _depth_of(fpath)
|
depth = _depth_of(fpath)
|
||||||
if max_depth is not None and depth > max_depth:
|
if max_depth is not None and depth > max_depth:
|
||||||
continue
|
continue
|
||||||
|
|
@ -811,6 +878,8 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
for staged in self._staged_dirs():
|
for staged in self._staged_dirs():
|
||||||
if not _is_under(staged, normalized):
|
if not _is_under(staged, normalized):
|
||||||
continue
|
continue
|
||||||
|
if self._is_dir_suppressed(staged, deleted_dirs):
|
||||||
|
continue
|
||||||
depth = _depth_of(staged)
|
depth = _depth_of(staged)
|
||||||
if max_depth is not None and depth > max_depth:
|
if max_depth is not None and depth > max_depth:
|
||||||
continue
|
continue
|
||||||
|
|
@ -835,7 +904,9 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
folder_id=row.folder_id,
|
folder_id=row.folder_id,
|
||||||
index=index,
|
index=index,
|
||||||
)
|
)
|
||||||
if candidate in moved_removed:
|
if candidate in moved_removed or self._is_dir_suppressed(
|
||||||
|
candidate, deleted_dirs
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
if not _is_under(candidate, normalized):
|
if not _is_under(candidate, normalized):
|
||||||
continue
|
continue
|
||||||
|
|
@ -875,6 +946,10 @@ class KBPostgresBackend(BackendProtocol):
|
||||||
continue
|
continue
|
||||||
if not _is_under(path_key, normalized):
|
if not _is_under(path_key, normalized):
|
||||||
continue
|
continue
|
||||||
|
if path_key in moved_removed or self._is_dir_suppressed(
|
||||||
|
path_key, deleted_dirs
|
||||||
|
):
|
||||||
|
continue
|
||||||
if any(e["path"] == path_key for e in entries):
|
if any(e["path"] == path_key for e in entries):
|
||||||
continue
|
continue
|
||||||
if not (
|
if not (
|
||||||
|
|
|
||||||
|
|
@ -201,6 +201,12 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
)
|
)
|
||||||
all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT]))
|
all_paths = sorted(set(folder_paths + doc_paths + [DOCUMENTS_ROOT]))
|
||||||
|
|
||||||
|
# Pre-compute which folders have at least one descendant (folder or doc).
|
||||||
|
# A folder is "empty" iff no path in `all_paths` is strictly under it.
|
||||||
|
# Used to emit an explicit "(empty)" marker so the LLM doesn't have to
|
||||||
|
# infer emptiness from indentation alone.
|
||||||
|
non_empty_folders = self._compute_non_empty_folders(folder_paths, doc_paths)
|
||||||
|
|
||||||
lines: list[str] = []
|
lines: list[str] = []
|
||||||
for path in all_paths:
|
for path in all_paths:
|
||||||
depth = (
|
depth = (
|
||||||
|
|
@ -214,6 +220,9 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents"
|
path.rsplit("/", 1)[-1] if path != DOCUMENTS_ROOT else "/documents"
|
||||||
)
|
)
|
||||||
if is_dir:
|
if is_dir:
|
||||||
|
if path != DOCUMENTS_ROOT and path not in non_empty_folders:
|
||||||
|
lines.append(f"{indent}{display}/ (empty)")
|
||||||
|
else:
|
||||||
lines.append(f"{indent}{display}/")
|
lines.append(f"{indent}{display}/")
|
||||||
else:
|
else:
|
||||||
lines.append(f"{indent}{display}")
|
lines.append(f"{indent}{display}")
|
||||||
|
|
@ -235,6 +244,35 @@ class KnowledgeTreeMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
|
||||||
return self._format_root_summary(folder_paths, doc_paths)
|
return self._format_root_summary(folder_paths, doc_paths)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_non_empty_folders(
|
||||||
|
folder_paths: list[str], doc_paths: list[str]
|
||||||
|
) -> set[str]:
|
||||||
|
"""Return the set of folder paths that contain at least one descendant.
|
||||||
|
|
||||||
|
A folder is "non-empty" if any document path or any other folder path
|
||||||
|
is strictly under it. Documents propagate emptiness up to every
|
||||||
|
ancestor folder, while a sub-folder only marks its direct ancestors
|
||||||
|
non-empty (so a chain of empty folders all read ``(empty)``).
|
||||||
|
"""
|
||||||
|
non_empty: set[str] = set()
|
||||||
|
folder_set = set(folder_paths)
|
||||||
|
|
||||||
|
for doc_path in doc_paths:
|
||||||
|
parent = doc_path.rsplit("/", 1)[0]
|
||||||
|
while parent and parent != DOCUMENTS_ROOT:
|
||||||
|
if parent in folder_set:
|
||||||
|
non_empty.add(parent)
|
||||||
|
parent = parent.rsplit("/", 1)[0]
|
||||||
|
|
||||||
|
for child in folder_paths:
|
||||||
|
parent = child.rsplit("/", 1)[0]
|
||||||
|
while parent and parent != DOCUMENTS_ROOT and parent in folder_set:
|
||||||
|
non_empty.add(parent)
|
||||||
|
parent = parent.rsplit("/", 1)[0]
|
||||||
|
|
||||||
|
return non_empty
|
||||||
|
|
||||||
def _format_root_summary(
|
def _format_root_summary(
|
||||||
self, folder_paths: list[str], doc_paths: list[str]
|
self, folder_paths: list[str], doc_paths: list[str]
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
|
||||||
|
|
@ -360,6 +360,74 @@ class LocalFolderBackend:
|
||||||
self.move, source_path, destination_path, overwrite
|
self.move, source_path, destination_path, overwrite
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def delete_file(self, file_path: str) -> WriteResult:
|
||||||
|
"""Hard-delete a single file under root.
|
||||||
|
|
||||||
|
Refuses directories, root, and missing paths. Roughly mirrors POSIX
|
||||||
|
``rm path``; ``-r`` recursion and glob expansion are explicitly
|
||||||
|
out of scope.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
path = self._resolve_virtual(file_path)
|
||||||
|
except ValueError:
|
||||||
|
return WriteResult(error=f"Error: Invalid path '{file_path}'")
|
||||||
|
with self._lock_for(file_path):
|
||||||
|
if not path.exists():
|
||||||
|
return WriteResult(error=f"Error: File '{file_path}' not found")
|
||||||
|
if path.is_dir():
|
||||||
|
return WriteResult(
|
||||||
|
error=(
|
||||||
|
f"Error: '{file_path}' is a directory. "
|
||||||
|
"Use rmdir for empty directories."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
os.unlink(path)
|
||||||
|
except OSError as exc:
|
||||||
|
return WriteResult(
|
||||||
|
error=f"Error: failed to delete '{file_path}': {exc}"
|
||||||
|
)
|
||||||
|
return WriteResult(path=file_path, files_update=None)
|
||||||
|
|
||||||
|
async def adelete_file(self, file_path: str) -> WriteResult:
|
||||||
|
return await asyncio.to_thread(self.delete_file, file_path)
|
||||||
|
|
||||||
|
def rmdir(self, dir_path: str) -> WriteResult:
|
||||||
|
"""Hard-delete an empty directory under root.
|
||||||
|
|
||||||
|
Refuses files, root, missing paths, and non-empty directories.
|
||||||
|
``os.rmdir`` is naturally empty-only; we pre-check so the error is
|
||||||
|
clearer for the agent.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
path = self._resolve_virtual(dir_path)
|
||||||
|
except ValueError:
|
||||||
|
return WriteResult(error=f"Error: Invalid path '{dir_path}'")
|
||||||
|
with self._lock_for(dir_path):
|
||||||
|
if not path.exists():
|
||||||
|
return WriteResult(error=f"Error: Directory '{dir_path}' not found")
|
||||||
|
if not path.is_dir():
|
||||||
|
return WriteResult(error=f"Error: '{dir_path}' is not a directory")
|
||||||
|
try:
|
||||||
|
next(path.iterdir())
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
return WriteResult(
|
||||||
|
error=(
|
||||||
|
f"Error: directory '{dir_path}' is not empty. "
|
||||||
|
"Remove its contents first."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
os.rmdir(path)
|
||||||
|
except OSError as exc:
|
||||||
|
return WriteResult(error=f"Error: failed to rmdir '{dir_path}': {exc}")
|
||||||
|
return WriteResult(path=dir_path, files_update=None)
|
||||||
|
|
||||||
|
async def armdir(self, dir_path: str) -> WriteResult:
|
||||||
|
return await asyncio.to_thread(self.rmdir, dir_path)
|
||||||
|
|
||||||
def edit(
|
def edit(
|
||||||
self,
|
self,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
|
|
|
||||||
|
|
@ -285,6 +285,34 @@ class MultiRootLocalFolderBackend:
|
||||||
overwrite,
|
overwrite,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def delete_file(self, file_path: str) -> WriteResult:
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(file_path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return WriteResult(error=f"Error: {exc}")
|
||||||
|
result = self._mount_to_backend[mount].delete_file(local_path)
|
||||||
|
if result.path:
|
||||||
|
result.path = self._prefix_mount_path(mount, result.path)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def adelete_file(self, file_path: str) -> WriteResult:
|
||||||
|
return await asyncio.to_thread(self.delete_file, file_path)
|
||||||
|
|
||||||
|
def rmdir(self, dir_path: str) -> WriteResult:
|
||||||
|
try:
|
||||||
|
mount, local_path = self._split_mount_path(dir_path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return WriteResult(error=f"Error: {exc}")
|
||||||
|
if local_path == "/":
|
||||||
|
return WriteResult(error=f"Error: cannot rmdir mount root '{dir_path}'")
|
||||||
|
result = self._mount_to_backend[mount].rmdir(local_path)
|
||||||
|
if result.path:
|
||||||
|
result.path = self._prefix_mount_path(mount, result.path)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def armdir(self, dir_path: str) -> WriteResult:
|
||||||
|
return await asyncio.to_thread(self.rmdir, dir_path)
|
||||||
|
|
||||||
def edit(
|
def edit(
|
||||||
self,
|
self,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
|
|
|
||||||
166
surfsense_backend/app/agents/new_chat/prompt_caching.py
Normal file
166
surfsense_backend/app/agents/new_chat/prompt_caching.py
Normal file
|
|
@ -0,0 +1,166 @@
|
||||||
|
"""LiteLLM-native prompt caching configuration for SurfSense agents.
|
||||||
|
|
||||||
|
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
|
||||||
|
activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)``
|
||||||
|
gate always failed) with LiteLLM's universal caching mechanism.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
|
||||||
|
- Marker-based providers (need ``cache_control`` injection, which LiteLLM
|
||||||
|
performs automatically when ``cache_control_injection_points`` is set):
|
||||||
|
``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``,
|
||||||
|
``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/``
|
||||||
|
(Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM).
|
||||||
|
- Auto-cached (LiteLLM strips the marker silently): ``openai/``,
|
||||||
|
``deepseek/``, ``xai/`` — these caches automatically for prompts ≥1024
|
||||||
|
tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``.
|
||||||
|
|
||||||
|
We inject **two** breakpoints per request:
|
||||||
|
|
||||||
|
- ``role: system`` — pins the SurfSense system prompt (provider variant,
|
||||||
|
citation rules, tool catalog, KB tree, skills metadata) into the cache.
|
||||||
|
- ``index: -1`` — pins the latest message so multi-turn savings compound:
|
||||||
|
Anthropic-family providers use longest-matching-prefix lookup, so turn
|
||||||
|
N+1 still reads turn N's cache up to the shared prefix.
|
||||||
|
|
||||||
|
For OpenAI-family configs we additionally pass:
|
||||||
|
|
||||||
|
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` — routing hint that
|
||||||
|
raises hit rate by sending requests with a shared prefix to the same
|
||||||
|
backend.
|
||||||
|
- ``prompt_cache_retention="24h"`` — extends cache TTL beyond the default
|
||||||
|
5-10 min in-memory cache.
|
||||||
|
|
||||||
|
Safety net: ``litellm.drop_params=True`` is set globally in
|
||||||
|
``app.services.llm_service`` at module-load time. Any kwarg the destination
|
||||||
|
provider doesn't recognise is auto-stripped at the provider transformer
|
||||||
|
layer, so an OpenAI→Bedrock auto-mode fallback can't 400 on
|
||||||
|
``prompt_cache_key`` etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.agents.new_chat.llm_config import AgentConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Two-breakpoint policy: system + latest message. See module docstring for
|
||||||
|
# rationale. Anthropic limits requests to 4 ``cache_control`` blocks; we
|
||||||
|
# use 2 here, leaving headroom for Phase-2 tool caching.
|
||||||
|
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
|
||||||
|
{"location": "message", "role": "system"},
|
||||||
|
{"location": "message", "index": -1},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Providers (uppercase ``AgentConfig.provider`` values) that natively expose
|
||||||
|
# OpenAI-style automatic prompt caching with ``prompt_cache_key`` and
|
||||||
|
# ``prompt_cache_retention`` kwargs. Strict whitelist — many other providers
|
||||||
|
# in ``PROVIDER_MAP`` route through litellm's ``openai`` prefix without
|
||||||
|
# implementing the OpenAI prompt-cache surface (e.g. MOONSHOT, ZHIPU,
|
||||||
|
# MINIMAX), so we can't infer family from the litellm prefix alone.
|
||||||
|
_OPENAI_FAMILY_PROVIDERS: frozenset[str] = frozenset({"OPENAI", "DEEPSEEK", "XAI"})
|
||||||
|
|
||||||
|
|
||||||
|
def _is_router_llm(llm: BaseChatModel) -> bool:
|
||||||
|
"""Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import.
|
||||||
|
|
||||||
|
Importing ``app.services.llm_router_service`` at module-load time would
|
||||||
|
create a cycle via ``llm_config -> prompt_caching -> llm_router_service``.
|
||||||
|
Class-name comparison is sufficient since the class is defined in a
|
||||||
|
single place.
|
||||||
|
"""
|
||||||
|
return type(llm).__name__ == "ChatLiteLLMRouter"
|
||||||
|
|
||||||
|
|
||||||
|
def _is_openai_family_config(agent_config: AgentConfig | None) -> bool:
|
||||||
|
"""Whether the config targets an OpenAI-style prompt-cache surface.
|
||||||
|
|
||||||
|
Strict — only returns True when the user explicitly chose OPENAI,
|
||||||
|
DEEPSEEK, or XAI as the provider in their ``NewLLMConfig`` /
|
||||||
|
``YAMLConfig``. Auto-mode and custom providers return False because
|
||||||
|
we can't statically know the destination.
|
||||||
|
"""
|
||||||
|
if agent_config is None or not agent_config.provider:
|
||||||
|
return False
|
||||||
|
if agent_config.is_auto_mode:
|
||||||
|
return False
|
||||||
|
if agent_config.custom_provider:
|
||||||
|
return False
|
||||||
|
return agent_config.provider.upper() in _OPENAI_FAMILY_PROVIDERS
|
||||||
|
|
||||||
|
|
||||||
|
def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None:
|
||||||
|
"""Return ``llm.model_kwargs`` as a writable dict, or ``None`` to bail.
|
||||||
|
|
||||||
|
Initialises the field to ``{}`` when present-but-None on a Pydantic v2
|
||||||
|
model. Returns ``None`` if the LLM type doesn't expose a writable
|
||||||
|
``model_kwargs`` attribute (caller should treat as no-op).
|
||||||
|
"""
|
||||||
|
model_kwargs = getattr(llm, "model_kwargs", None)
|
||||||
|
if isinstance(model_kwargs, dict):
|
||||||
|
return model_kwargs
|
||||||
|
try:
|
||||||
|
llm.model_kwargs = {} # type: ignore[attr-defined]
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
refreshed = getattr(llm, "model_kwargs", None)
|
||||||
|
return refreshed if isinstance(refreshed, dict) else None
|
||||||
|
|
||||||
|
|
||||||
|
def apply_litellm_prompt_caching(
|
||||||
|
llm: BaseChatModel,
|
||||||
|
*,
|
||||||
|
agent_config: AgentConfig | None = None,
|
||||||
|
thread_id: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter.
|
||||||
|
|
||||||
|
Idempotent — values already present in ``llm.model_kwargs`` (e.g. from
|
||||||
|
``agent_config.litellm_params`` overrides) are preserved. Mutates
|
||||||
|
``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion``
|
||||||
|
via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge
|
||||||
|
in our custom ``ChatLiteLLMRouter``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance.
|
||||||
|
agent_config: Optional ``AgentConfig`` driving provider-specific
|
||||||
|
behaviour. When omitted (or auto-mode), only the universal
|
||||||
|
``cache_control_injection_points`` are set.
|
||||||
|
thread_id: Optional thread id used to construct a per-thread
|
||||||
|
``prompt_cache_key`` for OpenAI-family providers. Caching still
|
||||||
|
works without it (server-side automatic), but the key improves
|
||||||
|
backend routing affinity and therefore hit rate.
|
||||||
|
"""
|
||||||
|
model_kwargs = _get_or_init_model_kwargs(llm)
|
||||||
|
if model_kwargs is None:
|
||||||
|
logger.debug(
|
||||||
|
"apply_litellm_prompt_caching: %s exposes no writable model_kwargs; skipping",
|
||||||
|
type(llm).__name__,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if "cache_control_injection_points" not in model_kwargs:
|
||||||
|
model_kwargs["cache_control_injection_points"] = [
|
||||||
|
dict(point) for point in _DEFAULT_INJECTION_POINTS
|
||||||
|
]
|
||||||
|
|
||||||
|
# OpenAI-family extras only when we statically know the destination is
|
||||||
|
# OpenAI / DeepSeek / xAI. Auto-mode router fans out across providers
|
||||||
|
# so we can't safely set OpenAI-only kwargs there (drop_params would
|
||||||
|
# strip them but it's wasteful to set them in the first place).
|
||||||
|
if _is_router_llm(llm):
|
||||||
|
return
|
||||||
|
if not _is_openai_family_config(agent_config):
|
||||||
|
return
|
||||||
|
|
||||||
|
if thread_id is not None and "prompt_cache_key" not in model_kwargs:
|
||||||
|
model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}"
|
||||||
|
if "prompt_cache_retention" not in model_kwargs:
|
||||||
|
model_kwargs["prompt_cache_retention"] = "24h"
|
||||||
|
|
@ -181,9 +181,13 @@ def _initial_filesystem_state() -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"cwd": "/documents",
|
"cwd": "/documents",
|
||||||
"staged_dirs": [],
|
"staged_dirs": [],
|
||||||
|
"staged_dir_tool_calls": {},
|
||||||
"pending_moves": [],
|
"pending_moves": [],
|
||||||
|
"pending_deletes": [],
|
||||||
|
"pending_dir_deletes": [],
|
||||||
"doc_id_by_path": {},
|
"doc_id_by_path": {},
|
||||||
"dirty_paths": [],
|
"dirty_paths": [],
|
||||||
|
"dirty_path_tool_calls": {},
|
||||||
"kb_priority": [],
|
"kb_priority": [],
|
||||||
"kb_matched_chunk_ids": {},
|
"kb_matched_chunk_ids": {},
|
||||||
"kb_anon_doc": None,
|
"kb_anon_doc": None,
|
||||||
|
|
|
||||||
|
|
@ -84,6 +84,8 @@ WRITE_TOOL_DENY_PATTERNS: tuple[str, ...] = (
|
||||||
"write_file",
|
"write_file",
|
||||||
"move_file",
|
"move_file",
|
||||||
"mkdir",
|
"mkdir",
|
||||||
|
"rm",
|
||||||
|
"rmdir",
|
||||||
"update_memory",
|
"update_memory",
|
||||||
"update_memory_team",
|
"update_memory_team",
|
||||||
"update_memory_private",
|
"update_memory_private",
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,35 @@ from langgraph.types import interrupt
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Tools that mirror the safety profile of ``write_file`` against the
|
||||||
|
# SurfSense KB: each call creates ONE artifact in the user's own workspace
|
||||||
|
# with no external visibility (drafts aren't sent; new files aren't shared
|
||||||
|
# unless the user shares them later). These are auto-approved by default
|
||||||
|
# so the agent can compose drafts and seed scratch files without a popup
|
||||||
|
# on every call.
|
||||||
|
#
|
||||||
|
# Members of this set still call ``request_approval`` exactly as before;
|
||||||
|
# the function returns immediately with ``decision_type="auto_approved"``
|
||||||
|
# and the original params untouched. This preserves the call-site shape
|
||||||
|
# (logging, metadata fetching, account fallbacks) so the only behavior
|
||||||
|
# change is "no interrupt fires".
|
||||||
|
#
|
||||||
|
# To re-enable prompting, the future per-search-space rules table
|
||||||
|
# (``agent_permission_rules``) takes precedence — see the ``# (future)``
|
||||||
|
# layer-3 comment in :mod:`app.agents.new_chat.chat_deepagent`.
|
||||||
|
DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"create_gmail_draft",
|
||||||
|
"update_gmail_draft",
|
||||||
|
"create_notion_page",
|
||||||
|
"create_confluence_page",
|
||||||
|
"create_google_drive_file",
|
||||||
|
"create_dropbox_file",
|
||||||
|
"create_onedrive_file",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
class HITLResult:
|
class HITLResult:
|
||||||
"""Outcome of a human-in-the-loop approval request."""
|
"""Outcome of a human-in-the-loop approval request."""
|
||||||
|
|
@ -119,6 +148,19 @@ def request_approval(
|
||||||
logger.info("Tool '%s' is user-trusted — skipping HITL", tool_name)
|
logger.info("Tool '%s' is user-trusted — skipping HITL", tool_name)
|
||||||
return HITLResult(rejected=False, decision_type="trusted", params=dict(params))
|
return HITLResult(rejected=False, decision_type="trusted", params=dict(params))
|
||||||
|
|
||||||
|
if tool_name in DEFAULT_AUTO_APPROVED_TOOLS:
|
||||||
|
# Default policy: low-stakes creation tools (drafts + new-file
|
||||||
|
# creates) skip HITL because they're as recoverable as a local
|
||||||
|
# ``write_file`` against the SurfSense KB. The user can still
|
||||||
|
# delete the artifact in <30s if it's wrong.
|
||||||
|
logger.info(
|
||||||
|
"Tool '%s' is in DEFAULT_AUTO_APPROVED_TOOLS — skipping HITL",
|
||||||
|
tool_name,
|
||||||
|
)
|
||||||
|
return HITLResult(
|
||||||
|
rejected=False, decision_type="auto_approved", params=dict(params)
|
||||||
|
)
|
||||||
|
|
||||||
approval = interrupt(
|
approval = interrupt(
|
||||||
{
|
{
|
||||||
"type": action_type,
|
"type": action_type,
|
||||||
|
|
|
||||||
|
|
@ -63,6 +63,27 @@ def load_global_llm_configs():
|
||||||
else:
|
else:
|
||||||
seen_slugs[slug] = cfg.get("id", 0)
|
seen_slugs[slug] = cfg.get("id", 0)
|
||||||
|
|
||||||
|
# Stamp Auto (Fastest) ranking metadata. YAML configs are always
|
||||||
|
# Tier A — operator-curated, locked first when premium-eligible.
|
||||||
|
# The OpenRouter refresh tick later re-stamps health for any cfg
|
||||||
|
# whose provider == "OPENROUTER" via _enrich_health.
|
||||||
|
try:
|
||||||
|
from app.services.quality_score import static_score_yaml
|
||||||
|
|
||||||
|
for cfg in configs:
|
||||||
|
cfg["auto_pin_tier"] = "A"
|
||||||
|
static_q = static_score_yaml(cfg)
|
||||||
|
cfg["quality_score_static"] = static_q
|
||||||
|
cfg["quality_score"] = static_q
|
||||||
|
cfg["quality_score_health"] = None
|
||||||
|
# YAML cfgs whose provider is OPENROUTER are also subject
|
||||||
|
# to health gating against their own /endpoints data — a
|
||||||
|
# hand-picked dead OR model is still dead. _enrich_health
|
||||||
|
# re-stamps health_gated for them on the next refresh tick.
|
||||||
|
cfg["health_gated"] = False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to score global LLM configs: {e}")
|
||||||
|
|
||||||
return configs
|
return configs
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to load global LLM configs: {e}")
|
print(f"Warning: Failed to load global LLM configs: {e}")
|
||||||
|
|
@ -194,6 +215,9 @@ def load_openrouter_integration_settings() -> dict | None:
|
||||||
"""
|
"""
|
||||||
Load OpenRouter integration settings from the YAML config.
|
Load OpenRouter integration settings from the YAML config.
|
||||||
|
|
||||||
|
Emits startup warnings for deprecated keys (``billing_tier``,
|
||||||
|
``anonymous_enabled``) and seeds their replacements for back-compat.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict with settings if present and enabled, None otherwise
|
dict with settings if present and enabled, None otherwise
|
||||||
"""
|
"""
|
||||||
|
|
@ -206,9 +230,31 @@ def load_openrouter_integration_settings() -> dict | None:
|
||||||
with open(global_config_file, encoding="utf-8") as f:
|
with open(global_config_file, encoding="utf-8") as f:
|
||||||
data = yaml.safe_load(f)
|
data = yaml.safe_load(f)
|
||||||
settings = data.get("openrouter_integration")
|
settings = data.get("openrouter_integration")
|
||||||
if settings and settings.get("enabled"):
|
if not settings or not settings.get("enabled"):
|
||||||
return settings
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if "billing_tier" in settings:
|
||||||
|
print(
|
||||||
|
"Warning: openrouter_integration.billing_tier is deprecated; "
|
||||||
|
"tier is now derived per model from OpenRouter data "
|
||||||
|
"(':free' suffix or zero pricing). Remove this key."
|
||||||
|
)
|
||||||
|
|
||||||
|
if "anonymous_enabled" in settings:
|
||||||
|
print(
|
||||||
|
"Warning: openrouter_integration.anonymous_enabled is "
|
||||||
|
"deprecated; use anonymous_enabled_paid and/or "
|
||||||
|
"anonymous_enabled_free instead. Both new flags have been "
|
||||||
|
"seeded from the legacy value for back-compat."
|
||||||
|
)
|
||||||
|
settings.setdefault(
|
||||||
|
"anonymous_enabled_paid", settings["anonymous_enabled"]
|
||||||
|
)
|
||||||
|
settings.setdefault(
|
||||||
|
"anonymous_enabled_free", settings["anonymous_enabled"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return settings
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: Failed to load OpenRouter integration settings: {e}")
|
print(f"Warning: Failed to load OpenRouter integration settings: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
@ -217,9 +263,14 @@ def load_openrouter_integration_settings() -> dict | None:
|
||||||
def initialize_openrouter_integration():
|
def initialize_openrouter_integration():
|
||||||
"""
|
"""
|
||||||
If enabled, fetch all OpenRouter models and append them to
|
If enabled, fetch all OpenRouter models and append them to
|
||||||
config.GLOBAL_LLM_CONFIGS as dynamic premium entries.
|
config.GLOBAL_LLM_CONFIGS as dynamic entries. Each model's ``billing_tier``
|
||||||
Should be called BEFORE initialize_llm_router() so the router
|
is derived per-model from OpenRouter's API signals (``:free`` suffix or
|
||||||
correctly excludes premium models from Auto mode.
|
zero pricing), so free OpenRouter models correctly skip premium quota.
|
||||||
|
|
||||||
|
Should be called BEFORE initialize_llm_router(). Dynamic entries are
|
||||||
|
tagged ``router_pool_eligible=False`` so the LiteLLM Router pool (used
|
||||||
|
by title-gen / sub-agent flows) remains scoped to curated YAML configs,
|
||||||
|
while user-facing Auto-mode thread pinning still considers them.
|
||||||
"""
|
"""
|
||||||
settings = load_openrouter_integration_settings()
|
settings = load_openrouter_integration_settings()
|
||||||
if not settings:
|
if not settings:
|
||||||
|
|
@ -235,9 +286,13 @@ def initialize_openrouter_integration():
|
||||||
|
|
||||||
if new_configs:
|
if new_configs:
|
||||||
config.GLOBAL_LLM_CONFIGS.extend(new_configs)
|
config.GLOBAL_LLM_CONFIGS.extend(new_configs)
|
||||||
|
free_count = sum(1 for c in new_configs if c.get("billing_tier") == "free")
|
||||||
|
premium_count = sum(
|
||||||
|
1 for c in new_configs if c.get("billing_tier") == "premium"
|
||||||
|
)
|
||||||
print(
|
print(
|
||||||
f"Info: OpenRouter integration added {len(new_configs)} models "
|
f"Info: OpenRouter integration added {len(new_configs)} models "
|
||||||
f"(billing_tier={settings.get('billing_tier', 'premium')})"
|
f"(free={free_count}, premium={premium_count})"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print("Info: OpenRouter integration enabled but no models fetched")
|
print("Info: OpenRouter integration enabled but no models fetched")
|
||||||
|
|
|
||||||
|
|
@ -245,31 +245,53 @@ global_llm_configs:
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# When enabled, dynamically fetches ALL available models from the OpenRouter API
|
# When enabled, dynamically fetches ALL available models from the OpenRouter API
|
||||||
# and injects them as global configs. This gives premium users access to any model
|
# and injects them as global configs. This gives premium users access to any model
|
||||||
# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota.
|
# on OpenRouter (Claude, Gemini, Llama, Mistral, etc.) via their premium token quota,
|
||||||
|
# while free-tier OpenRouter models show up with a green Free badge and do NOT
|
||||||
|
# consume premium quota.
|
||||||
# Models are fetched at startup and refreshed periodically in the background.
|
# Models are fetched at startup and refreshed periodically in the background.
|
||||||
# All calls go through LiteLLM with the openrouter/ prefix.
|
# All calls go through LiteLLM with the openrouter/ prefix.
|
||||||
openrouter_integration:
|
openrouter_integration:
|
||||||
enabled: false
|
enabled: false
|
||||||
api_key: "sk-or-your-openrouter-api-key"
|
api_key: "sk-or-your-openrouter-api-key"
|
||||||
# billing_tier: "premium" or "free". Controls whether users need premium tokens.
|
|
||||||
billing_tier: "premium"
|
# Tier is derived PER MODEL from OpenRouter's own API signals:
|
||||||
# anonymous_enabled: set true to also show OpenRouter models to no-login users
|
# - id ends with ":free" -> billing_tier=free
|
||||||
anonymous_enabled: false
|
# - pricing.prompt AND pricing.completion == "0" -> billing_tier=free
|
||||||
|
# - otherwise -> billing_tier=premium
|
||||||
|
# No global billing_tier knob is honored; any legacy value emits a startup warning.
|
||||||
|
|
||||||
|
# Anonymous access is split by tier so operators can expose only free
|
||||||
|
# models to no-login users without leaking paid inference.
|
||||||
|
anonymous_enabled_paid: false
|
||||||
|
anonymous_enabled_free: false
|
||||||
|
|
||||||
seo_enabled: false
|
seo_enabled: false
|
||||||
# quota_reserve_tokens: tokens reserved per call for quota enforcement
|
# quota_reserve_tokens: tokens reserved per call for quota enforcement
|
||||||
quota_reserve_tokens: 4000
|
quota_reserve_tokens: 4000
|
||||||
# id_offset: starting negative ID for dynamically generated configs.
|
# id_offset: base negative ID for dynamically generated configs.
|
||||||
# Must not overlap with your static global_llm_configs IDs above.
|
# Model IDs are derived deterministically via BLAKE2b so they survive
|
||||||
|
# catalogue churn. Must not overlap with your static global_llm_configs IDs.
|
||||||
id_offset: -10000
|
id_offset: -10000
|
||||||
# refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only)
|
# refresh_interval_hours: how often to re-fetch models from OpenRouter (0 = startup only)
|
||||||
refresh_interval_hours: 24
|
refresh_interval_hours: 24
|
||||||
# rpm/tpm: Applied uniformly to all OpenRouter models for LiteLLM Router load balancing.
|
|
||||||
# OpenRouter doesn't expose per-model rate limits via API; actual throttling is handled
|
# Rate limits for PAID OpenRouter models. These are used by LiteLLM Router
|
||||||
# upstream by OpenRouter itself (your account limits are at https://openrouter.ai/settings/limits).
|
# for per-deployment accounting when OR premium models participate in the
|
||||||
# These values only matter if you set billing_tier to "free" (adding them to Auto mode).
|
# shared sub-agent "auto" pool. They do NOT cap OpenRouter itself — your
|
||||||
# For premium-only models they are cosmetic. Set conservatively or match your account tier.
|
# real account limits live at https://openrouter.ai/settings/limits.
|
||||||
rpm: 200
|
rpm: 200
|
||||||
tpm: 1000000
|
tpm: 1000000
|
||||||
|
|
||||||
|
# Rate limits for FREE OpenRouter models. Informational only: free OR
|
||||||
|
# models are intentionally kept OUT of the LiteLLM Router pool, because
|
||||||
|
# OpenRouter enforces free-tier limits globally per account (~20 RPM +
|
||||||
|
# 50-1000 daily requests across every ":free" model combined) —
|
||||||
|
# per-deployment router accounting can't represent a shared bucket
|
||||||
|
# correctly. Free OR models stay fully available in the model selector
|
||||||
|
# and for user-facing Auto thread pinning.
|
||||||
|
free_rpm: 20
|
||||||
|
free_tpm: 100000
|
||||||
|
|
||||||
litellm_params:
|
litellm_params:
|
||||||
max_tokens: 16384
|
max_tokens: 16384
|
||||||
system_instructions: ""
|
system_instructions: ""
|
||||||
|
|
|
||||||
|
|
@ -638,6 +638,12 @@ class NewChatThread(BaseModel, TimestampMixin):
|
||||||
default=False,
|
default=False,
|
||||||
server_default="false",
|
server_default="false",
|
||||||
)
|
)
|
||||||
|
# Auto (Fastest) model pin for this thread: concrete resolved global LLM
|
||||||
|
# config id. NULL means no pin; Auto will resolve on the next turn.
|
||||||
|
# Single-writer invariant: only app.services.auto_model_pin_service sets
|
||||||
|
# or clears this column (plus bulk clears when a search space's
|
||||||
|
# agent_llm_id changes). Unindexed: all reads are by primary key.
|
||||||
|
pinned_llm_config_id = Column(Integer, nullable=True)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
search_space = relationship("SearchSpace", back_populates="new_chat_threads")
|
search_space = relationship("SearchSpace", back_populates="new_chat_threads")
|
||||||
|
|
@ -689,6 +695,12 @@ class NewChatMessage(BaseModel, TimestampMixin):
|
||||||
index=True,
|
index=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Per-turn correlation id sourced from ``configurable.turn_id`` at
|
||||||
|
# streaming time (``f"{chat_id}:{ms}"``). Nullable because legacy rows
|
||||||
|
# predate the column. Used by C1's edit-from-arbitrary-position to map
|
||||||
|
# a message back to the LangGraph checkpoint that produced its turn.
|
||||||
|
turn_id = Column(String(64), nullable=True, index=True)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
thread = relationship("NewChatThread", back_populates="messages")
|
thread = relationship("NewChatThread", back_populates="messages")
|
||||||
author = relationship("User")
|
author = relationship("User")
|
||||||
|
|
@ -2292,7 +2304,13 @@ class AgentActionLog(BaseModel):
|
||||||
nullable=False,
|
nullable=False,
|
||||||
index=True,
|
index=True,
|
||||||
)
|
)
|
||||||
|
# ``turn_id`` historically held the LangChain ``tool_call.id``. It has
|
||||||
|
# been renamed to ``tool_call_id`` (with a parallel column kept for one
|
||||||
|
# release for back-compat). The real chat-turn id lives in
|
||||||
|
# ``chat_turn_id`` and is sourced from ``configurable.turn_id``.
|
||||||
turn_id = Column(String(64), nullable=True, index=True)
|
turn_id = Column(String(64), nullable=True, index=True)
|
||||||
|
tool_call_id = Column(String(64), nullable=True, index=True)
|
||||||
|
chat_turn_id = Column(String(64), nullable=True, index=True)
|
||||||
message_id = Column(String(128), nullable=True, index=True)
|
message_id = Column(String(128), nullable=True, index=True)
|
||||||
tool_name = Column(String(255), nullable=False, index=True)
|
tool_name = Column(String(255), nullable=False, index=True)
|
||||||
args = Column(JSONB, nullable=True)
|
args = Column(JSONB, nullable=True)
|
||||||
|
|
@ -2318,6 +2336,16 @@ class AgentActionLog(BaseModel):
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("ix_agent_action_log_thread_created", "thread_id", "created_at"),
|
Index("ix_agent_action_log_thread_created", "thread_id", "created_at"),
|
||||||
|
# Partial unique index enforces "at most one revert per
|
||||||
|
# original action". Created in migration 137 with
|
||||||
|
# ``WHERE reverse_of IS NOT NULL`` so non-revert rows
|
||||||
|
# (the vast majority) are unaffected and NULLs don't collide.
|
||||||
|
Index(
|
||||||
|
"ux_agent_action_log_reverse_of",
|
||||||
|
"reverse_of",
|
||||||
|
unique=True,
|
||||||
|
postgresql_where=text("reverse_of IS NOT NULL"),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2332,10 +2360,13 @@ class DocumentRevision(BaseModel):
|
||||||
|
|
||||||
__tablename__ = "document_revisions"
|
__tablename__ = "document_revisions"
|
||||||
|
|
||||||
|
# ``ON DELETE SET NULL`` (not CASCADE) so the snapshot survives the
|
||||||
|
# hard-delete it describes — without that, ``rm`` would wipe the row
|
||||||
|
# we'd need to undo it. See migration ``134_relax_revision_fks``.
|
||||||
document_id = Column(
|
document_id = Column(
|
||||||
Integer,
|
Integer,
|
||||||
ForeignKey("documents.id", ondelete="CASCADE"),
|
ForeignKey("documents.id", ondelete="SET NULL"),
|
||||||
nullable=False,
|
nullable=True,
|
||||||
index=True,
|
index=True,
|
||||||
)
|
)
|
||||||
search_space_id = Column(
|
search_space_id = Column(
|
||||||
|
|
@ -2370,10 +2401,13 @@ class FolderRevision(BaseModel):
|
||||||
|
|
||||||
__tablename__ = "folder_revisions"
|
__tablename__ = "folder_revisions"
|
||||||
|
|
||||||
|
# ``ON DELETE SET NULL`` (not CASCADE) so the snapshot survives the
|
||||||
|
# hard-delete it describes — without that, ``rmdir`` would wipe the
|
||||||
|
# row we'd need to undo it. See migration ``134_relax_revision_fks``.
|
||||||
folder_id = Column(
|
folder_id = Column(
|
||||||
Integer,
|
Integer,
|
||||||
ForeignKey("folders.id", ondelete="CASCADE"),
|
ForeignKey("folders.id", ondelete="SET NULL"),
|
||||||
nullable=False,
|
nullable=True,
|
||||||
index=True,
|
index=True,
|
||||||
)
|
)
|
||||||
search_space_id = Column(
|
search_space_id = Column(
|
||||||
|
|
|
||||||
|
|
@ -65,6 +65,13 @@ class AgentActionRead(BaseModel):
|
||||||
reverse_of: int | None
|
reverse_of: int | None
|
||||||
reverted_by_action_id: int | None
|
reverted_by_action_id: int | None
|
||||||
is_revert_action: bool
|
is_revert_action: bool
|
||||||
|
# Correlation ids added in migration 135. ``tool_call_id`` is the
|
||||||
|
# LangChain tool-call id (joinable to ``data-action-log`` SSE events
|
||||||
|
# via ``langchainToolCallId``). ``chat_turn_id`` is the per-turn id
|
||||||
|
# from ``configurable.turn_id`` (used by the
|
||||||
|
# ``revert-turn/{chat_turn_id}`` endpoint).
|
||||||
|
tool_call_id: str | None = None
|
||||||
|
chat_turn_id: str | None = None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -172,6 +179,8 @@ async def list_thread_actions(
|
||||||
reverse_of=row.reverse_of,
|
reverse_of=row.reverse_of,
|
||||||
reverted_by_action_id=revert_map.get(row.id),
|
reverted_by_action_id=revert_map.get(row.id),
|
||||||
is_revert_action=row.reverse_of is not None,
|
is_revert_action=row.reverse_of is not None,
|
||||||
|
tool_call_id=row.tool_call_id,
|
||||||
|
chat_turn_id=row.chat_turn_id,
|
||||||
created_at=row.created_at,
|
created_at=row.created_at,
|
||||||
)
|
)
|
||||||
for row in rows
|
for row in rows
|
||||||
|
|
|
||||||
|
|
@ -11,14 +11,25 @@ flag flips. Once enabled, the route runs:
|
||||||
4. Revert dispatch via :func:`app.services.revert_service.revert_action`.
|
4. Revert dispatch via :func:`app.services.revert_service.revert_action`.
|
||||||
5. Idempotent on retries: if the same action is reverted twice the second
|
5. Idempotent on retries: if the same action is reverted twice the second
|
||||||
call returns 409 ``"already reverted"``.
|
call returns 409 ``"already reverted"``.
|
||||||
|
|
||||||
|
This module also hosts the per-turn batch endpoint
|
||||||
|
``POST /api/threads/{thread_id}/revert-turn/{chat_turn_id}``. It
|
||||||
|
walks every reversible action emitted during a chat turn in reverse
|
||||||
|
``created_at`` order and reverts each independently. Partial success is the
|
||||||
|
common case — the response always contains a per-action result list and a
|
||||||
|
``status`` of ``"ok"`` or ``"partial"``; we never collapse the batch into a
|
||||||
|
whole-batch 4xx.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.feature_flags import get_flags
|
from app.agents.new_chat.feature_flags import get_flags
|
||||||
|
|
@ -97,6 +108,16 @@ async def revert_agent_action(
|
||||||
action=action,
|
action=action,
|
||||||
requester_user_id=str(user.id) if user is not None else None,
|
requester_user_id=str(user.id) if user is not None else None,
|
||||||
)
|
)
|
||||||
|
except IntegrityError:
|
||||||
|
# Partial unique index ``ux_agent_action_log_reverse_of`` caught
|
||||||
|
# a concurrent revert. Translate to the existing 409 "already
|
||||||
|
# reverted" contract so racing clients see consistent
|
||||||
|
# behaviour with the pre-flight TOCTOU check above.
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail="This action has already been reverted.",
|
||||||
|
) from None
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logger.exception("Revert dispatch raised for action_id=%s", action_id)
|
logger.exception("Revert dispatch raised for action_id=%s", action_id)
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
|
|
@ -105,7 +126,16 @@ async def revert_agent_action(
|
||||||
) from err
|
) from err
|
||||||
|
|
||||||
if outcome.status == "ok":
|
if outcome.status == "ok":
|
||||||
|
try:
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
except IntegrityError:
|
||||||
|
# Race lost on commit (constraint enforced at flush in some
|
||||||
|
# configs but at commit in others — defensive).
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail="This action has already been reverted.",
|
||||||
|
) from None
|
||||||
return {
|
return {
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
"message": outcome.message,
|
"message": outcome.message,
|
||||||
|
|
@ -122,3 +152,357 @@ async def revert_agent_action(
|
||||||
raise HTTPException(status_code=501, detail=outcome.message)
|
raise HTTPException(status_code=501, detail=outcome.message)
|
||||||
# not_reversible
|
# not_reversible
|
||||||
raise HTTPException(status_code=409, detail=outcome.message)
|
raise HTTPException(status_code=409, detail=outcome.message)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Per-turn revert batch endpoint
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
PerActionStatus = Literal[
|
||||||
|
"reverted",
|
||||||
|
"already_reverted",
|
||||||
|
"not_reversible",
|
||||||
|
"permission_denied",
|
||||||
|
"failed",
|
||||||
|
"skipped",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class RevertTurnActionResult(BaseModel):
|
||||||
|
"""Per-action outcome inside a ``revert-turn`` batch response."""
|
||||||
|
|
||||||
|
action_id: int
|
||||||
|
tool_name: str
|
||||||
|
status: PerActionStatus
|
||||||
|
message: str | None = None
|
||||||
|
new_action_id: int | None = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class RevertTurnResponse(BaseModel):
|
||||||
|
"""Top-level response for ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
|
||||||
|
|
||||||
|
``status`` is ``"ok"`` only when every reversible row succeeded. Any
|
||||||
|
``failed`` / ``not_reversible`` / ``permission_denied`` entry downgrades
|
||||||
|
it to ``"partial"``. Empty turns (no rows) return ``"ok"`` with an empty
|
||||||
|
``results`` list — callers should treat that as a no-op.
|
||||||
|
|
||||||
|
Counter invariant:
|
||||||
|
``total == reverted + already_reverted + not_reversible
|
||||||
|
+ permission_denied + failed + skipped``
|
||||||
|
|
||||||
|
Frontend toasts and the ``RevertTurnButton`` summary rely on this
|
||||||
|
invariant to display "X of Y reverted, Z could not be undone" without
|
||||||
|
silently dropping ``permission_denied`` or ``skipped`` rows.
|
||||||
|
"""
|
||||||
|
|
||||||
|
status: Literal["ok", "partial"]
|
||||||
|
chat_turn_id: str
|
||||||
|
total: int
|
||||||
|
reverted: int
|
||||||
|
already_reverted: int
|
||||||
|
not_reversible: int
|
||||||
|
permission_denied: int = 0
|
||||||
|
failed: int = 0
|
||||||
|
skipped: int = 0
|
||||||
|
results: list[RevertTurnActionResult]
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_outcome(outcome: RevertOutcome) -> PerActionStatus:
|
||||||
|
if outcome.status == "ok":
|
||||||
|
return "reverted"
|
||||||
|
if outcome.status == "permission_denied":
|
||||||
|
return "permission_denied"
|
||||||
|
# ``not_found`` / ``tool_unavailable`` / ``reverse_not_implemented`` /
|
||||||
|
# ``not_reversible`` are all surfaced to the caller as "not_reversible"
|
||||||
|
# — they share the same UX (this row cannot be undone) and only the
|
||||||
|
# ``message`` differs.
|
||||||
|
return "not_reversible"
|
||||||
|
|
||||||
|
|
||||||
|
async def _was_already_reverted(session: AsyncSession, *, action_id: int) -> int | None:
|
||||||
|
"""Return the id of an existing successful revert row, if any.
|
||||||
|
|
||||||
|
Single-action variant — kept for the post-IntegrityError lookup
|
||||||
|
path where we already know we lost a race for one specific id.
|
||||||
|
"""
|
||||||
|
stmt = select(AgentActionLog.id).where(AgentActionLog.reverse_of == action_id)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return result.scalars().first()
|
||||||
|
|
||||||
|
|
||||||
|
async def _was_already_reverted_batch(
|
||||||
|
session: AsyncSession, *, action_ids: list[int]
|
||||||
|
) -> dict[int, int]:
|
||||||
|
"""Batch idempotency probe for the revert-turn loop.
|
||||||
|
|
||||||
|
Replaces N individual ``SELECT id WHERE reverse_of = :id`` queries
|
||||||
|
(one per row in the turn) with a single ``SELECT id, reverse_of
|
||||||
|
WHERE reverse_of IN (:ids)``. The route still iterates rows in
|
||||||
|
reverse-chronological order, but the membership check is O(1) per
|
||||||
|
iteration after this query. For a turn with 30 actions that's 30
|
||||||
|
fewer round-trips through asyncpg + a smaller transaction footprint.
|
||||||
|
|
||||||
|
Returns a ``{original_action_id -> revert_action_id}`` map. Missing
|
||||||
|
keys mean "not yet reverted" — callers should treat them as
|
||||||
|
eligible for revert.
|
||||||
|
"""
|
||||||
|
if not action_ids:
|
||||||
|
return {}
|
||||||
|
stmt = select(AgentActionLog.id, AgentActionLog.reverse_of).where(
|
||||||
|
AgentActionLog.reverse_of.in_(action_ids)
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return {
|
||||||
|
original_id: revert_id
|
||||||
|
for revert_id, original_id in result.all()
|
||||||
|
if original_id is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/threads/{thread_id}/revert-turn/{chat_turn_id}",
|
||||||
|
response_model=RevertTurnResponse,
|
||||||
|
)
|
||||||
|
async def revert_agent_turn(
|
||||||
|
thread_id: int,
|
||||||
|
chat_turn_id: str,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
) -> RevertTurnResponse:
|
||||||
|
"""Revert every reversible action emitted during ``chat_turn_id``.
|
||||||
|
|
||||||
|
Walks ``AgentActionLog`` rows for the turn in reverse ``created_at``
|
||||||
|
order so dependencies (e.g. ``mkdir`` -> ``write_file`` inside the new
|
||||||
|
folder) unwind in the right sequence. Each action is reverted in its
|
||||||
|
own SAVEPOINT so a single failure does not poison the batch.
|
||||||
|
|
||||||
|
Partial success is intentional and returned with HTTP 200. Callers
|
||||||
|
must inspect ``results[*].status`` to find rows that need attention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
flags = get_flags()
|
||||||
|
if flags.disable_new_agent_stack or not flags.enable_revert_route:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail=(
|
||||||
|
"Revert is not available on this deployment yet. The route "
|
||||||
|
"ships before the UI; flip SURFSENSE_ENABLE_REVERT_ROUTE to "
|
||||||
|
"enable it."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
thread = await load_thread(session, thread_id=thread_id)
|
||||||
|
if thread is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Thread not found.")
|
||||||
|
|
||||||
|
# Reverse-chronological so the latest mutation in the turn unwinds
|
||||||
|
# first. ``id.desc()`` is the deterministic tiebreaker for actions
|
||||||
|
# written in the same millisecond.
|
||||||
|
rows_stmt = (
|
||||||
|
select(AgentActionLog)
|
||||||
|
.where(
|
||||||
|
AgentActionLog.thread_id == thread_id,
|
||||||
|
AgentActionLog.chat_turn_id == chat_turn_id,
|
||||||
|
)
|
||||||
|
.order_by(AgentActionLog.created_at.desc(), AgentActionLog.id.desc())
|
||||||
|
)
|
||||||
|
rows = (await session.execute(rows_stmt)).scalars().all()
|
||||||
|
|
||||||
|
requester_user_id = str(user.id) if user is not None else None
|
||||||
|
results: list[RevertTurnActionResult] = []
|
||||||
|
# Counters MUST be exhaustive so the response invariant
|
||||||
|
# ``total == sum(counters)`` always holds. Frontend toasts and
|
||||||
|
# ``RevertTurnButton`` rely on this for "X of Y reverted" math.
|
||||||
|
counts: dict[str, int] = {
|
||||||
|
"reverted": 0,
|
||||||
|
"already_reverted": 0,
|
||||||
|
"not_reversible": 0,
|
||||||
|
"permission_denied": 0,
|
||||||
|
"failed": 0,
|
||||||
|
"skipped": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Single batched idempotency probe replaces the previous per-row
|
||||||
|
# SELECT. ``rows`` are filtered in the loop so we pre-collect only
|
||||||
|
# the original-action ids (skip rows that are themselves
|
||||||
|
# reverts).
|
||||||
|
eligible_ids = [r.id for r in rows if r.reverse_of is None]
|
||||||
|
already_reverted_map = await _was_already_reverted_batch(
|
||||||
|
session, action_ids=eligible_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
for action in rows:
|
||||||
|
# Skip rows that ARE reverts of an earlier action — reverting a
|
||||||
|
# revert is meaningless inside a batch (the user wants to wipe
|
||||||
|
# the original effects, not chase tail).
|
||||||
|
if action.reverse_of is not None:
|
||||||
|
counts["skipped"] += 1
|
||||||
|
results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="skipped",
|
||||||
|
message="Row is itself a revert action; skipped.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Idempotency: surface "already_reverted" instead of failing.
|
||||||
|
existing_revert_id = already_reverted_map.get(action.id)
|
||||||
|
if existing_revert_id is not None:
|
||||||
|
counts["already_reverted"] += 1
|
||||||
|
results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="already_reverted",
|
||||||
|
new_action_id=existing_revert_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not can_revert(
|
||||||
|
requester_user_id=requester_user_id,
|
||||||
|
action=action,
|
||||||
|
is_admin=False,
|
||||||
|
):
|
||||||
|
counts["permission_denied"] += 1
|
||||||
|
results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="permission_denied",
|
||||||
|
message="You are not allowed to revert this action.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Per-row SAVEPOINT so one failed revert never poisons later
|
||||||
|
# successful ones.
|
||||||
|
try:
|
||||||
|
async with session.begin_nested():
|
||||||
|
outcome = await revert_action(
|
||||||
|
session,
|
||||||
|
action=action,
|
||||||
|
requester_user_id=requester_user_id,
|
||||||
|
)
|
||||||
|
if outcome.status != "ok":
|
||||||
|
raise _OutcomeRollbackError(outcome)
|
||||||
|
except _OutcomeRollbackError as rollback:
|
||||||
|
outcome = rollback.outcome
|
||||||
|
classified = _classify_outcome(outcome)
|
||||||
|
if classified == "permission_denied":
|
||||||
|
counts["permission_denied"] += 1
|
||||||
|
else:
|
||||||
|
counts["not_reversible"] += 1
|
||||||
|
results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status=classified,
|
||||||
|
message=outcome.message,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
except IntegrityError:
|
||||||
|
# Partial unique index caught a concurrent revert that won
|
||||||
|
# the race against our pre-flight ``_was_already_reverted``
|
||||||
|
# SELECT. Look up the winner so
|
||||||
|
# we can surface its ``new_action_id`` to the client.
|
||||||
|
existing_revert_id = await _was_already_reverted(
|
||||||
|
session, action_id=action.id
|
||||||
|
)
|
||||||
|
counts["already_reverted"] += 1
|
||||||
|
results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="already_reverted",
|
||||||
|
new_action_id=existing_revert_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
except Exception as err: # pragma: no cover — defensive, logged
|
||||||
|
logger.exception(
|
||||||
|
"Unexpected revert failure inside batch for action_id=%s",
|
||||||
|
action.id,
|
||||||
|
)
|
||||||
|
counts["failed"] += 1
|
||||||
|
results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="failed",
|
||||||
|
error=str(err) or err.__class__.__name__,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
counts["reverted"] += 1
|
||||||
|
results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="reverted",
|
||||||
|
message=outcome.message,
|
||||||
|
new_action_id=outcome.new_action_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Single commit at the end — successful SAVEPOINTs above already
|
||||||
|
# released; failed ones rolled back to their savepoint. No row leaks
|
||||||
|
# across the boundary.
|
||||||
|
try:
|
||||||
|
await session.commit()
|
||||||
|
except Exception as err: # pragma: no cover — defensive
|
||||||
|
logger.exception(
|
||||||
|
"Final commit for revert-turn failed (thread=%s turn=%s)",
|
||||||
|
thread_id,
|
||||||
|
chat_turn_id,
|
||||||
|
)
|
||||||
|
await session.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Internal error while finalising revert-turn batch.",
|
||||||
|
) from err
|
||||||
|
|
||||||
|
has_partial = (
|
||||||
|
counts["failed"] > 0
|
||||||
|
or counts["not_reversible"] > 0
|
||||||
|
or counts["permission_denied"] > 0
|
||||||
|
)
|
||||||
|
overall_status: Literal["ok", "partial"] = "partial" if has_partial else "ok"
|
||||||
|
|
||||||
|
return RevertTurnResponse(
|
||||||
|
status=overall_status,
|
||||||
|
chat_turn_id=chat_turn_id,
|
||||||
|
total=len(rows),
|
||||||
|
reverted=counts["reverted"],
|
||||||
|
already_reverted=counts["already_reverted"],
|
||||||
|
not_reversible=counts["not_reversible"],
|
||||||
|
permission_denied=counts["permission_denied"],
|
||||||
|
failed=counts["failed"],
|
||||||
|
skipped=counts["skipped"],
|
||||||
|
results=results,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _OutcomeRollbackError(Exception):
|
||||||
|
"""Sentinel raised inside the SAVEPOINT to roll back a non-OK outcome.
|
||||||
|
|
||||||
|
``revert_action`` writes a new ``agent_action_log`` row only on the
|
||||||
|
happy path, but on the failure paths it sometimes mutates the
|
||||||
|
``DocumentRevision``/``Document`` tables before deciding the action
|
||||||
|
is not reversible. Wrapping each call in ``begin_nested`` and raising
|
||||||
|
this from the failure branch ensures we always discard partial
|
||||||
|
writes for failed rows.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, outcome: RevertOutcome) -> None:
|
||||||
|
self.outcome = outcome
|
||||||
|
super().__init__(outcome.message)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["router"]
|
||||||
|
|
|
||||||
|
|
@ -745,6 +745,51 @@ async def search_document_titles(
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/documents/by-virtual-path", response_model=DocumentTitleRead)
|
||||||
|
async def get_document_by_virtual_path(
|
||||||
|
search_space_id: int,
|
||||||
|
virtual_path: str,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Resolve a knowledge-base document id by exact virtual path."""
|
||||||
|
try:
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
search_space_id,
|
||||||
|
Permission.DOCUMENTS_READ.value,
|
||||||
|
"You don't have permission to read documents in this search space",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await session.execute(
|
||||||
|
select(
|
||||||
|
Document.id,
|
||||||
|
Document.title,
|
||||||
|
Document.document_type,
|
||||||
|
).filter(
|
||||||
|
Document.search_space_id == search_space_id,
|
||||||
|
Document.document_metadata["virtual_path"].as_string() == virtual_path,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.first()
|
||||||
|
if row is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Document not found")
|
||||||
|
|
||||||
|
return DocumentTitleRead(
|
||||||
|
id=row.id,
|
||||||
|
title=row.title,
|
||||||
|
document_type=row.document_type,
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"Failed to resolve document by virtual path: {e!s}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
@router.get("/documents/status", response_model=DocumentStatusBatchResponse)
|
@router.get("/documents/status", response_model=DocumentStatusBatchResponse)
|
||||||
async def get_documents_status(
|
async def get_documents_status(
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
|
|
|
||||||
|
|
@ -11,10 +11,11 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy import func, or_
|
from sqlalchemy import func, or_
|
||||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||||
|
|
@ -28,6 +29,12 @@ from app.agents.new_chat.filesystem_selection import (
|
||||||
FilesystemSelection,
|
FilesystemSelection,
|
||||||
LocalFilesystemMount,
|
LocalFilesystemMount,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.middleware.busy_mutex import (
|
||||||
|
get_cancel_state,
|
||||||
|
is_cancel_requested,
|
||||||
|
manager,
|
||||||
|
request_cancel,
|
||||||
|
)
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.db import (
|
from app.db import (
|
||||||
ChatComment,
|
ChatComment,
|
||||||
|
|
@ -43,6 +50,7 @@ from app.db import (
|
||||||
)
|
)
|
||||||
from app.schemas.new_chat import (
|
from app.schemas.new_chat import (
|
||||||
AgentToolInfo,
|
AgentToolInfo,
|
||||||
|
CancelActiveTurnResponse,
|
||||||
LocalFilesystemMountPayload,
|
LocalFilesystemMountPayload,
|
||||||
NewChatMessageRead,
|
NewChatMessageRead,
|
||||||
NewChatRequest,
|
NewChatRequest,
|
||||||
|
|
@ -59,6 +67,7 @@ from app.schemas.new_chat import (
|
||||||
ThreadListItem,
|
ThreadListItem,
|
||||||
ThreadListResponse,
|
ThreadListResponse,
|
||||||
TokenUsageSummary,
|
TokenUsageSummary,
|
||||||
|
TurnStatusResponse,
|
||||||
)
|
)
|
||||||
from app.services.token_tracking_service import record_token_usage
|
from app.services.token_tracking_service import record_token_usage
|
||||||
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
||||||
|
|
@ -71,6 +80,9 @@ from app.utils.user_message_multimodal import (
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
_background_tasks: set[asyncio.Task] = set()
|
_background_tasks: set[asyncio.Task] = set()
|
||||||
|
TURN_CANCELLING_INITIAL_DELAY_MS = 200
|
||||||
|
TURN_CANCELLING_BACKOFF_FACTOR = 2
|
||||||
|
TURN_CANCELLING_MAX_DELAY_MS = 1500
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
@ -136,6 +148,326 @@ def _resolve_filesystem_selection(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
|
||||||
|
"""Bounded exponential delay for TURN_CANCELLING retry hints."""
|
||||||
|
if attempt < 1:
|
||||||
|
attempt = 1
|
||||||
|
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
|
||||||
|
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
|
||||||
|
)
|
||||||
|
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_turn_status_payload(thread_id: int) -> dict[str, object]:
|
||||||
|
lock = manager.lock_for(str(thread_id))
|
||||||
|
if not lock.locked():
|
||||||
|
return {"status": "idle"}
|
||||||
|
|
||||||
|
if is_cancel_requested(str(thread_id)):
|
||||||
|
cancel_state = get_cancel_state(str(thread_id))
|
||||||
|
attempt = cancel_state[0] if cancel_state else 1
|
||||||
|
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
|
||||||
|
retry_after_at = int(datetime.now(UTC).timestamp() * 1000) + retry_after_ms
|
||||||
|
return {
|
||||||
|
"status": "cancelling",
|
||||||
|
"retry_after_ms": retry_after_ms,
|
||||||
|
"retry_after_at": retry_after_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {"status": "busy"}
|
||||||
|
|
||||||
|
|
||||||
|
def _set_retry_after_headers(response: Response, retry_after_ms: int) -> None:
|
||||||
|
response.headers["retry-after-ms"] = str(retry_after_ms)
|
||||||
|
response.headers["Retry-After"] = str(max(1, (retry_after_ms + 999) // 1000))
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_if_thread_busy_for_start(thread_id: int) -> None:
|
||||||
|
status_payload = _build_turn_status_payload(thread_id)
|
||||||
|
status = status_payload["status"]
|
||||||
|
if status == "idle":
|
||||||
|
return
|
||||||
|
if status == "cancelling":
|
||||||
|
retry_after_ms = int(status_payload.get("retry_after_ms") or 0)
|
||||||
|
detail = {
|
||||||
|
"errorCode": "TURN_CANCELLING",
|
||||||
|
"message": "A previous response is still stopping. Please try again in a moment.",
|
||||||
|
"retry_after_ms": retry_after_ms if retry_after_ms > 0 else None,
|
||||||
|
"retry_after_at": status_payload.get("retry_after_at"),
|
||||||
|
}
|
||||||
|
headers = (
|
||||||
|
{
|
||||||
|
"retry-after-ms": str(retry_after_ms),
|
||||||
|
"Retry-After": str(max(1, (retry_after_ms + 999) // 1000)),
|
||||||
|
}
|
||||||
|
if retry_after_ms > 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
raise HTTPException(status_code=409, detail=detail, headers=headers)
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail={
|
||||||
|
"errorCode": "THREAD_BUSY",
|
||||||
|
"message": "Another response is still finishing for this thread. Please try again in a moment.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_pre_turn_checkpoint_id(
|
||||||
|
checkpoint_tuples: list,
|
||||||
|
*,
|
||||||
|
turn_id: str,
|
||||||
|
) -> str | None:
|
||||||
|
"""Locate the LangGraph checkpoint immediately before ``turn_id`` started.
|
||||||
|
|
||||||
|
``checkpoint_tuples`` arrives newest-first from
|
||||||
|
``checkpointer.alist(config)``. We walk OLDEST-first (``reversed``)
|
||||||
|
and remember the most recent checkpoint that does NOT belong to the
|
||||||
|
edited turn. As soon as we cross into the edited turn (a checkpoint
|
||||||
|
whose ``turn_id`` matches), we return the previously-tracked
|
||||||
|
checkpoint — that's the state immediately before ``turn_id`` began.
|
||||||
|
|
||||||
|
The naive "newest-first, return first non-matching" approach is
|
||||||
|
INCORRECT when later turns exist after ``turn_id``: their
|
||||||
|
checkpoints also satisfy ``cp_turn_id != turn_id`` and would be
|
||||||
|
returned before the real pre-turn boundary is reached.
|
||||||
|
|
||||||
|
Reads from ``cp_tuple.metadata`` (the durable surface promoted from
|
||||||
|
``configurable`` at write time) rather than ``config["configurable"]``
|
||||||
|
so the lookup is portable across checkpointer implementations.
|
||||||
|
|
||||||
|
Returns ``None`` when no eligible pre-turn checkpoint exists (e.g.
|
||||||
|
the edited turn is the very first turn of the thread). Callers fall
|
||||||
|
back to the oldest available checkpoint in that case.
|
||||||
|
"""
|
||||||
|
|
||||||
|
last_pre_turn_target: str | None = None
|
||||||
|
for cp_tuple in reversed(checkpoint_tuples): # oldest -> newest
|
||||||
|
metadata = getattr(cp_tuple, "metadata", None) or {}
|
||||||
|
cp_turn_id = metadata.get("turn_id") if isinstance(metadata, dict) else None
|
||||||
|
if cp_turn_id == turn_id:
|
||||||
|
# Crossed into the edited turn; the previous tracked
|
||||||
|
# checkpoint is the rewind target. May be ``None`` if we hit
|
||||||
|
# the edited turn on the very first iteration.
|
||||||
|
return last_pre_turn_target
|
||||||
|
try:
|
||||||
|
last_pre_turn_target = cp_tuple.config["configurable"]["checkpoint_id"]
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
continue
|
||||||
|
return last_pre_turn_target
|
||||||
|
|
||||||
|
|
||||||
|
async def _revert_turns_for_regenerate(
|
||||||
|
*,
|
||||||
|
thread_id: int,
|
||||||
|
chat_turn_ids: list[str],
|
||||||
|
requester_user_id: str,
|
||||||
|
) -> dict:
|
||||||
|
"""Best-effort revert pass for every ``chat_turn_id`` in ``chat_turn_ids``.
|
||||||
|
|
||||||
|
Runs BEFORE the regenerate stream so the frontend can surface
|
||||||
|
partial-rollback feedback alongside the new assistant turn. Each
|
||||||
|
turn's actions are reverted in their own SAVEPOINTs (handled
|
||||||
|
inside :mod:`app.routes.agent_revert_route`'s helpers) so a single
|
||||||
|
failure never poisons the batch.
|
||||||
|
|
||||||
|
Sequencing inside the request: revert THEN regenerate. The
|
||||||
|
operation is NOT atomic and partial state IS surfaced — see the
|
||||||
|
plan's "Sequencing inside the request" note.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.routes.agent_revert_route import (
|
||||||
|
RevertTurnActionResult,
|
||||||
|
_classify_outcome,
|
||||||
|
_OutcomeRollbackError,
|
||||||
|
_was_already_reverted,
|
||||||
|
_was_already_reverted_batch,
|
||||||
|
)
|
||||||
|
from app.services.revert_service import (
|
||||||
|
can_revert,
|
||||||
|
revert_action,
|
||||||
|
)
|
||||||
|
|
||||||
|
aggregated_results: list[dict] = []
|
||||||
|
# Exhaustive counters keep the response invariant
|
||||||
|
# ``total == sum(counters)`` true for ``data-revert-results``.
|
||||||
|
counts = {
|
||||||
|
"reverted": 0,
|
||||||
|
"already_reverted": 0,
|
||||||
|
"not_reversible": 0,
|
||||||
|
"permission_denied": 0,
|
||||||
|
"failed": 0,
|
||||||
|
"skipped": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Local import keeps the route module's existing imports tidy and
|
||||||
|
# avoids a circular dependency at module-load time.
|
||||||
|
from app.db import AgentActionLog as _AgentActionLog
|
||||||
|
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
for chat_turn_id in chat_turn_ids:
|
||||||
|
rows_stmt = (
|
||||||
|
select(_AgentActionLog)
|
||||||
|
.where(
|
||||||
|
_AgentActionLog.thread_id == thread_id,
|
||||||
|
_AgentActionLog.chat_turn_id == chat_turn_id,
|
||||||
|
)
|
||||||
|
.order_by(
|
||||||
|
_AgentActionLog.created_at.desc(),
|
||||||
|
_AgentActionLog.id.desc(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rows = (await session.execute(rows_stmt)).scalars().all()
|
||||||
|
|
||||||
|
# Batch idempotency probe across the turn (single SELECT
|
||||||
|
# instead of one per row).
|
||||||
|
eligible_ids = [r.id for r in rows if r.reverse_of is None]
|
||||||
|
already_reverted_map = await _was_already_reverted_batch(
|
||||||
|
session, action_ids=eligible_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
for action in rows:
|
||||||
|
if action.reverse_of is not None:
|
||||||
|
counts["skipped"] += 1
|
||||||
|
aggregated_results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="skipped",
|
||||||
|
message="Row is itself a revert action; skipped.",
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
existing_revert_id = already_reverted_map.get(action.id)
|
||||||
|
if existing_revert_id is not None:
|
||||||
|
counts["already_reverted"] += 1
|
||||||
|
aggregated_results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="already_reverted",
|
||||||
|
new_action_id=existing_revert_id,
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not can_revert(
|
||||||
|
requester_user_id=requester_user_id,
|
||||||
|
action=action,
|
||||||
|
is_admin=False,
|
||||||
|
):
|
||||||
|
counts["permission_denied"] += 1
|
||||||
|
aggregated_results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="permission_denied",
|
||||||
|
message="You are not allowed to revert this action.",
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with session.begin_nested():
|
||||||
|
outcome = await revert_action(
|
||||||
|
session,
|
||||||
|
action=action,
|
||||||
|
requester_user_id=requester_user_id,
|
||||||
|
)
|
||||||
|
if outcome.status != "ok":
|
||||||
|
raise _OutcomeRollbackError(outcome)
|
||||||
|
except _OutcomeRollbackError as rollback:
|
||||||
|
outcome = rollback.outcome
|
||||||
|
classified = _classify_outcome(outcome)
|
||||||
|
if classified == "permission_denied":
|
||||||
|
counts["permission_denied"] += 1
|
||||||
|
else:
|
||||||
|
counts["not_reversible"] += 1
|
||||||
|
aggregated_results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status=classified,
|
||||||
|
message=outcome.message,
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
except IntegrityError:
|
||||||
|
# Concurrent revert won the race against the
|
||||||
|
# pre-flight ``_was_already_reverted`` SELECT.
|
||||||
|
# Surface the winning revert id so the client can
|
||||||
|
# treat this as a successful idempotent op.
|
||||||
|
existing_revert_id = await _was_already_reverted(
|
||||||
|
session, action_id=action.id
|
||||||
|
)
|
||||||
|
counts["already_reverted"] += 1
|
||||||
|
aggregated_results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="already_reverted",
|
||||||
|
new_action_id=existing_revert_id,
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
except Exception as err: # pragma: no cover — defensive
|
||||||
|
_logger.exception(
|
||||||
|
"Unexpected revert failure during regenerate batch "
|
||||||
|
"for action_id=%s",
|
||||||
|
action.id,
|
||||||
|
)
|
||||||
|
counts["failed"] += 1
|
||||||
|
aggregated_results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="failed",
|
||||||
|
error=str(err) or err.__class__.__name__,
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
counts["reverted"] += 1
|
||||||
|
aggregated_results.append(
|
||||||
|
RevertTurnActionResult(
|
||||||
|
action_id=action.id,
|
||||||
|
tool_name=action.tool_name,
|
||||||
|
status="reverted",
|
||||||
|
message=outcome.message,
|
||||||
|
new_action_id=outcome.new_action_id,
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
_logger.exception(
|
||||||
|
"[regenerate-revert] Final commit failed; rolling back batch."
|
||||||
|
)
|
||||||
|
await session.rollback()
|
||||||
|
|
||||||
|
has_partial = (
|
||||||
|
counts["failed"] > 0
|
||||||
|
or counts["not_reversible"] > 0
|
||||||
|
or counts["permission_denied"] > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "partial" if has_partial else "ok",
|
||||||
|
"chat_turn_ids": chat_turn_ids,
|
||||||
|
"total": len(aggregated_results),
|
||||||
|
"reverted": counts["reverted"],
|
||||||
|
"already_reverted": counts["already_reverted"],
|
||||||
|
"not_reversible": counts["not_reversible"],
|
||||||
|
"permission_denied": counts["permission_denied"],
|
||||||
|
"failed": counts["failed"],
|
||||||
|
"skipped": counts["skipped"],
|
||||||
|
"results": aggregated_results,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _try_delete_sandbox(thread_id: int) -> None:
|
def _try_delete_sandbox(thread_id: int) -> None:
|
||||||
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
|
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
|
||||||
from app.agents.new_chat.sandbox import (
|
from app.agents.new_chat.sandbox import (
|
||||||
|
|
@ -574,6 +906,7 @@ async def get_thread_messages(
|
||||||
token_usage=TokenUsageSummary.model_validate(msg.token_usage)
|
token_usage=TokenUsageSummary.model_validate(msg.token_usage)
|
||||||
if msg.token_usage
|
if msg.token_usage
|
||||||
else None,
|
else None,
|
||||||
|
turn_id=msg.turn_id,
|
||||||
)
|
)
|
||||||
for msg in db_messages
|
for msg in db_messages
|
||||||
]
|
]
|
||||||
|
|
@ -1006,12 +1339,24 @@ async def append_message(
|
||||||
# Check thread-level access based on visibility
|
# Check thread-level access based on visibility
|
||||||
await check_thread_access(session, thread, user)
|
await check_thread_access(session, thread, user)
|
||||||
|
|
||||||
# Create message
|
# Create message. ``turn_id`` is the per-turn correlation id from
|
||||||
|
# ``configurable.turn_id`` (added in migration 136) — when the
|
||||||
|
# client streams it back to ``appendMessage``, we persist it so
|
||||||
|
# C1's edit-from-arbitrary-position can later map this message
|
||||||
|
# back to the LangGraph checkpoint that produced its turn.
|
||||||
|
raw_turn_id = raw_body.get("turn_id")
|
||||||
|
turn_id_value = (
|
||||||
|
str(raw_turn_id).strip()
|
||||||
|
if isinstance(raw_turn_id, str) and raw_turn_id.strip()
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
db_message = NewChatMessage(
|
db_message = NewChatMessage(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
role=message_role,
|
role=message_role,
|
||||||
content=content,
|
content=content,
|
||||||
author_id=user.id,
|
author_id=user.id,
|
||||||
|
turn_id=turn_id_value,
|
||||||
)
|
)
|
||||||
session.add(db_message)
|
session.add(db_message)
|
||||||
|
|
||||||
|
|
@ -1050,6 +1395,7 @@ async def append_message(
|
||||||
created_at=db_message.created_at,
|
created_at=db_message.created_at,
|
||||||
author_id=db_message.author_id,
|
author_id=db_message.author_id,
|
||||||
token_usage=None,
|
token_usage=None,
|
||||||
|
turn_id=db_message.turn_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|
@ -1207,6 +1553,7 @@ async def handle_new_chat(
|
||||||
|
|
||||||
# Check thread-level access based on visibility
|
# Check thread-level access based on visibility
|
||||||
await check_thread_access(session, thread, user)
|
await check_thread_access(session, thread, user)
|
||||||
|
_raise_if_thread_busy_for_start(request.chat_id)
|
||||||
filesystem_selection = _resolve_filesystem_selection(
|
filesystem_selection = _resolve_filesystem_selection(
|
||||||
mode=request.filesystem_mode,
|
mode=request.filesystem_mode,
|
||||||
client_platform=request.client_platform,
|
client_platform=request.client_platform,
|
||||||
|
|
@ -1281,6 +1628,93 @@ async def handle_new_chat(
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/threads/{thread_id}/cancel-active-turn",
|
||||||
|
response_model=CancelActiveTurnResponse,
|
||||||
|
)
|
||||||
|
async def cancel_active_turn(
|
||||||
|
thread_id: int,
|
||||||
|
response: Response,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""Signal cancellation for the currently running turn on ``thread_id``."""
|
||||||
|
result = await session.execute(
|
||||||
|
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||||
|
)
|
||||||
|
thread = result.scalars().first()
|
||||||
|
if not thread:
|
||||||
|
raise HTTPException(status_code=404, detail="Thread not found")
|
||||||
|
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
thread.search_space_id,
|
||||||
|
Permission.CHATS_UPDATE.value,
|
||||||
|
"You don't have permission to update chats in this search space",
|
||||||
|
)
|
||||||
|
await check_thread_access(session, thread, user)
|
||||||
|
|
||||||
|
status_payload = _build_turn_status_payload(thread_id)
|
||||||
|
if status_payload["status"] == "idle":
|
||||||
|
return CancelActiveTurnResponse(
|
||||||
|
status="idle",
|
||||||
|
error_code="NO_ACTIVE_TURN",
|
||||||
|
)
|
||||||
|
|
||||||
|
request_cancel(str(thread_id))
|
||||||
|
response.status_code = 202
|
||||||
|
updated_payload = _build_turn_status_payload(thread_id)
|
||||||
|
retry_after_ms = int(updated_payload.get("retry_after_ms") or 0)
|
||||||
|
retry_after_at = (
|
||||||
|
int(updated_payload["retry_after_at"])
|
||||||
|
if "retry_after_at" in updated_payload
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if retry_after_ms > 0:
|
||||||
|
_set_retry_after_headers(response, retry_after_ms)
|
||||||
|
return CancelActiveTurnResponse(
|
||||||
|
status="cancelling",
|
||||||
|
error_code="TURN_CANCELLING",
|
||||||
|
retry_after_ms=retry_after_ms if retry_after_ms > 0 else None,
|
||||||
|
retry_after_at=retry_after_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/threads/{thread_id}/turn-status",
|
||||||
|
response_model=TurnStatusResponse,
|
||||||
|
)
|
||||||
|
async def get_turn_status(
|
||||||
|
thread_id: int,
|
||||||
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
result = await session.execute(
|
||||||
|
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||||
|
)
|
||||||
|
thread = result.scalars().first()
|
||||||
|
if not thread:
|
||||||
|
raise HTTPException(status_code=404, detail="Thread not found")
|
||||||
|
|
||||||
|
await check_permission(
|
||||||
|
session,
|
||||||
|
user,
|
||||||
|
thread.search_space_id,
|
||||||
|
Permission.CHATS_READ.value,
|
||||||
|
"You don't have permission to view chats in this search space",
|
||||||
|
)
|
||||||
|
await check_thread_access(session, thread, user)
|
||||||
|
|
||||||
|
status_payload = _build_turn_status_payload(thread_id)
|
||||||
|
return TurnStatusResponse(
|
||||||
|
status=status_payload["status"], # type: ignore[arg-type]
|
||||||
|
active_turn_id=None,
|
||||||
|
retry_after_ms=status_payload.get("retry_after_ms"), # type: ignore[arg-type]
|
||||||
|
retry_after_at=status_payload.get("retry_after_at"), # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Chat Regeneration Endpoint (Edit/Reload)
|
# Chat Regeneration Endpoint (Edit/Reload)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -1336,6 +1770,7 @@ async def regenerate_response(
|
||||||
|
|
||||||
# Check thread-level access based on visibility
|
# Check thread-level access based on visibility
|
||||||
await check_thread_access(session, thread, user)
|
await check_thread_access(session, thread, user)
|
||||||
|
_raise_if_thread_busy_for_start(thread_id)
|
||||||
filesystem_selection = _resolve_filesystem_selection(
|
filesystem_selection = _resolve_filesystem_selection(
|
||||||
mode=request.filesystem_mode,
|
mode=request.filesystem_mode,
|
||||||
client_platform=request.client_platform,
|
client_platform=request.client_platform,
|
||||||
|
|
@ -1373,8 +1808,85 @@ async def regenerate_response(
|
||||||
user_query_to_use = request.user_query
|
user_query_to_use = request.user_query
|
||||||
regenerate_image_urls: list[str] = []
|
regenerate_image_urls: list[str] = []
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------
|
||||||
|
# Edit-from-arbitrary-position. When the client passes
|
||||||
|
# ``from_message_id`` we look up its persisted ``turn_id`` (added
|
||||||
|
# in migration 136) and pick the checkpoint immediately before
|
||||||
|
# that turn started.
|
||||||
|
#
|
||||||
|
# Legacy graceful-degradation contract:
|
||||||
|
# * Rows persisted BEFORE migration 136 have ``turn_id IS NULL``.
|
||||||
|
# Returning 400 in that case is the wrong UX — the user is
|
||||||
|
# editing an old message in an existing thread and just wants
|
||||||
|
# it to work. We instead skip the checkpoint rewind (the
|
||||||
|
# stream falls back to the latest state) and skip the revert
|
||||||
|
# pass (no chat_turn_id available to walk). Deletion still
|
||||||
|
# uses ``created_at``, so the messages-after-cursor slice is
|
||||||
|
# correct on both legacy and post-136 rows.
|
||||||
|
# ---------------------------------------------------------------
|
||||||
|
from_message_turn_id: str | None = None
|
||||||
|
from_message_created_at: datetime | None = None
|
||||||
|
legacy_from_message: bool = False
|
||||||
|
if request.from_message_id is not None:
|
||||||
|
from_msg_row = await session.execute(
|
||||||
|
select(NewChatMessage).filter(
|
||||||
|
NewChatMessage.id == request.from_message_id,
|
||||||
|
NewChatMessage.thread_id == thread_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
from_msg = from_msg_row.scalars().first()
|
||||||
|
if from_msg is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="from_message_id not found in this thread.",
|
||||||
|
)
|
||||||
|
from_message_created_at = from_msg.created_at
|
||||||
|
if not from_msg.turn_id:
|
||||||
|
# Legacy row — surface the degradation in logs but let
|
||||||
|
# the request proceed with the slice-based delete and a
|
||||||
|
# cold-start checkpoint.
|
||||||
|
legacy_from_message = True
|
||||||
|
_logger.warning(
|
||||||
|
"[regenerate] from_message_id=%s on thread=%s has no "
|
||||||
|
"turn_id (legacy row pre-migration-136). Falling back "
|
||||||
|
"to slice-based delete without checkpoint rewind. "
|
||||||
|
"revert_actions=%s will be ignored.",
|
||||||
|
request.from_message_id,
|
||||||
|
thread_id,
|
||||||
|
request.revert_actions,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from_message_turn_id = from_msg.turn_id
|
||||||
|
|
||||||
|
# Walk oldest-to-newest and pick the LAST checkpoint whose
|
||||||
|
# ``turn_id`` differs from the edited turn — that's the state
|
||||||
|
# immediately before this turn started running. We read from
|
||||||
|
# ``metadata`` (the durable surface) rather than
|
||||||
|
# ``config["configurable"]`` so the lookup works across
|
||||||
|
# checkpointer implementations.
|
||||||
|
target_checkpoint_id = _find_pre_turn_checkpoint_id(
|
||||||
|
checkpoint_tuples,
|
||||||
|
turn_id=from_message_turn_id,
|
||||||
|
)
|
||||||
|
if target_checkpoint_id is None and len(checkpoint_tuples) > 0:
|
||||||
|
# Fall back to the oldest checkpoint — better than
|
||||||
|
# 400ing when the agent didn't checkpoint pre-turn
|
||||||
|
# (e.g. very first turn of the thread).
|
||||||
|
target_checkpoint_id = checkpoint_tuples[-1].config["configurable"][
|
||||||
|
"checkpoint_id"
|
||||||
|
]
|
||||||
|
|
||||||
# Look through checkpoints to find the right one
|
# Look through checkpoints to find the right one
|
||||||
# We want to find the checkpoint just before the last HumanMessage
|
# We want to find the checkpoint just before the last HumanMessage.
|
||||||
|
# We enter this branch when:
|
||||||
|
# * the client did NOT pin ``from_message_id`` (legacy reload/edit), OR
|
||||||
|
# * the client pinned ``from_message_id`` but the row is a
|
||||||
|
# legacy pre-migration-136 row with no ``turn_id`` (we
|
||||||
|
# downgraded to the same heuristic as a regular reload).
|
||||||
|
# We DO skip it when a real turn_id pinned ``target_checkpoint_id``
|
||||||
|
# — that's the C1 happy path and the heuristic below would just
|
||||||
|
# re-derive a worse target.
|
||||||
|
if request.from_message_id is None or legacy_from_message:
|
||||||
for i, cp_tuple in enumerate(checkpoint_tuples):
|
for i, cp_tuple in enumerate(checkpoint_tuples):
|
||||||
# Access the checkpoint's channel_values which contains "messages"
|
# Access the checkpoint's channel_values which contains "messages"
|
||||||
checkpoint_data = cp_tuple.checkpoint
|
checkpoint_data = cp_tuple.checkpoint
|
||||||
|
|
@ -1397,12 +1909,15 @@ async def regenerate_response(
|
||||||
prev_messages = prev_channel_values.get("messages", [])
|
prev_messages = prev_channel_values.get("messages", [])
|
||||||
for msg in reversed(prev_messages):
|
for msg in reversed(prev_messages):
|
||||||
if isinstance(msg, HumanMessage):
|
if isinstance(msg, HumanMessage):
|
||||||
q, imgs = split_langchain_human_content(msg.content)
|
q, imgs = split_langchain_human_content(
|
||||||
|
msg.content
|
||||||
|
)
|
||||||
user_query_to_use = q
|
user_query_to_use = q
|
||||||
regenerate_image_urls = imgs
|
regenerate_image_urls = imgs
|
||||||
break
|
break
|
||||||
if user_query_to_use is not None and (
|
if user_query_to_use is not None and (
|
||||||
str(user_query_to_use).strip() or regenerate_image_urls
|
str(user_query_to_use).strip()
|
||||||
|
or regenerate_image_urls
|
||||||
):
|
):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
@ -1472,8 +1987,23 @@ async def regenerate_response(
|
||||||
detail="Could not determine user query for regeneration. Please provide a user_query.",
|
detail="Could not determine user query for regeneration. Please provide a user_query.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the last two messages to delete AFTER streaming succeeds
|
# Get the messages to delete AFTER streaming succeeds.
|
||||||
# This prevents data loss if streaming fails
|
# This prevents data loss if streaming fails.
|
||||||
|
#
|
||||||
|
# When ``from_message_id`` is set we slice from that message
|
||||||
|
# forward (using ``created_at`` so we also catch any tool/system
|
||||||
|
# messages persisted into the same turn). Otherwise
|
||||||
|
# we keep the legacy "last 2 messages" rewind.
|
||||||
|
if request.from_message_id is not None and from_message_created_at is not None:
|
||||||
|
last_messages_result = await session.execute(
|
||||||
|
select(NewChatMessage)
|
||||||
|
.filter(
|
||||||
|
NewChatMessage.thread_id == thread_id,
|
||||||
|
NewChatMessage.created_at >= from_message_created_at,
|
||||||
|
)
|
||||||
|
.order_by(NewChatMessage.created_at.desc())
|
||||||
|
)
|
||||||
|
else:
|
||||||
last_messages_result = await session.execute(
|
last_messages_result = await session.execute(
|
||||||
select(NewChatMessage)
|
select(NewChatMessage)
|
||||||
.filter(NewChatMessage.thread_id == thread_id)
|
.filter(NewChatMessage.thread_id == thread_id)
|
||||||
|
|
@ -1484,6 +2014,24 @@ async def regenerate_response(
|
||||||
|
|
||||||
message_ids_to_delete = [msg.id for msg in messages_to_delete]
|
message_ids_to_delete = [msg.id for msg in messages_to_delete]
|
||||||
|
|
||||||
|
# When revert_actions is requested, collect the set of
|
||||||
|
# ``chat_turn_id``s present in the slice we're about to delete.
|
||||||
|
# Each one will be reverted (best-effort) BEFORE the regenerate
|
||||||
|
# stream begins. Legacy rows have ``turn_id=None`` and silently
|
||||||
|
# contribute nothing — we already logged the degradation above.
|
||||||
|
revert_turn_ids: list[str] = []
|
||||||
|
if (
|
||||||
|
request.revert_actions
|
||||||
|
and request.from_message_id is not None
|
||||||
|
and not legacy_from_message
|
||||||
|
):
|
||||||
|
seen_turns: set[str] = set()
|
||||||
|
for msg in messages_to_delete:
|
||||||
|
tid = msg.turn_id
|
||||||
|
if tid and tid not in seen_turns:
|
||||||
|
seen_turns.add(tid)
|
||||||
|
revert_turn_ids.append(tid)
|
||||||
|
|
||||||
# Get search space for LLM config
|
# Get search space for LLM config
|
||||||
search_space_result = await session.execute(
|
search_space_result = await session.execute(
|
||||||
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
|
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
|
||||||
|
|
@ -1507,6 +2055,24 @@ async def regenerate_response(
|
||||||
# This prevents data loss if streaming fails (network error, LLM error, etc.)
|
# This prevents data loss if streaming fails (network error, LLM error, etc.)
|
||||||
async def stream_with_cleanup():
|
async def stream_with_cleanup():
|
||||||
streaming_completed = False
|
streaming_completed = False
|
||||||
|
# Best-effort revert pass BEFORE the regenerate stream begins.
|
||||||
|
# Each turn is reverted independently (per-row SAVEPOINTs
|
||||||
|
# inside the route helper) and the per-action results are surfaced
|
||||||
|
# on a single ``data-revert-results`` SSE event so the frontend
|
||||||
|
# can render any failed rows alongside the new turn. Failures here
|
||||||
|
# do NOT abort the regeneration — partial rollback is documented
|
||||||
|
# behaviour.
|
||||||
|
if revert_turn_ids:
|
||||||
|
revert_results = await _revert_turns_for_regenerate(
|
||||||
|
thread_id=thread_id,
|
||||||
|
chat_turn_ids=revert_turn_ids,
|
||||||
|
requester_user_id=str(user.id),
|
||||||
|
)
|
||||||
|
envelope = {
|
||||||
|
"type": "data-revert-results",
|
||||||
|
"data": revert_results,
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(envelope, default=str)}\n\n".encode()
|
||||||
try:
|
try:
|
||||||
async for chunk in stream_new_chat(
|
async for chunk in stream_new_chat(
|
||||||
user_query=str(user_query_to_use),
|
user_query=str(user_query_to_use),
|
||||||
|
|
@ -1524,6 +2090,7 @@ async def regenerate_response(
|
||||||
filesystem_selection=filesystem_selection,
|
filesystem_selection=filesystem_selection,
|
||||||
request_id=getattr(http_request.state, "request_id", "unknown"),
|
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||||
user_image_data_urls=regenerate_image_urls or None,
|
user_image_data_urls=regenerate_image_urls or None,
|
||||||
|
flow="regenerate",
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
streaming_completed = True
|
streaming_completed = True
|
||||||
|
|
@ -1611,6 +2178,7 @@ async def resume_chat(
|
||||||
)
|
)
|
||||||
|
|
||||||
await check_thread_access(session, thread, user)
|
await check_thread_access(session, thread, user)
|
||||||
|
_raise_if_thread_busy_for_start(thread_id)
|
||||||
filesystem_selection = _resolve_filesystem_selection(
|
filesystem_selection = _resolve_filesystem_selection(
|
||||||
mode=request.filesystem_mode,
|
mode=request.filesystem_mode,
|
||||||
client_platform=request.client_platform,
|
client_platform=request.client_platform,
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import logging
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from pydantic import BaseModel as PydanticBaseModel
|
from pydantic import BaseModel as PydanticBaseModel
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func, update
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
|
@ -15,6 +15,7 @@ from app.agents.new_chat.tools.update_memory import MEMORY_HARD_LIMIT, _save_mem
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.db import (
|
from app.db import (
|
||||||
ImageGenerationConfig,
|
ImageGenerationConfig,
|
||||||
|
NewChatThread,
|
||||||
NewLLMConfig,
|
NewLLMConfig,
|
||||||
Permission,
|
Permission,
|
||||||
SearchSpace,
|
SearchSpace,
|
||||||
|
|
@ -790,9 +791,27 @@ async def update_llm_preferences(
|
||||||
|
|
||||||
# Update preferences
|
# Update preferences
|
||||||
update_data = preferences.model_dump(exclude_unset=True)
|
update_data = preferences.model_dump(exclude_unset=True)
|
||||||
|
previous_agent_llm_id = search_space.agent_llm_id
|
||||||
for key, value in update_data.items():
|
for key, value in update_data.items():
|
||||||
setattr(search_space, key, value)
|
setattr(search_space, key, value)
|
||||||
|
|
||||||
|
agent_llm_changed = (
|
||||||
|
"agent_llm_id" in update_data
|
||||||
|
and update_data["agent_llm_id"] != previous_agent_llm_id
|
||||||
|
)
|
||||||
|
if agent_llm_changed:
|
||||||
|
await session.execute(
|
||||||
|
update(NewChatThread)
|
||||||
|
.where(NewChatThread.search_space_id == search_space_id)
|
||||||
|
.values(pinned_llm_config_id=None)
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Cleared auto model pins for search_space_id=%s after agent_llm_id change (%s -> %s)",
|
||||||
|
search_space_id,
|
||||||
|
previous_agent_llm_id,
|
||||||
|
update_data["agent_llm_id"],
|
||||||
|
)
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(search_space)
|
await session.refresh(search_space)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,11 @@ class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
|
||||||
author_display_name: str | None = None
|
author_display_name: str | None = None
|
||||||
author_avatar_url: str | None = None
|
author_avatar_url: str | None = None
|
||||||
token_usage: TokenUsageSummary | None = None
|
token_usage: TokenUsageSummary | None = None
|
||||||
|
# Per-turn correlation id (``f"{chat_id}:{ms}"``) from
|
||||||
|
# ``configurable.turn_id`` at streaming time. Nullable because
|
||||||
|
# legacy rows predate the column; clients should treat NULL as
|
||||||
|
# "edit-from-this-message is unavailable".
|
||||||
|
turn_id: str | None = None
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -241,6 +246,15 @@ class RegenerateRequest(BaseModel):
|
||||||
|
|
||||||
For edit, optional user_images (when not None) replaces image URLs resolved from
|
For edit, optional user_images (when not None) replaces image URLs resolved from
|
||||||
checkpoint/DB so the client can send the full user turn (text and/or images).
|
checkpoint/DB so the client can send the full user turn (text and/or images).
|
||||||
|
|
||||||
|
Edit-from-arbitrary-position. When ``from_message_id`` is provided
|
||||||
|
the route slices conversation history starting at that message (instead of
|
||||||
|
the legacy "last 2 messages" rewind), rewinds the LangGraph checkpoint by
|
||||||
|
matching ``configurable.turn_id`` stored on the message (added in migration 136), and
|
||||||
|
optionally reverts every reversible action emitted in turns at or after
|
||||||
|
``from_message_id``. The revert step is best-effort and runs BEFORE the
|
||||||
|
regenerate stream — partial failures are surfaced via SSE
|
||||||
|
``data-revert-results`` and do not abort the regeneration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
search_space_id: int
|
search_space_id: int
|
||||||
|
|
@ -257,6 +271,28 @@ class RegenerateRequest(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="If set, use these images for the regenerated turn (edit); overrides checkpoint/DB",
|
description="If set, use these images for the regenerated turn (edit); overrides checkpoint/DB",
|
||||||
)
|
)
|
||||||
|
from_message_id: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Message id to rewind to. When set, history is sliced "
|
||||||
|
"from this message forward and the LangGraph checkpoint is "
|
||||||
|
"rewound to the state immediately preceding this turn. Legacy "
|
||||||
|
"rows that predate migration 136 have ``turn_id=None`` and "
|
||||||
|
"still process — the route logs a warning, skips the "
|
||||||
|
"checkpoint rewind, and ignores ``revert_actions`` (no "
|
||||||
|
"chat_turn_id available to walk)."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
revert_actions: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"When true, every reversible action emitted at or "
|
||||||
|
"after ``from_message_id`` is reverted before the regenerate "
|
||||||
|
"stream begins. Per-action results are surfaced via the "
|
||||||
|
"``data-revert-results`` SSE event. Partial failures DO NOT "
|
||||||
|
"abort the regeneration."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def _validate_regenerate_user_images(self) -> Self:
|
def _validate_regenerate_user_images(self) -> Self:
|
||||||
|
|
@ -264,6 +300,14 @@ class RegenerateRequest(BaseModel):
|
||||||
raise ValueError(f"At most {MAX_NEW_CHAT_IMAGES} images allowed")
|
raise ValueError(f"At most {MAX_NEW_CHAT_IMAGES} images allowed")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _validate_revert_actions_requires_from_message(self) -> Self:
|
||||||
|
if self.revert_actions and self.from_message_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"revert_actions requires from_message_id; specify which message to rewind to"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Agent Tools Schemas
|
# Agent Tools Schemas
|
||||||
|
|
@ -291,6 +335,24 @@ class ResumeRequest(BaseModel):
|
||||||
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
|
local_filesystem_mounts: list[LocalFilesystemMountPayload] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CancelActiveTurnResponse(BaseModel):
|
||||||
|
"""Response for canceling an active turn on a chat thread."""
|
||||||
|
|
||||||
|
status: Literal["cancelling", "idle"]
|
||||||
|
error_code: Literal["TURN_CANCELLING", "NO_ACTIVE_TURN"]
|
||||||
|
retry_after_ms: int | None = None
|
||||||
|
retry_after_at: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TurnStatusResponse(BaseModel):
|
||||||
|
"""Current turn execution status for a thread."""
|
||||||
|
|
||||||
|
status: Literal["idle", "busy", "cancelling"]
|
||||||
|
active_turn_id: str | None = None
|
||||||
|
retry_after_ms: int | None = None
|
||||||
|
retry_after_at: int | None = None
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Public Chat Snapshot Schemas
|
# Public Chat Snapshot Schemas
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
|
||||||
385
surfsense_backend/app/services/auto_model_pin_service.py
Normal file
385
surfsense_backend/app/services/auto_model_pin_service.py
Normal file
|
|
@ -0,0 +1,385 @@
|
||||||
|
"""Resolve and persist Auto (Fastest) model pins per chat thread.
|
||||||
|
|
||||||
|
Auto (Fastest) is represented by ``agent_llm_id == 0``. For chat threads we
|
||||||
|
resolve that virtual mode to one concrete global LLM config exactly once and
|
||||||
|
persist the chosen config id on ``new_chat_threads.pinned_llm_config_id`` so
|
||||||
|
subsequent turns are stable.
|
||||||
|
|
||||||
|
Single-writer invariant: this module is the only writer of
|
||||||
|
``NewChatThread.pinned_llm_config_id`` (aside from the bulk clear in
|
||||||
|
``search_spaces_routes`` when a search space's ``agent_llm_id`` changes).
|
||||||
|
Therefore a non-NULL value unambiguously means "this thread has an
|
||||||
|
Auto-resolved pin"; no separate source/policy column is needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.db import NewChatThread
|
||||||
|
from app.services.quality_score import _QUALITY_TOP_K
|
||||||
|
from app.services.token_quota_service import TokenQuotaService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
AUTO_FASTEST_ID = 0
|
||||||
|
AUTO_FASTEST_MODE = "auto_fastest"
|
||||||
|
_RUNTIME_COOLDOWN_SECONDS = 600
|
||||||
|
_HEALTHY_TTL_SECONDS = 45
|
||||||
|
|
||||||
|
# In-memory runtime cooldown map for configs that recently hard-failed at
|
||||||
|
# provider runtime (e.g. OpenRouter 429 on a pinned free model). This keeps
|
||||||
|
# the same unhealthy config from being reselected immediately during repair.
|
||||||
|
_runtime_cooldown_until: dict[int, float] = {}
|
||||||
|
_runtime_cooldown_lock = threading.Lock()
|
||||||
|
|
||||||
|
# Short-TTL "recently healthy" cache for configs that just passed a runtime
|
||||||
|
# preflight ping. Lets back-to-back turns on the same model skip the probe
|
||||||
|
# without eroding correctness — entries auto-expire and are wiped any time
|
||||||
|
# the same config is cooled down or the OR catalogue is refreshed.
|
||||||
|
_healthy_until: dict[int, float] = {}
|
||||||
|
_healthy_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AutoPinResolution:
|
||||||
|
resolved_llm_config_id: int
|
||||||
|
resolved_tier: str
|
||||||
|
from_existing_pin: bool
|
||||||
|
|
||||||
|
|
||||||
|
def _is_usable_global_config(cfg: dict) -> bool:
|
||||||
|
return bool(
|
||||||
|
cfg.get("id") is not None
|
||||||
|
and cfg.get("model_name")
|
||||||
|
and cfg.get("provider")
|
||||||
|
and cfg.get("api_key")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _prune_runtime_cooldowns(now_ts: float | None = None) -> None:
|
||||||
|
now = time.time() if now_ts is None else now_ts
|
||||||
|
stale = [cid for cid, until in _runtime_cooldown_until.items() if until <= now]
|
||||||
|
for cid in stale:
|
||||||
|
_runtime_cooldown_until.pop(cid, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_runtime_cooled_down(config_id: int) -> bool:
|
||||||
|
with _runtime_cooldown_lock:
|
||||||
|
_prune_runtime_cooldowns()
|
||||||
|
return config_id in _runtime_cooldown_until
|
||||||
|
|
||||||
|
|
||||||
|
def mark_runtime_cooldown(
|
||||||
|
config_id: int,
|
||||||
|
*,
|
||||||
|
reason: str = "rate_limited",
|
||||||
|
cooldown_seconds: int = _RUNTIME_COOLDOWN_SECONDS,
|
||||||
|
) -> None:
|
||||||
|
"""Temporarily suppress a config from Auto selection.
|
||||||
|
|
||||||
|
Used by runtime error handlers (e.g. OpenRouter 429) so an already pinned
|
||||||
|
config that is currently unhealthy does not get immediately reused on the
|
||||||
|
same thread during repair.
|
||||||
|
"""
|
||||||
|
if cooldown_seconds <= 0:
|
||||||
|
cooldown_seconds = _RUNTIME_COOLDOWN_SECONDS
|
||||||
|
until = time.time() + int(cooldown_seconds)
|
||||||
|
with _runtime_cooldown_lock:
|
||||||
|
_runtime_cooldown_until[int(config_id)] = until
|
||||||
|
_prune_runtime_cooldowns()
|
||||||
|
# A cooled cfg can never be "recently healthy"; drop any stale credit so
|
||||||
|
# the next turn that resolves to it (after cooldown) re-runs preflight.
|
||||||
|
clear_healthy(int(config_id))
|
||||||
|
logger.info(
|
||||||
|
"auto_pin_runtime_cooled_down config_id=%s reason=%s cooldown_seconds=%s",
|
||||||
|
config_id,
|
||||||
|
reason,
|
||||||
|
cooldown_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_runtime_cooldown(config_id: int | None = None) -> None:
|
||||||
|
"""Test/ops helper to clear runtime cooldown entries."""
|
||||||
|
with _runtime_cooldown_lock:
|
||||||
|
if config_id is None:
|
||||||
|
_runtime_cooldown_until.clear()
|
||||||
|
return
|
||||||
|
_runtime_cooldown_until.pop(int(config_id), None)
|
||||||
|
|
||||||
|
|
||||||
|
def _prune_healthy(now_ts: float | None = None) -> None:
|
||||||
|
now = time.time() if now_ts is None else now_ts
|
||||||
|
stale = [cid for cid, until in _healthy_until.items() if until <= now]
|
||||||
|
for cid in stale:
|
||||||
|
_healthy_until.pop(cid, None)
|
||||||
|
|
||||||
|
|
||||||
|
def is_recently_healthy(config_id: int) -> bool:
|
||||||
|
"""Return True if ``config_id`` passed preflight within the TTL window."""
|
||||||
|
with _healthy_lock:
|
||||||
|
_prune_healthy()
|
||||||
|
return int(config_id) in _healthy_until
|
||||||
|
|
||||||
|
|
||||||
|
def mark_healthy(
|
||||||
|
config_id: int,
|
||||||
|
*,
|
||||||
|
ttl_seconds: int = _HEALTHY_TTL_SECONDS,
|
||||||
|
) -> None:
|
||||||
|
"""Record that ``config_id`` just passed a preflight probe.
|
||||||
|
|
||||||
|
Subsequent calls within ``ttl_seconds`` can skip the preflight ping. The
|
||||||
|
healthy state is intentionally process-local — it's a latency hint, not a
|
||||||
|
correctness primitive — so multi-worker drift is acceptable.
|
||||||
|
"""
|
||||||
|
if ttl_seconds <= 0:
|
||||||
|
ttl_seconds = _HEALTHY_TTL_SECONDS
|
||||||
|
until = time.time() + int(ttl_seconds)
|
||||||
|
with _healthy_lock:
|
||||||
|
_healthy_until[int(config_id)] = until
|
||||||
|
_prune_healthy()
|
||||||
|
|
||||||
|
|
||||||
|
def clear_healthy(config_id: int | None = None) -> None:
|
||||||
|
"""Drop one (or all) healthy-cache entries.
|
||||||
|
|
||||||
|
Called from runtime cooldown and OR catalogue refresh so a freshly cooled
|
||||||
|
or replaced config never carries stale "healthy" credit.
|
||||||
|
"""
|
||||||
|
with _healthy_lock:
|
||||||
|
if config_id is None:
|
||||||
|
_healthy_until.clear()
|
||||||
|
return
|
||||||
|
_healthy_until.pop(int(config_id), None)
|
||||||
|
|
||||||
|
|
||||||
|
def _global_candidates() -> list[dict]:
|
||||||
|
"""Return Auto-eligible global cfgs.
|
||||||
|
|
||||||
|
Drops cfgs flagged ``health_gated`` (best non-null OpenRouter uptime
|
||||||
|
below ``_HEALTH_GATE_UPTIME_PCT``) so chronically broken providers
|
||||||
|
can't be picked as the thread's pin. Also excludes configs currently
|
||||||
|
in runtime cooldown (e.g. temporary 429 bursts).
|
||||||
|
"""
|
||||||
|
candidates = [
|
||||||
|
cfg
|
||||||
|
for cfg in config.GLOBAL_LLM_CONFIGS
|
||||||
|
if _is_usable_global_config(cfg)
|
||||||
|
and not cfg.get("health_gated")
|
||||||
|
and not _is_runtime_cooled_down(int(cfg.get("id", 0)))
|
||||||
|
]
|
||||||
|
return sorted(candidates, key=lambda c: int(c.get("id", 0)))
|
||||||
|
|
||||||
|
|
||||||
|
def _tier_of(cfg: dict) -> str:
|
||||||
|
return str(cfg.get("billing_tier", "free")).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _select_pin(eligible: list[dict], thread_id: int) -> tuple[dict, int]:
|
||||||
|
"""Pick a config with quality-first ranking + deterministic spread.
|
||||||
|
|
||||||
|
Tier policy is lock-first: prefer Tier A (operator-curated YAML)
|
||||||
|
cfgs and only fall through to Tier B/C (dynamic OpenRouter) if no
|
||||||
|
Tier A cfg is eligible after upstream filters. Within the locked
|
||||||
|
pool, sort by ``quality_score`` and pick from the top-K via
|
||||||
|
``SHA256(thread_id)`` so different new threads spread across the
|
||||||
|
best models without ever picking a low-ranked one.
|
||||||
|
|
||||||
|
Returns ``(chosen_cfg, top_k_size)``. ``top_k_size`` is exposed for
|
||||||
|
structured logging in the caller.
|
||||||
|
"""
|
||||||
|
tier_a = [c for c in eligible if c.get("auto_pin_tier") in (None, "A")]
|
||||||
|
pool = tier_a if tier_a else eligible
|
||||||
|
pool = sorted(pool, key=lambda c: -int(c.get("quality_score") or 0))
|
||||||
|
top_k = pool[:_QUALITY_TOP_K]
|
||||||
|
digest = hashlib.sha256(f"{AUTO_FASTEST_MODE}:{thread_id}".encode()).digest()
|
||||||
|
idx = int.from_bytes(digest[:8], "big") % len(top_k)
|
||||||
|
return top_k[idx], len(top_k)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_uuid(user_id: str | UUID | None) -> UUID | None:
|
||||||
|
if user_id is None:
|
||||||
|
return None
|
||||||
|
if isinstance(user_id, UUID):
|
||||||
|
return user_id
|
||||||
|
try:
|
||||||
|
return UUID(str(user_id))
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _is_premium_eligible(
|
||||||
|
session: AsyncSession, user_id: str | UUID | None
|
||||||
|
) -> bool:
|
||||||
|
parsed = _to_uuid(user_id)
|
||||||
|
if parsed is None:
|
||||||
|
return False
|
||||||
|
usage = await TokenQuotaService.premium_get_usage(session, parsed)
|
||||||
|
return bool(usage.allowed)
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_or_get_pinned_llm_config_id(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
thread_id: int,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str | UUID | None,
|
||||||
|
selected_llm_config_id: int,
|
||||||
|
force_repin_free: bool = False,
|
||||||
|
exclude_config_ids: set[int] | None = None,
|
||||||
|
) -> AutoPinResolution:
|
||||||
|
"""Resolve Auto (Fastest) to one concrete config id and persist the pin.
|
||||||
|
|
||||||
|
For non-auto selections, this function clears any existing pin and returns
|
||||||
|
the selected id as-is.
|
||||||
|
"""
|
||||||
|
thread = (
|
||||||
|
(
|
||||||
|
await session.execute(
|
||||||
|
select(NewChatThread)
|
||||||
|
.where(NewChatThread.id == thread_id)
|
||||||
|
.with_for_update(of=NewChatThread)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.unique()
|
||||||
|
.scalar_one_or_none()
|
||||||
|
)
|
||||||
|
if thread is None:
|
||||||
|
raise ValueError(f"Thread {thread_id} not found")
|
||||||
|
if thread.search_space_id != search_space_id:
|
||||||
|
raise ValueError(
|
||||||
|
f"Thread {thread_id} does not belong to search space {search_space_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Explicit model selected: clear any stale pin.
|
||||||
|
if selected_llm_config_id != AUTO_FASTEST_ID:
|
||||||
|
if thread.pinned_llm_config_id is not None:
|
||||||
|
thread.pinned_llm_config_id = None
|
||||||
|
await session.commit()
|
||||||
|
return AutoPinResolution(
|
||||||
|
resolved_llm_config_id=selected_llm_config_id,
|
||||||
|
resolved_tier="explicit",
|
||||||
|
from_existing_pin=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
excluded_ids = {int(cid) for cid in (exclude_config_ids or set())}
|
||||||
|
candidates = [
|
||||||
|
c for c in _global_candidates() if int(c.get("id", 0)) not in excluded_ids
|
||||||
|
]
|
||||||
|
if not candidates:
|
||||||
|
raise ValueError("No usable global LLM configs are available for Auto mode")
|
||||||
|
candidate_by_id = {int(c["id"]): c for c in candidates}
|
||||||
|
|
||||||
|
# Reuse an existing valid pin without re-checking current quota (no silent
|
||||||
|
# tier switch), unless the caller explicitly requests a forced repin to free.
|
||||||
|
pinned_id = thread.pinned_llm_config_id
|
||||||
|
if (
|
||||||
|
not force_repin_free
|
||||||
|
and pinned_id is not None
|
||||||
|
and int(pinned_id) in candidate_by_id
|
||||||
|
):
|
||||||
|
pinned_cfg = candidate_by_id[int(pinned_id)]
|
||||||
|
logger.info(
|
||||||
|
"auto_pin_reused thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s",
|
||||||
|
thread_id,
|
||||||
|
search_space_id,
|
||||||
|
pinned_id,
|
||||||
|
_tier_of(pinned_cfg),
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"auto_pin_resolved thread_id=%s config_id=%s tier=%s "
|
||||||
|
"auto_pin_tier=%s score=%s top_k_size=0 from_existing_pin=True",
|
||||||
|
thread_id,
|
||||||
|
pinned_id,
|
||||||
|
_tier_of(pinned_cfg),
|
||||||
|
pinned_cfg.get("auto_pin_tier", "?"),
|
||||||
|
int(pinned_cfg.get("quality_score") or 0),
|
||||||
|
)
|
||||||
|
return AutoPinResolution(
|
||||||
|
resolved_llm_config_id=int(pinned_id),
|
||||||
|
resolved_tier=_tier_of(pinned_cfg),
|
||||||
|
from_existing_pin=True,
|
||||||
|
)
|
||||||
|
if pinned_id is not None:
|
||||||
|
logger.info(
|
||||||
|
"auto_pin_invalid thread_id=%s search_space_id=%s pinned_config_id=%s",
|
||||||
|
thread_id,
|
||||||
|
search_space_id,
|
||||||
|
pinned_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
premium_eligible = (
|
||||||
|
False if force_repin_free else await _is_premium_eligible(session, user_id)
|
||||||
|
)
|
||||||
|
if premium_eligible:
|
||||||
|
eligible = candidates
|
||||||
|
else:
|
||||||
|
eligible = [c for c in candidates if _tier_of(c) != "premium"]
|
||||||
|
|
||||||
|
if not eligible:
|
||||||
|
raise ValueError(
|
||||||
|
"Auto mode could not find an eligible LLM config for this user and quota state"
|
||||||
|
)
|
||||||
|
|
||||||
|
selected_cfg, top_k_size = _select_pin(eligible, thread_id)
|
||||||
|
selected_id = int(selected_cfg["id"])
|
||||||
|
selected_tier = _tier_of(selected_cfg)
|
||||||
|
|
||||||
|
thread.pinned_llm_config_id = selected_id
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
if force_repin_free:
|
||||||
|
logger.info(
|
||||||
|
"auto_pin_forced_free_repin thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s",
|
||||||
|
thread_id,
|
||||||
|
search_space_id,
|
||||||
|
pinned_id,
|
||||||
|
selected_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if pinned_id is None:
|
||||||
|
logger.info(
|
||||||
|
"auto_pin_created thread_id=%s search_space_id=%s resolved_config_id=%s tier=%s premium_eligible=%s",
|
||||||
|
thread_id,
|
||||||
|
search_space_id,
|
||||||
|
selected_id,
|
||||||
|
selected_tier,
|
||||||
|
premium_eligible,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"auto_pin_repaired thread_id=%s search_space_id=%s previous_config_id=%s resolved_config_id=%s tier=%s premium_eligible=%s",
|
||||||
|
thread_id,
|
||||||
|
search_space_id,
|
||||||
|
pinned_id,
|
||||||
|
selected_id,
|
||||||
|
selected_tier,
|
||||||
|
premium_eligible,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"auto_pin_resolved thread_id=%s config_id=%s tier=%s "
|
||||||
|
"auto_pin_tier=%s score=%s top_k_size=%d from_existing_pin=False",
|
||||||
|
thread_id,
|
||||||
|
selected_id,
|
||||||
|
selected_tier,
|
||||||
|
selected_cfg.get("auto_pin_tier", "?"),
|
||||||
|
int(selected_cfg.get("quality_score") or 0),
|
||||||
|
top_k_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return AutoPinResolution(
|
||||||
|
resolved_llm_config_id=selected_id,
|
||||||
|
resolved_tier=selected_tier,
|
||||||
|
from_existing_pin=False,
|
||||||
|
)
|
||||||
|
|
@ -28,6 +28,7 @@ from litellm.exceptions import (
|
||||||
BadRequestError as LiteLLMBadRequestError,
|
BadRequestError as LiteLLMBadRequestError,
|
||||||
ContextWindowExceededError,
|
ContextWindowExceededError,
|
||||||
)
|
)
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
from app.utils.perf import get_perf_logger
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
|
|
@ -207,6 +208,12 @@ class LLMRouterService:
|
||||||
"""
|
"""
|
||||||
Initialize the router with global LLM configurations.
|
Initialize the router with global LLM configurations.
|
||||||
|
|
||||||
|
Configs with ``router_pool_eligible=False`` are skipped so that
|
||||||
|
dynamic OpenRouter entries stay out of the shared router pool used
|
||||||
|
by title-gen / sub-agent ``model="auto"`` flows. Those dynamic
|
||||||
|
entries are still available for user-facing Auto-mode thread pinning
|
||||||
|
via ``auto_model_pin_service``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
global_configs: List of global LLM config dictionaries from YAML
|
global_configs: List of global LLM config dictionaries from YAML
|
||||||
router_settings: Optional router settings (routing_strategy, num_retries, etc.)
|
router_settings: Optional router settings (routing_strategy, num_retries, etc.)
|
||||||
|
|
@ -220,6 +227,8 @@ class LLMRouterService:
|
||||||
model_list = []
|
model_list = []
|
||||||
premium_models: set[str] = set()
|
premium_models: set[str] = set()
|
||||||
for config in global_configs:
|
for config in global_configs:
|
||||||
|
if config.get("router_pool_eligible") is False:
|
||||||
|
continue
|
||||||
deployment = cls._config_to_deployment(config)
|
deployment = cls._config_to_deployment(config)
|
||||||
if deployment:
|
if deployment:
|
||||||
model_list.append(deployment)
|
model_list.append(deployment)
|
||||||
|
|
@ -308,10 +317,45 @@ class LLMRouterService:
|
||||||
logger.error(f"Failed to initialize LLM Router: {e}")
|
logger.error(f"Failed to initialize LLM Router: {e}")
|
||||||
instance._router = None
|
instance._router = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def rebuild(
|
||||||
|
cls,
|
||||||
|
global_configs: list[dict],
|
||||||
|
router_settings: dict | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Reset the router and re-run ``initialize`` with fresh configs.
|
||||||
|
|
||||||
|
``initialize`` short-circuits once it has run to avoid re-creating the
|
||||||
|
LiteLLM Router on every request; ``rebuild`` deliberately clears
|
||||||
|
``_initialized`` so a caller (e.g. background OpenRouter refresh)
|
||||||
|
can force the pool to be rebuilt after catalogue changes.
|
||||||
|
"""
|
||||||
|
instance = cls.get_instance()
|
||||||
|
instance._initialized = False
|
||||||
|
instance._router = None
|
||||||
|
instance._model_list = []
|
||||||
|
instance._premium_model_strings = set()
|
||||||
|
cls.initialize(global_configs, router_settings)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_premium_model(cls, model_string: str) -> bool:
|
def is_premium_model(cls, model_string: str) -> bool:
|
||||||
"""Return True if *model_string* (as reported by LiteLLM) belongs to a
|
"""Return True if *model_string* belongs to a premium-tier deployment
|
||||||
premium-tier deployment in the router pool."""
|
in the LiteLLM router pool.
|
||||||
|
|
||||||
|
Scope: only covers configs with ``router_pool_eligible`` truthy. That
|
||||||
|
includes static YAML premium configs AND dynamic OpenRouter *premium*
|
||||||
|
entries (which opt in at generation time). Dynamic OpenRouter *free*
|
||||||
|
entries are deliberately kept out of the router pool — OpenRouter
|
||||||
|
enforces free-tier limits globally per account, so per-deployment
|
||||||
|
router accounting can't represent them correctly — and therefore
|
||||||
|
return ``False`` here, which matches their ``billing_tier="free"``
|
||||||
|
(no premium quota).
|
||||||
|
|
||||||
|
For per-request premium checks on an arbitrary config (static or
|
||||||
|
dynamic, pool or non-pool), read ``agent_config.is_premium`` instead;
|
||||||
|
that reflects the per-config ``billing_tier`` directly and is what
|
||||||
|
user-facing Auto-mode thread pinning uses to bill correctly.
|
||||||
|
"""
|
||||||
instance = cls.get_instance()
|
instance = cls.get_instance()
|
||||||
return model_string in instance._premium_model_strings
|
return model_string in instance._premium_model_strings
|
||||||
|
|
||||||
|
|
@ -573,6 +617,11 @@ class ChatLiteLLMRouter(BaseChatModel):
|
||||||
# Public attributes that Pydantic will manage
|
# Public attributes that Pydantic will manage
|
||||||
model: str = "auto"
|
model: str = "auto"
|
||||||
streaming: bool = True
|
streaming: bool = True
|
||||||
|
# Static kwargs that flow through to ``litellm.completion(...)`` on every
|
||||||
|
# invocation (e.g. ``cache_control_injection_points`` set by
|
||||||
|
# ``apply_litellm_prompt_caching``). Per-call ``**kwargs`` from
|
||||||
|
# ``invoke()`` still take precedence — see ``_generate``/``_astream``.
|
||||||
|
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
# Bound tools and tool choice for tool calling
|
# Bound tools and tool choice for tool calling
|
||||||
_bound_tools: list[dict] | None = None
|
_bound_tools: list[dict] | None = None
|
||||||
|
|
@ -898,13 +947,16 @@ class ChatLiteLLMRouter(BaseChatModel):
|
||||||
logger.warning(f"Failed to convert tool {tool}: {e}")
|
logger.warning(f"Failed to convert tool {tool}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Create a new instance with tools bound
|
# Create a new instance with tools bound. Carry through ``model_kwargs``
|
||||||
|
# so static settings (e.g. cache_control_injection_points) survive the
|
||||||
|
# bind_tools rebuild.
|
||||||
return ChatLiteLLMRouter(
|
return ChatLiteLLMRouter(
|
||||||
router=self._router,
|
router=self._router,
|
||||||
bound_tools=formatted_tools if formatted_tools else None,
|
bound_tools=formatted_tools if formatted_tools else None,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
streaming=self.streaming,
|
streaming=self.streaming,
|
||||||
|
model_kwargs=dict(self.model_kwargs),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -929,8 +981,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
||||||
formatted_messages = self._convert_messages(messages)
|
formatted_messages = self._convert_messages(messages)
|
||||||
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
||||||
|
|
||||||
# Add tools if bound
|
# Merge static model_kwargs (e.g. cache_control_injection_points) under
|
||||||
call_kwargs = {**kwargs}
|
# per-call kwargs so callers can still override per invocation. Then add
|
||||||
|
# bound tools.
|
||||||
|
call_kwargs = {**self.model_kwargs, **kwargs}
|
||||||
if self._bound_tools:
|
if self._bound_tools:
|
||||||
call_kwargs["tools"] = self._bound_tools
|
call_kwargs["tools"] = self._bound_tools
|
||||||
if self._tool_choice is not None:
|
if self._tool_choice is not None:
|
||||||
|
|
@ -997,8 +1051,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
||||||
formatted_messages = self._convert_messages(messages)
|
formatted_messages = self._convert_messages(messages)
|
||||||
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
||||||
|
|
||||||
# Add tools if bound
|
# Merge static model_kwargs (e.g. cache_control_injection_points) under
|
||||||
call_kwargs = {**kwargs}
|
# per-call kwargs so callers can still override per invocation. Then add
|
||||||
|
# bound tools.
|
||||||
|
call_kwargs = {**self.model_kwargs, **kwargs}
|
||||||
if self._bound_tools:
|
if self._bound_tools:
|
||||||
call_kwargs["tools"] = self._bound_tools
|
call_kwargs["tools"] = self._bound_tools
|
||||||
if self._tool_choice is not None:
|
if self._tool_choice is not None:
|
||||||
|
|
@ -1060,8 +1116,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
||||||
formatted_messages = self._convert_messages(messages)
|
formatted_messages = self._convert_messages(messages)
|
||||||
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
||||||
|
|
||||||
# Add tools if bound
|
# Merge static model_kwargs (e.g. cache_control_injection_points) under
|
||||||
call_kwargs = {**kwargs}
|
# per-call kwargs so callers can still override per invocation. Then add
|
||||||
|
# bound tools.
|
||||||
|
call_kwargs = {**self.model_kwargs, **kwargs}
|
||||||
if self._bound_tools:
|
if self._bound_tools:
|
||||||
call_kwargs["tools"] = self._bound_tools
|
call_kwargs["tools"] = self._bound_tools
|
||||||
if self._tool_choice is not None:
|
if self._tool_choice is not None:
|
||||||
|
|
@ -1110,8 +1168,10 @@ class ChatLiteLLMRouter(BaseChatModel):
|
||||||
formatted_messages = self._convert_messages(messages)
|
formatted_messages = self._convert_messages(messages)
|
||||||
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
formatted_messages = self._trim_messages_to_fit_context(formatted_messages)
|
||||||
|
|
||||||
# Add tools if bound
|
# Merge static model_kwargs (e.g. cache_control_injection_points) under
|
||||||
call_kwargs = {**kwargs}
|
# per-call kwargs so callers can still override per invocation. Then add
|
||||||
|
# bound tools.
|
||||||
|
call_kwargs = {**self.model_kwargs, **kwargs}
|
||||||
if self._bound_tools:
|
if self._bound_tools:
|
||||||
call_kwargs["tools"] = self._bound_tools
|
call_kwargs["tools"] = self._bound_tools
|
||||||
if self._tool_choice is not None:
|
if self._tool_choice is not None:
|
||||||
|
|
|
||||||
|
|
@ -565,32 +565,63 @@ class VercelStreamingService:
|
||||||
# Error Part
|
# Error Part
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
|
||||||
def format_error(self, error_text: str) -> str:
|
def format_error(
|
||||||
|
self,
|
||||||
|
error_text: str,
|
||||||
|
error_code: str | None = None,
|
||||||
|
extra: dict[str, object] | None = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Format an error message.
|
Format an error message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
error_text: The error message text
|
error_text: The error message text
|
||||||
|
error_code: Optional machine-readable error code for frontend branching
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: SSE formatted error part
|
str: SSE formatted error part
|
||||||
|
|
||||||
Example output:
|
Example output:
|
||||||
data: {"type":"error","errorText":"Something went wrong"}
|
data: {"type":"error","errorText":"Something went wrong","errorCode":"SOME_CODE"}
|
||||||
"""
|
"""
|
||||||
return self._format_sse({"type": "error", "errorText": error_text})
|
payload: dict[str, object] = {"type": "error", "errorText": error_text}
|
||||||
|
if error_code:
|
||||||
|
payload["errorCode"] = error_code
|
||||||
|
if extra:
|
||||||
|
payload.update(extra)
|
||||||
|
return self._format_sse(payload)
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# Tool Parts
|
# Tool Parts
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
|
||||||
def format_tool_input_start(self, tool_call_id: str, tool_name: str) -> str:
|
def format_tool_input_start(
|
||||||
|
self,
|
||||||
|
tool_call_id: str,
|
||||||
|
tool_name: str,
|
||||||
|
*,
|
||||||
|
langchain_tool_call_id: str | None = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Format the start of tool input streaming.
|
Format the start of tool input streaming.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_call_id: The unique tool call identifier
|
tool_call_id: The unique tool call identifier. May be EITHER the
|
||||||
tool_name: The name of the tool being called
|
synthetic ``call_<run_id>`` id derived from LangGraph
|
||||||
|
``run_id`` (legacy / ``SURFSENSE_ENABLE_STREAM_PARITY_V2``
|
||||||
|
OFF, or the unmatched-fallback path under parity_v2) OR
|
||||||
|
the authoritative LangChain ``tool_call.id`` (parity_v2
|
||||||
|
path: when the provider streams ``tool_call_chunks`` we
|
||||||
|
register the ``index`` and reuse the lc-id as the card
|
||||||
|
id so live ``tool-input-delta`` events can be routed
|
||||||
|
without a downstream join). Either way, the same id is
|
||||||
|
preserved across ``tool-input-start`` / ``-delta`` /
|
||||||
|
``-available`` / ``tool-output-available`` for one call.
|
||||||
|
tool_name: The name of the tool being called.
|
||||||
|
langchain_tool_call_id: Optional authoritative LangChain
|
||||||
|
``tool_call.id``. When set, surfaces as
|
||||||
|
``langchainToolCallId`` so the frontend can join this card
|
||||||
|
to the action-log row written by ``ActionLogMiddleware``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: SSE formatted tool input start part
|
str: SSE formatted tool input start part
|
||||||
|
|
@ -598,13 +629,14 @@ class VercelStreamingService:
|
||||||
Example output:
|
Example output:
|
||||||
data: {"type":"tool-input-start","toolCallId":"call_abc123","toolName":"getWeather"}
|
data: {"type":"tool-input-start","toolCallId":"call_abc123","toolName":"getWeather"}
|
||||||
"""
|
"""
|
||||||
return self._format_sse(
|
payload: dict[str, Any] = {
|
||||||
{
|
|
||||||
"type": "tool-input-start",
|
"type": "tool-input-start",
|
||||||
"toolCallId": tool_call_id,
|
"toolCallId": tool_call_id,
|
||||||
"toolName": tool_name,
|
"toolName": tool_name,
|
||||||
}
|
}
|
||||||
)
|
if langchain_tool_call_id:
|
||||||
|
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||||
|
return self._format_sse(payload)
|
||||||
|
|
||||||
def format_tool_input_delta(self, tool_call_id: str, input_text_delta: str) -> str:
|
def format_tool_input_delta(self, tool_call_id: str, input_text_delta: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
@ -629,7 +661,12 @@ class VercelStreamingService:
|
||||||
)
|
)
|
||||||
|
|
||||||
def format_tool_input_available(
|
def format_tool_input_available(
|
||||||
self, tool_call_id: str, tool_name: str, input_data: dict[str, Any]
|
self,
|
||||||
|
tool_call_id: str,
|
||||||
|
tool_name: str,
|
||||||
|
input_data: dict[str, Any],
|
||||||
|
*,
|
||||||
|
langchain_tool_call_id: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Format the completion of tool input.
|
Format the completion of tool input.
|
||||||
|
|
@ -638,6 +675,8 @@ class VercelStreamingService:
|
||||||
tool_call_id: The tool call identifier
|
tool_call_id: The tool call identifier
|
||||||
tool_name: The name of the tool
|
tool_name: The name of the tool
|
||||||
input_data: The complete tool input parameters
|
input_data: The complete tool input parameters
|
||||||
|
langchain_tool_call_id: Optional authoritative LangChain
|
||||||
|
``tool_call.id`` (see ``format_tool_input_start``).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: SSE formatted tool input available part
|
str: SSE formatted tool input available part
|
||||||
|
|
@ -645,22 +684,34 @@ class VercelStreamingService:
|
||||||
Example output:
|
Example output:
|
||||||
data: {"type":"tool-input-available","toolCallId":"call_abc123","toolName":"getWeather","input":{"city":"SF"}}
|
data: {"type":"tool-input-available","toolCallId":"call_abc123","toolName":"getWeather","input":{"city":"SF"}}
|
||||||
"""
|
"""
|
||||||
return self._format_sse(
|
payload: dict[str, Any] = {
|
||||||
{
|
|
||||||
"type": "tool-input-available",
|
"type": "tool-input-available",
|
||||||
"toolCallId": tool_call_id,
|
"toolCallId": tool_call_id,
|
||||||
"toolName": tool_name,
|
"toolName": tool_name,
|
||||||
"input": input_data,
|
"input": input_data,
|
||||||
}
|
}
|
||||||
)
|
if langchain_tool_call_id:
|
||||||
|
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||||
|
return self._format_sse(payload)
|
||||||
|
|
||||||
def format_tool_output_available(self, tool_call_id: str, output: Any) -> str:
|
def format_tool_output_available(
|
||||||
|
self,
|
||||||
|
tool_call_id: str,
|
||||||
|
output: Any,
|
||||||
|
*,
|
||||||
|
langchain_tool_call_id: str | None = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Format tool execution output.
|
Format tool execution output.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool_call_id: The tool call identifier
|
tool_call_id: The tool call identifier
|
||||||
output: The tool execution result
|
output: The tool execution result
|
||||||
|
langchain_tool_call_id: Optional authoritative LangChain
|
||||||
|
``tool_call.id`` extracted from ``ToolMessage.tool_call_id``.
|
||||||
|
When set, the frontend can backfill any card whose
|
||||||
|
``langchainToolCallId`` was not yet known at
|
||||||
|
``tool-input-start`` time.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: SSE formatted tool output available part
|
str: SSE formatted tool output available part
|
||||||
|
|
@ -668,13 +719,14 @@ class VercelStreamingService:
|
||||||
Example output:
|
Example output:
|
||||||
data: {"type":"tool-output-available","toolCallId":"call_abc123","output":{"weather":"sunny"}}
|
data: {"type":"tool-output-available","toolCallId":"call_abc123","output":{"weather":"sunny"}}
|
||||||
"""
|
"""
|
||||||
return self._format_sse(
|
payload: dict[str, Any] = {
|
||||||
{
|
|
||||||
"type": "tool-output-available",
|
"type": "tool-output-available",
|
||||||
"toolCallId": tool_call_id,
|
"toolCallId": tool_call_id,
|
||||||
"output": output,
|
"output": output,
|
||||||
}
|
}
|
||||||
)
|
if langchain_tool_call_id:
|
||||||
|
payload["langchainToolCallId"] = langchain_tool_call_id
|
||||||
|
return self._format_sse(payload)
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# Step Parts
|
# Step Parts
|
||||||
|
|
|
||||||
|
|
@ -11,20 +11,81 @@ this service only manages the catalogue, not the inference path.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
from app.services.quality_score import (
|
||||||
|
_HEALTH_BLEND_WEIGHT,
|
||||||
|
_HEALTH_ENRICH_CONCURRENCY,
|
||||||
|
_HEALTH_ENRICH_TOP_N_FREE,
|
||||||
|
_HEALTH_ENRICH_TOP_N_PREMIUM,
|
||||||
|
_HEALTH_FAIL_RATIO_FALLBACK,
|
||||||
|
_HEALTH_FETCH_TIMEOUT_SEC,
|
||||||
|
aggregate_health,
|
||||||
|
static_score_or,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models"
|
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/models"
|
||||||
|
OPENROUTER_ENDPOINTS_URL_TEMPLATE = (
|
||||||
|
"https://openrouter.ai/api/v1/models/{model_id}/endpoints"
|
||||||
|
)
|
||||||
|
|
||||||
# Sentinel value stored on each generated config so we can distinguish
|
# Sentinel value stored on each generated config so we can distinguish
|
||||||
# dynamic OpenRouter entries from hand-written YAML entries during refresh.
|
# dynamic OpenRouter entries from hand-written YAML entries during refresh.
|
||||||
_OPENROUTER_DYNAMIC_MARKER = "__openrouter_dynamic__"
|
_OPENROUTER_DYNAMIC_MARKER = "__openrouter_dynamic__"
|
||||||
|
|
||||||
|
# Width of the hash space used by ``_stable_config_id``. 9_000_000 provides
|
||||||
|
# enough headroom to avoid frequent collisions for OpenRouter's catalogue
|
||||||
|
# (~300 models) while keeping IDs comfortably within Postgres INTEGER range.
|
||||||
|
_STABLE_ID_HASH_WIDTH = 9_000_000
|
||||||
|
|
||||||
|
|
||||||
|
def _stable_config_id(model_id: str, offset: int, taken: set[int]) -> int:
|
||||||
|
"""Derive a deterministic negative config ID from ``model_id``.
|
||||||
|
|
||||||
|
The same ``model_id`` always hashes to the same base value so thread pins
|
||||||
|
survive catalogue churn (models appearing/disappearing/reordering between
|
||||||
|
refreshes). On collision we decrement until we find an unused slot; this
|
||||||
|
keeps the mapping stable for the first config that claimed a slot and
|
||||||
|
only shifts collisions, which is much less disruptive than the legacy
|
||||||
|
index-based scheme that reshuffled every ID when the catalogue changed.
|
||||||
|
"""
|
||||||
|
digest = hashlib.blake2b(model_id.encode("utf-8"), digest_size=6).digest()
|
||||||
|
base = offset - (int.from_bytes(digest, "big") % _STABLE_ID_HASH_WIDTH)
|
||||||
|
cid = base
|
||||||
|
while cid in taken:
|
||||||
|
cid -= 1
|
||||||
|
taken.add(cid)
|
||||||
|
return cid
|
||||||
|
|
||||||
|
|
||||||
|
def _openrouter_tier(model: dict) -> str:
|
||||||
|
"""Classify an OpenRouter model as ``"free"`` or ``"premium"``.
|
||||||
|
|
||||||
|
Per OpenRouter's API contract, a model is free if:
|
||||||
|
- Its id ends with ``:free`` (OpenRouter's own free-variant convention), or
|
||||||
|
- Both ``pricing.prompt`` and ``pricing.completion`` are zero strings.
|
||||||
|
|
||||||
|
Anything else (missing pricing, non-zero pricing) falls through to
|
||||||
|
``"premium"`` so we never under-charge users. This derivation runs off the
|
||||||
|
already-cached /api/v1/models payload, so it adds no network cost.
|
||||||
|
"""
|
||||||
|
if model.get("id", "").endswith(":free"):
|
||||||
|
return "free"
|
||||||
|
pricing = model.get("pricing") or {}
|
||||||
|
prompt = str(pricing.get("prompt", "")).strip()
|
||||||
|
completion = str(pricing.get("completion", "")).strip()
|
||||||
|
if prompt == "0" and completion == "0":
|
||||||
|
return "free"
|
||||||
|
return "premium"
|
||||||
|
|
||||||
|
|
||||||
def _is_text_output_model(model: dict) -> bool:
|
def _is_text_output_model(model: dict) -> bool:
|
||||||
"""Return True if the model produces text output only (skip image/audio generators)."""
|
"""Return True if the model produces text output only (skip image/audio generators)."""
|
||||||
|
|
@ -56,6 +117,11 @@ _EXCLUDED_MODEL_IDS: set[str] = {
|
||||||
# Deep-research models reject standard params (temperature, etc.)
|
# Deep-research models reject standard params (temperature, etc.)
|
||||||
"openai/o3-deep-research",
|
"openai/o3-deep-research",
|
||||||
"openai/o4-mini-deep-research",
|
"openai/o4-mini-deep-research",
|
||||||
|
# OpenRouter's own meta-router over free models. We already enumerate every
|
||||||
|
# concrete ``:free`` model into GLOBAL_LLM_CONFIGS and Auto-mode thread
|
||||||
|
# pinning handles churn via the repair path, so exposing an additional
|
||||||
|
# indirection layer would only duplicate the capability with an opaque slug.
|
||||||
|
"openrouter/free",
|
||||||
}
|
}
|
||||||
|
|
||||||
_EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",)
|
_EXCLUDED_MODEL_SUFFIXES: tuple[str, ...] = ("-deep-research",)
|
||||||
|
|
@ -113,20 +179,41 @@ def _generate_configs(
|
||||||
raw_models: list[dict],
|
raw_models: list[dict],
|
||||||
settings: dict[str, Any],
|
settings: dict[str, Any],
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""Convert raw OpenRouter model entries into global LLM config dicts.
|
||||||
Convert raw OpenRouter model entries into global LLM config dicts.
|
|
||||||
|
|
||||||
Models are sorted by ID for deterministic, stable ID assignment across
|
Tier (``billing_tier``) is derived per-model from OpenRouter's own API
|
||||||
restarts and refreshes.
|
signals via ``_openrouter_tier`` — there is no longer a uniform YAML
|
||||||
|
override. Config IDs are derived via ``_stable_config_id`` so they
|
||||||
|
survive catalogue churn across refreshes.
|
||||||
|
|
||||||
|
Router-pool membership is tier-aware:
|
||||||
|
|
||||||
|
- Premium OR models join the LiteLLM router pool (``router_pool_eligible=True``)
|
||||||
|
so sub-agent ``model="auto"`` flows benefit from load balancing and
|
||||||
|
failover across the curated YAML configs and the OR premium passthrough.
|
||||||
|
- Free OR models stay excluded (``router_pool_eligible=False``). LiteLLM
|
||||||
|
Router tracks rate limits per deployment, but OpenRouter enforces a
|
||||||
|
single global free-tier quota (~20 RPM + 50-1000 daily requests
|
||||||
|
account-wide across every ``:free`` model), so rotating across many
|
||||||
|
free deployments would only burn the shared bucket faster. Free OR
|
||||||
|
models remain fully available for user-facing Auto-mode thread pinning
|
||||||
|
via ``auto_model_pin_service``.
|
||||||
|
|
||||||
|
OpenRouter's own ``openrouter/free`` meta-router is filtered out upstream
|
||||||
|
via ``_EXCLUDED_MODEL_IDS``; we don't expose a redundant auto-select layer
|
||||||
|
because our own Auto (Fastest) pin + 24 h refresh + repair logic already
|
||||||
|
cover the catalogue-churn case.
|
||||||
"""
|
"""
|
||||||
id_offset: int = settings.get("id_offset", -10000)
|
id_offset: int = settings.get("id_offset", -10000)
|
||||||
api_key: str = settings.get("api_key", "")
|
api_key: str = settings.get("api_key", "")
|
||||||
billing_tier: str = settings.get("billing_tier", "premium")
|
|
||||||
anonymous_enabled: bool = settings.get("anonymous_enabled", False)
|
|
||||||
seo_enabled: bool = settings.get("seo_enabled", False)
|
seo_enabled: bool = settings.get("seo_enabled", False)
|
||||||
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
|
quota_reserve_tokens: int = settings.get("quota_reserve_tokens", 4000)
|
||||||
rpm: int = settings.get("rpm", 200)
|
rpm: int = settings.get("rpm", 200)
|
||||||
tpm: int = settings.get("tpm", 1000000)
|
tpm: int = settings.get("tpm", 1_000_000)
|
||||||
|
free_rpm: int = settings.get("free_rpm", 20)
|
||||||
|
free_tpm: int = settings.get("free_tpm", 100_000)
|
||||||
|
anon_paid: bool = settings.get("anonymous_enabled_paid", False)
|
||||||
|
anon_free: bool = settings.get("anonymous_enabled_free", False)
|
||||||
litellm_params: dict = settings.get("litellm_params") or {}
|
litellm_params: dict = settings.get("litellm_params") or {}
|
||||||
system_instructions: str = settings.get("system_instructions", "")
|
system_instructions: str = settings.get("system_instructions", "")
|
||||||
use_default: bool = settings.get("use_default_system_instructions", True)
|
use_default: bool = settings.get("use_default_system_instructions", True)
|
||||||
|
|
@ -142,19 +229,24 @@ def _generate_configs(
|
||||||
and _is_allowed_model(m)
|
and _is_allowed_model(m)
|
||||||
and "/" in m.get("id", "")
|
and "/" in m.get("id", "")
|
||||||
]
|
]
|
||||||
text_models.sort(key=lambda m: m["id"])
|
|
||||||
|
|
||||||
configs: list[dict] = []
|
configs: list[dict] = []
|
||||||
for idx, model in enumerate(text_models):
|
taken: set[int] = set()
|
||||||
|
now_ts = int(time.time())
|
||||||
|
|
||||||
|
for model in text_models:
|
||||||
model_id: str = model["id"]
|
model_id: str = model["id"]
|
||||||
name: str = model.get("name", model_id)
|
name: str = model.get("name", model_id)
|
||||||
|
tier = _openrouter_tier(model)
|
||||||
|
|
||||||
|
static_q = static_score_or(model, now_ts=now_ts)
|
||||||
|
|
||||||
cfg: dict[str, Any] = {
|
cfg: dict[str, Any] = {
|
||||||
"id": id_offset - idx,
|
"id": _stable_config_id(model_id, id_offset, taken),
|
||||||
"name": name,
|
"name": name,
|
||||||
"description": f"{name} via OpenRouter",
|
"description": f"{name} via OpenRouter",
|
||||||
"billing_tier": billing_tier,
|
"billing_tier": tier,
|
||||||
"anonymous_enabled": anonymous_enabled,
|
"anonymous_enabled": anon_free if tier == "free" else anon_paid,
|
||||||
"seo_enabled": seo_enabled,
|
"seo_enabled": seo_enabled,
|
||||||
"seo_slug": None,
|
"seo_slug": None,
|
||||||
"quota_reserve_tokens": quota_reserve_tokens,
|
"quota_reserve_tokens": quota_reserve_tokens,
|
||||||
|
|
@ -162,13 +254,28 @@ def _generate_configs(
|
||||||
"model_name": model_id,
|
"model_name": model_id,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
"api_base": "",
|
"api_base": "",
|
||||||
"rpm": rpm,
|
"rpm": free_rpm if tier == "free" else rpm,
|
||||||
"tpm": tpm,
|
"tpm": free_tpm if tier == "free" else tpm,
|
||||||
"litellm_params": dict(litellm_params),
|
"litellm_params": dict(litellm_params),
|
||||||
"system_instructions": system_instructions,
|
"system_instructions": system_instructions,
|
||||||
"use_default_system_instructions": use_default,
|
"use_default_system_instructions": use_default,
|
||||||
"citations_enabled": citations_enabled,
|
"citations_enabled": citations_enabled,
|
||||||
|
# Premium OR deployments join the LiteLLM router pool so sub-agent
|
||||||
|
# model="auto" flows can load-balance / fail over across them.
|
||||||
|
# Free OR deployments stay out: OpenRouter's free tier is a single
|
||||||
|
# account-wide quota, so per-deployment routing can't spread load
|
||||||
|
# there — it just drains the shared bucket faster.
|
||||||
|
"router_pool_eligible": tier == "premium",
|
||||||
_OPENROUTER_DYNAMIC_MARKER: True,
|
_OPENROUTER_DYNAMIC_MARKER: True,
|
||||||
|
# Auto (Fastest) ranking metadata. ``quality_score`` is initialised
|
||||||
|
# to the static score and gets re-blended with health on the next
|
||||||
|
# ``_enrich_health`` pass (synchronous on refresh, deferred on cold
|
||||||
|
# start so startup latency is unchanged).
|
||||||
|
"auto_pin_tier": "B" if tier == "premium" else "C",
|
||||||
|
"quality_score_static": static_q,
|
||||||
|
"quality_score_health": None,
|
||||||
|
"quality_score": static_q,
|
||||||
|
"health_gated": False,
|
||||||
}
|
}
|
||||||
configs.append(cfg)
|
configs.append(cfg)
|
||||||
|
|
||||||
|
|
@ -187,6 +294,12 @@ class OpenRouterIntegrationService:
|
||||||
self._configs_by_id: dict[int, dict] = {}
|
self._configs_by_id: dict[int, dict] = {}
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
self._refresh_task: asyncio.Task | None = None
|
self._refresh_task: asyncio.Task | None = None
|
||||||
|
# Last-good per-model health snapshot. Survives across refresh
|
||||||
|
# cycles so a transient OpenRouter /endpoints outage doesn't drop
|
||||||
|
# every cfg back to static-only scoring.
|
||||||
|
# Shape: {model_name: {"gated": bool, "score": float | None}}
|
||||||
|
self._health_cache: dict[str, dict[str, Any]] = {}
|
||||||
|
self._enrich_task: asyncio.Task | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls) -> "OpenRouterIntegrationService":
|
def get_instance(cls) -> "OpenRouterIntegrationService":
|
||||||
|
|
@ -220,12 +333,27 @@ class OpenRouterIntegrationService:
|
||||||
self._configs_by_id = {c["id"]: c for c in self._configs}
|
self._configs_by_id = {c["id"]: c for c in self._configs}
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
|
tier_counts = self._tier_counts(self._configs)
|
||||||
logger.info(
|
logger.info(
|
||||||
"OpenRouter integration: loaded %d models (IDs %d to %d)",
|
"OpenRouter integration: loaded %d models (free=%d, premium=%d)",
|
||||||
len(self._configs),
|
len(self._configs),
|
||||||
self._configs[0]["id"] if self._configs else 0,
|
tier_counts["free"],
|
||||||
self._configs[-1]["id"] if self._configs else 0,
|
tier_counts["premium"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Schedule the first health-enrichment pass as a deferred task so
|
||||||
|
# cold-start latency is unchanged. Only valid when an event loop is
|
||||||
|
# already running (e.g. FastAPI lifespan); Celery worker init is
|
||||||
|
# fully sync so we silently skip — its first refresh tick (or the
|
||||||
|
# next refresh from the web process) will populate health data.
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
self._enrich_task = loop.create_task(
|
||||||
|
self._enrich_health_safely(self._configs)
|
||||||
|
)
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
return self._configs
|
return self._configs
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
@ -254,7 +382,225 @@ class OpenRouterIntegrationService:
|
||||||
self._configs = new_configs
|
self._configs = new_configs
|
||||||
self._configs_by_id = new_by_id
|
self._configs_by_id = new_by_id
|
||||||
|
|
||||||
logger.info("OpenRouter refresh: updated to %d models", len(new_configs))
|
# Catalogue churn invalidates per-config "recently healthy" credit
|
||||||
|
# earned by the previous turn's preflight. Drop the whole table so
|
||||||
|
# the next turn re-probes against the freshly loaded configs.
|
||||||
|
try:
|
||||||
|
from app.services.auto_model_pin_service import clear_healthy
|
||||||
|
|
||||||
|
clear_healthy()
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
"OpenRouter refresh: clear_healthy import skipped", exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
|
tier_counts = self._tier_counts(new_configs)
|
||||||
|
logger.info(
|
||||||
|
"OpenRouter refresh: updated to %d models (free=%d, premium=%d)",
|
||||||
|
len(new_configs),
|
||||||
|
tier_counts["free"],
|
||||||
|
tier_counts["premium"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Re-blend health scores against the freshly fetched catalogue. Also
|
||||||
|
# re-stamps health for any YAML-curated cfg with provider==OPENROUTER
|
||||||
|
# so a hand-picked dead OR model is gated like a dynamic one.
|
||||||
|
await self._enrich_health_safely(static_configs + new_configs, log_summary=True)
|
||||||
|
|
||||||
|
# Rebuild the LiteLLM router so freshly fetched configs flow through
|
||||||
|
# (dynamic OR premium entries now opt into the pool, free ones stay
|
||||||
|
# out; a refresh also needs to pick up any static-config edits and
|
||||||
|
# reset cached context-window profiles).
|
||||||
|
try:
|
||||||
|
from app.config import config as _app_config
|
||||||
|
from app.services.llm_router_service import (
|
||||||
|
LLMRouterService,
|
||||||
|
_router_instance_cache as _chat_router_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
LLMRouterService.rebuild(
|
||||||
|
_app_config.GLOBAL_LLM_CONFIGS,
|
||||||
|
getattr(_app_config, "ROUTER_SETTINGS", None),
|
||||||
|
)
|
||||||
|
_chat_router_cache.clear()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("OpenRouter refresh: router rebuild skipped (%s)", exc)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _tier_counts(configs: list[dict]) -> dict[str, int]:
|
||||||
|
counts = {"free": 0, "premium": 0}
|
||||||
|
for cfg in configs:
|
||||||
|
tier = str(cfg.get("billing_tier", "")).lower()
|
||||||
|
if tier in counts:
|
||||||
|
counts[tier] += 1
|
||||||
|
return counts
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Auto (Fastest) health enrichment
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _enrich_health_safely(
|
||||||
|
self, configs: list[dict], *, log_summary: bool = True
|
||||||
|
) -> None:
|
||||||
|
"""Wrapper around ``_enrich_health`` that swallows all errors.
|
||||||
|
|
||||||
|
Health enrichment is best-effort: any failure must leave cfgs in
|
||||||
|
their static-only state and never break refresh / startup.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await self._enrich_health(configs, log_summary=log_summary)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("OpenRouter health enrichment failed")
|
||||||
|
|
||||||
|
async def _enrich_health(
|
||||||
|
self, configs: list[dict], *, log_summary: bool = True
|
||||||
|
) -> None:
|
||||||
|
"""Fetch per-model ``/endpoints`` data for the top OR cfgs and blend
|
||||||
|
the resulting health score into ``cfg["quality_score"]``.
|
||||||
|
|
||||||
|
Bounded fan-out: top-N per tier by ``quality_score_static`` only,
|
||||||
|
with ``asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY)`` guarding the
|
||||||
|
outbound HTTP. Misses fall back to a per-model last-good cache; if
|
||||||
|
the failure ratio crosses ``_HEALTH_FAIL_RATIO_FALLBACK`` we keep
|
||||||
|
the entire previous cycle's cache for this run.
|
||||||
|
"""
|
||||||
|
or_cfgs = [
|
||||||
|
c for c in configs if str(c.get("provider", "")).upper() == "OPENROUTER"
|
||||||
|
]
|
||||||
|
if not or_cfgs:
|
||||||
|
return
|
||||||
|
|
||||||
|
premium_pool = sorted(
|
||||||
|
[c for c in or_cfgs if str(c.get("billing_tier", "")).lower() == "premium"],
|
||||||
|
key=lambda c: -int(c.get("quality_score_static") or 0),
|
||||||
|
)[:_HEALTH_ENRICH_TOP_N_PREMIUM]
|
||||||
|
free_pool = sorted(
|
||||||
|
[c for c in or_cfgs if str(c.get("billing_tier", "")).lower() == "free"],
|
||||||
|
key=lambda c: -int(c.get("quality_score_static") or 0),
|
||||||
|
)[:_HEALTH_ENRICH_TOP_N_FREE]
|
||||||
|
# De-duplicate while preserving order: a cfg shouldn't fall in both
|
||||||
|
# tiers, but defensive code is cheap here.
|
||||||
|
seen_ids: set[int] = set()
|
||||||
|
selected: list[dict] = []
|
||||||
|
for cfg in premium_pool + free_pool:
|
||||||
|
cid = int(cfg.get("id", 0))
|
||||||
|
if cid in seen_ids:
|
||||||
|
continue
|
||||||
|
seen_ids.add(cid)
|
||||||
|
selected.append(cfg)
|
||||||
|
|
||||||
|
if not selected:
|
||||||
|
return
|
||||||
|
|
||||||
|
api_key = str(self._settings.get("api_key") or "")
|
||||||
|
semaphore = asyncio.Semaphore(_HEALTH_ENRICH_CONCURRENCY)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=_HEALTH_FETCH_TIMEOUT_SEC) as client:
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*(
|
||||||
|
self._fetch_endpoints(client, semaphore, api_key, cfg)
|
||||||
|
for cfg in selected
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fail_count = sum(1 for _, _, err in results if err is not None)
|
||||||
|
fail_ratio = fail_count / len(results) if results else 0.0
|
||||||
|
degraded = fail_ratio >= _HEALTH_FAIL_RATIO_FALLBACK
|
||||||
|
if degraded:
|
||||||
|
logger.warning(
|
||||||
|
"auto_pin_health_enrich_degraded fail_ratio=%.2f total=%d "
|
||||||
|
"using_last_good_cache=true",
|
||||||
|
fail_ratio,
|
||||||
|
len(results),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Per-cfg health update.
|
||||||
|
for cfg, endpoints, err in results:
|
||||||
|
model_name = str(cfg.get("model_name", ""))
|
||||||
|
if not degraded and err is None and endpoints is not None:
|
||||||
|
gated, h_score = aggregate_health(endpoints)
|
||||||
|
cfg["health_gated"] = bool(gated)
|
||||||
|
cfg["quality_score_health"] = h_score
|
||||||
|
self._health_cache[model_name] = {
|
||||||
|
"gated": bool(gated),
|
||||||
|
"score": h_score,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
cached = self._health_cache.get(model_name)
|
||||||
|
if cached is not None:
|
||||||
|
cfg["health_gated"] = bool(cached.get("gated", False))
|
||||||
|
cfg["quality_score_health"] = cached.get("score")
|
||||||
|
# else: keep current values (initial defaults from
|
||||||
|
# _generate_configs / load_global_llm_configs).
|
||||||
|
|
||||||
|
# Blend health into the final score for every OR cfg, including
|
||||||
|
# those outside the enriched top-N (they fall through to static).
|
||||||
|
gated_count = 0
|
||||||
|
by_provider: dict[str, int] = {}
|
||||||
|
for cfg in or_cfgs:
|
||||||
|
static_q = int(cfg.get("quality_score_static") or 0)
|
||||||
|
h = cfg.get("quality_score_health")
|
||||||
|
if h is not None and not cfg.get("health_gated"):
|
||||||
|
blended = (
|
||||||
|
_HEALTH_BLEND_WEIGHT * float(h)
|
||||||
|
+ (1 - _HEALTH_BLEND_WEIGHT) * static_q
|
||||||
|
)
|
||||||
|
cfg["quality_score"] = round(blended)
|
||||||
|
else:
|
||||||
|
cfg["quality_score"] = static_q
|
||||||
|
|
||||||
|
if cfg.get("health_gated"):
|
||||||
|
gated_count += 1
|
||||||
|
model_id = str(cfg.get("model_name", ""))
|
||||||
|
provider_slug = (
|
||||||
|
model_id.split("/", 1)[0] if "/" in model_id else "unknown"
|
||||||
|
)
|
||||||
|
by_provider[provider_slug] = by_provider.get(provider_slug, 0) + 1
|
||||||
|
|
||||||
|
if log_summary:
|
||||||
|
logger.info(
|
||||||
|
"auto_pin_health_gated count=%d by_provider=%s fail_ratio=%.2f "
|
||||||
|
"total_enriched=%d",
|
||||||
|
gated_count,
|
||||||
|
dict(sorted(by_provider.items(), key=lambda kv: -kv[1])),
|
||||||
|
fail_ratio,
|
||||||
|
len(selected),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _fetch_endpoints(
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
semaphore: asyncio.Semaphore,
|
||||||
|
api_key: str,
|
||||||
|
cfg: dict,
|
||||||
|
) -> tuple[dict, list[dict] | None, Exception | None]:
|
||||||
|
"""Fetch ``/api/v1/models/{id}/endpoints`` for one cfg.
|
||||||
|
|
||||||
|
Returns ``(cfg, endpoints, err)`` so the caller can keep batched
|
||||||
|
results aligned with their cfgs without raising.
|
||||||
|
"""
|
||||||
|
model_id = str(cfg.get("model_name", ""))
|
||||||
|
if not model_id:
|
||||||
|
return cfg, None, ValueError("missing model_name")
|
||||||
|
|
||||||
|
url = OPENROUTER_ENDPOINTS_URL_TEMPLATE.format(model_id=model_id)
|
||||||
|
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||||
|
|
||||||
|
async with semaphore:
|
||||||
|
try:
|
||||||
|
resp = await client.get(url, headers=headers)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
except Exception as exc:
|
||||||
|
return cfg, None, exc
|
||||||
|
|
||||||
|
payload = data.get("data") if isinstance(data, dict) else None
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return cfg, None, ValueError("malformed endpoints payload")
|
||||||
|
endpoints = payload.get("endpoints")
|
||||||
|
if not isinstance(endpoints, list):
|
||||||
|
return cfg, [], None
|
||||||
|
return cfg, endpoints, None
|
||||||
|
|
||||||
async def _refresh_loop(self, interval_hours: float) -> None:
|
async def _refresh_loop(self, interval_hours: float) -> None:
|
||||||
interval_sec = interval_hours * 3600
|
interval_sec = interval_hours * 3600
|
||||||
|
|
|
||||||
380
surfsense_backend/app/services/quality_score.py
Normal file
380
surfsense_backend/app/services/quality_score.py
Normal file
|
|
@ -0,0 +1,380 @@
|
||||||
|
"""Pure-function quality scoring for Auto (Fastest) model selection.
|
||||||
|
|
||||||
|
This module is import-free of any service / request-path dependencies. All
|
||||||
|
numbers are computed once during the OpenRouter refresh tick (or YAML load)
|
||||||
|
and cached on the cfg dict, so the chat hot path only does a precomputed
|
||||||
|
sort and a SHA256 pick.
|
||||||
|
|
||||||
|
Score components (0-100 scale, higher is better):
|
||||||
|
|
||||||
|
* ``static_score_or`` - derived from the bulk ``/api/v1/models`` payload
|
||||||
|
(provider prestige + ``created`` recency + pricing band + context window
|
||||||
|
+ capabilities + narrow tiny/legacy slug penalty).
|
||||||
|
* ``static_score_yaml`` - same shape for hand-curated YAML configs, plus
|
||||||
|
an operator-trust bonus (the operator deliberately picked this model).
|
||||||
|
* ``aggregate_health`` - run on per-model ``/api/v1/models/{id}/endpoints``
|
||||||
|
responses; returns ``(gated, score_or_none)``.
|
||||||
|
|
||||||
|
The blended ``quality_score`` (0.5 * static + 0.5 * health) is computed in
|
||||||
|
:mod:`app.services.openrouter_integration_service` because that's the only
|
||||||
|
caller that sees both halves.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tunables (constants, not flags)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Top-K size for deterministic spread inside the locked tier.
|
||||||
|
_QUALITY_TOP_K: int = 5
|
||||||
|
|
||||||
|
# Hard health gate: any cfg whose best non-null uptime is below this %
|
||||||
|
# is excluded from Auto-mode selection entirely.
|
||||||
|
_HEALTH_GATE_UPTIME_PCT: float = 90.0
|
||||||
|
|
||||||
|
# Health/static blend weight when a cfg has fresh /endpoints data.
|
||||||
|
_HEALTH_BLEND_WEIGHT: float = 0.5
|
||||||
|
|
||||||
|
# Static bonus applied to YAML cfgs because the operator hand-picked them.
|
||||||
|
_OPERATOR_TRUST_BONUS: int = 20
|
||||||
|
|
||||||
|
# /endpoints fan-out is bounded per refresh tick.
|
||||||
|
_HEALTH_ENRICH_TOP_N_PREMIUM: int = 50
|
||||||
|
_HEALTH_ENRICH_TOP_N_FREE: int = 30
|
||||||
|
_HEALTH_ENRICH_CONCURRENCY: int = 15
|
||||||
|
_HEALTH_FETCH_TIMEOUT_SEC: float = 5.0
|
||||||
|
|
||||||
|
# If at least this fraction of /endpoints fetches fail in a refresh cycle,
|
||||||
|
# fall back to the previous cycle's last-good cache instead of writing
|
||||||
|
# partial / stale health values.
|
||||||
|
_HEALTH_FAIL_RATIO_FALLBACK: float = 0.25
|
||||||
|
|
||||||
|
# Narrow tiny/legacy slug penalties only. We deliberately do NOT penalise
|
||||||
|
# ``-nano`` / ``-mini`` / ``-lite`` because modern frontier models ship with
|
||||||
|
# those naming patterns (``gpt-5-mini``, ``gemini-2.5-flash-lite`` etc.) and
|
||||||
|
# blanket-penalising them suppresses high-quality picks.
|
||||||
|
_TINY_LEGACY_PENALTY_PATTERNS: tuple[str, ...] = (
|
||||||
|
"-1b-",
|
||||||
|
"-1.2b-",
|
||||||
|
"-1.5b-",
|
||||||
|
"-2b-",
|
||||||
|
"-3b-",
|
||||||
|
"gemma-3n",
|
||||||
|
"lfm-",
|
||||||
|
"-base",
|
||||||
|
"-distill",
|
||||||
|
":nitro",
|
||||||
|
"-preview",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Provider prestige tables
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# OpenRouter-side provider slug (the prefix before ``/`` in the model id).
|
||||||
|
# Tiers are coarse: frontier labs > strong open / fast-moving labs >
|
||||||
|
# specialist labs > everything else.
|
||||||
|
PROVIDER_PRESTIGE_OR: dict[str, int] = {
|
||||||
|
# Frontier labs
|
||||||
|
"openai": 50,
|
||||||
|
"anthropic": 50,
|
||||||
|
"google": 50,
|
||||||
|
"x-ai": 50,
|
||||||
|
# Strong open / fast-moving labs
|
||||||
|
"deepseek": 38,
|
||||||
|
"qwen": 38,
|
||||||
|
"meta-llama": 38,
|
||||||
|
"mistralai": 38,
|
||||||
|
"cohere": 38,
|
||||||
|
"nvidia": 38,
|
||||||
|
"alibaba": 38,
|
||||||
|
# Specialist / regional / strong second-tier
|
||||||
|
"microsoft": 28,
|
||||||
|
"01-ai": 28,
|
||||||
|
"minimax": 28,
|
||||||
|
"moonshot": 28,
|
||||||
|
"z-ai": 28,
|
||||||
|
"nousresearch": 28,
|
||||||
|
"ai21": 28,
|
||||||
|
"perplexity": 28,
|
||||||
|
# Smaller / niche providers
|
||||||
|
"liquid": 18,
|
||||||
|
"cognitivecomputations": 18,
|
||||||
|
"venice": 18,
|
||||||
|
"inflection": 18,
|
||||||
|
}
|
||||||
|
|
||||||
|
# YAML provider field (the upstream API shape the operator selected).
|
||||||
|
PROVIDER_PRESTIGE_YAML: dict[str, int] = {
|
||||||
|
"AZURE_OPENAI": 50,
|
||||||
|
"OPENAI": 50,
|
||||||
|
"ANTHROPIC": 50,
|
||||||
|
"GOOGLE": 50,
|
||||||
|
"VERTEX_AI": 50,
|
||||||
|
"GEMINI": 50,
|
||||||
|
"XAI": 50,
|
||||||
|
"MISTRAL": 38,
|
||||||
|
"DEEPSEEK": 38,
|
||||||
|
"COHERE": 38,
|
||||||
|
"GROQ": 30,
|
||||||
|
"TOGETHER_AI": 28,
|
||||||
|
"FIREWORKS_AI": 28,
|
||||||
|
"PERPLEXITY": 28,
|
||||||
|
"MINIMAX": 28,
|
||||||
|
"BEDROCK": 28,
|
||||||
|
"OPENROUTER": 25,
|
||||||
|
"OLLAMA": 12,
|
||||||
|
"CUSTOM": 12,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pure scoring helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Calibrated against the live /api/v1/models bulk dump. Frontier models
|
||||||
|
# released in the last ~6 months (GPT-5 family, Claude 4.x, Gemini 2.5,
|
||||||
|
# Grok 4) score in the 18-20 band; mid-2024 models in the 8-12 band;
|
||||||
|
# anything older trails off.
|
||||||
|
_RECENCY_BANDS_DAYS: tuple[tuple[int, int], ...] = (
|
||||||
|
(60, 20),
|
||||||
|
(180, 16),
|
||||||
|
(365, 12),
|
||||||
|
(540, 9),
|
||||||
|
(730, 6),
|
||||||
|
(1095, 3),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def created_recency_signal(created_ts: int | None, now_ts: int) -> int:
|
||||||
|
"""Return 0-20 based on how recently the model was published.
|
||||||
|
|
||||||
|
Uses the OpenRouter ``created`` Unix timestamp (or any equivalent for
|
||||||
|
YAML cfgs). Models without a usable timestamp get 0 (we don't penalise,
|
||||||
|
we just don't reward).
|
||||||
|
"""
|
||||||
|
if created_ts is None or created_ts <= 0 or now_ts <= 0:
|
||||||
|
return 0
|
||||||
|
age_days = max(0, (now_ts - int(created_ts)) // 86_400)
|
||||||
|
for cutoff, score in _RECENCY_BANDS_DAYS:
|
||||||
|
if age_days <= cutoff:
|
||||||
|
return score
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def pricing_band(
|
||||||
|
prompt: str | float | int | None,
|
||||||
|
completion: str | float | int | None,
|
||||||
|
) -> int:
|
||||||
|
"""Return 0-15 based on combined prompt+completion cost per 1M tokens.
|
||||||
|
|
||||||
|
Higher-priced models tend to be the larger / more capable ones. A free
|
||||||
|
model returns 0 (we use other signals to rank free-vs-free instead).
|
||||||
|
Uncoercible inputs are treated as 0 rather than raising.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _to_float(value) -> float:
|
||||||
|
if value is None:
|
||||||
|
return 0.0
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
p = _to_float(prompt)
|
||||||
|
c = _to_float(completion)
|
||||||
|
total_per_million = (p + c) * 1_000_000
|
||||||
|
|
||||||
|
if total_per_million >= 20.0:
|
||||||
|
return 15
|
||||||
|
if total_per_million >= 5.0:
|
||||||
|
return 12
|
||||||
|
if total_per_million >= 1.0:
|
||||||
|
return 9
|
||||||
|
if total_per_million >= 0.3:
|
||||||
|
return 6
|
||||||
|
if total_per_million >= 0.05:
|
||||||
|
return 4
|
||||||
|
if total_per_million > 0.0:
|
||||||
|
return 2
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def context_signal(ctx: int | None) -> int:
|
||||||
|
"""Return 0-10 based on the model's context window."""
|
||||||
|
if not ctx or ctx <= 0:
|
||||||
|
return 0
|
||||||
|
if ctx >= 1_000_000:
|
||||||
|
return 10
|
||||||
|
if ctx >= 400_000:
|
||||||
|
return 8
|
||||||
|
if ctx >= 200_000:
|
||||||
|
return 6
|
||||||
|
if ctx >= 128_000:
|
||||||
|
return 4
|
||||||
|
if ctx >= 100_000:
|
||||||
|
return 2
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def capabilities_signal(supported_parameters: list[str] | None) -> int:
|
||||||
|
"""Return 0-5 for capabilities that matter for our agent flows."""
|
||||||
|
if not supported_parameters:
|
||||||
|
return 0
|
||||||
|
params = set(supported_parameters)
|
||||||
|
score = 0
|
||||||
|
if "tools" in params:
|
||||||
|
score += 2
|
||||||
|
if "structured_outputs" in params or "response_format" in params:
|
||||||
|
score += 2
|
||||||
|
if "reasoning" in params or "include_reasoning" in params:
|
||||||
|
score += 1
|
||||||
|
return min(score, 5)
|
||||||
|
|
||||||
|
|
||||||
|
def slug_penalty(model_id: str) -> int:
|
||||||
|
"""Return a non-positive number; matches the narrow tiny/legacy patterns."""
|
||||||
|
if not model_id:
|
||||||
|
return 0
|
||||||
|
needle = model_id.lower()
|
||||||
|
for pattern in _TINY_LEGACY_PENALTY_PATTERNS:
|
||||||
|
if pattern in needle:
|
||||||
|
return -10
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_prestige_or(model_id: str) -> int:
|
||||||
|
if "/" not in model_id:
|
||||||
|
return 0
|
||||||
|
slug = model_id.split("/", 1)[0].lower()
|
||||||
|
return PROVIDER_PRESTIGE_OR.get(slug, 15)
|
||||||
|
|
||||||
|
|
||||||
|
def static_score_or(or_model: dict, *, now_ts: int) -> int:
|
||||||
|
"""Score a raw OpenRouter ``/api/v1/models`` entry on a 0-100 scale."""
|
||||||
|
model_id = str(or_model.get("id", ""))
|
||||||
|
pricing = or_model.get("pricing") or {}
|
||||||
|
|
||||||
|
score = (
|
||||||
|
_provider_prestige_or(model_id)
|
||||||
|
+ created_recency_signal(or_model.get("created"), now_ts)
|
||||||
|
+ pricing_band(pricing.get("prompt"), pricing.get("completion"))
|
||||||
|
+ context_signal(or_model.get("context_length"))
|
||||||
|
+ capabilities_signal(or_model.get("supported_parameters"))
|
||||||
|
+ slug_penalty(model_id)
|
||||||
|
)
|
||||||
|
return max(0, min(100, int(score)))
|
||||||
|
|
||||||
|
|
||||||
|
def static_score_yaml(cfg: dict) -> int:
|
||||||
|
"""Score a YAML-curated cfg on a 0-100 scale.
|
||||||
|
|
||||||
|
Includes ``_OPERATOR_TRUST_BONUS`` because the operator deliberately
|
||||||
|
listed this model. Pricing / context fall through to lazy ``litellm``
|
||||||
|
lookups; failures are silent (we just lose those sub-points).
|
||||||
|
"""
|
||||||
|
provider = str(cfg.get("provider", "")).upper()
|
||||||
|
base = PROVIDER_PRESTIGE_YAML.get(provider, 15)
|
||||||
|
|
||||||
|
model_name = cfg.get("model_name") or ""
|
||||||
|
litellm_params = cfg.get("litellm_params") or {}
|
||||||
|
lookup_name = (
|
||||||
|
litellm_params.get("base_model") or litellm_params.get("model") or model_name
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = 0
|
||||||
|
p_cost: float = 0.0
|
||||||
|
c_cost: float = 0.0
|
||||||
|
try:
|
||||||
|
from litellm import get_model_info # lazy: avoid cold-import cost
|
||||||
|
|
||||||
|
info = get_model_info(lookup_name) or {}
|
||||||
|
ctx = int(info.get("max_input_tokens") or info.get("max_tokens") or 0)
|
||||||
|
p_cost = float(info.get("input_cost_per_token") or 0.0)
|
||||||
|
c_cost = float(info.get("output_cost_per_token") or 0.0)
|
||||||
|
except Exception:
|
||||||
|
# Unknown to litellm — that's fine for prestige+operator-bonus weighting.
|
||||||
|
pass
|
||||||
|
|
||||||
|
score = (
|
||||||
|
base
|
||||||
|
+ _OPERATOR_TRUST_BONUS
|
||||||
|
+ pricing_band(p_cost, c_cost)
|
||||||
|
+ context_signal(ctx)
|
||||||
|
+ slug_penalty(str(model_name))
|
||||||
|
)
|
||||||
|
return max(0, min(100, int(score)))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Health aggregation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_pct(value) -> float | None:
|
||||||
|
try:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
f = float(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
if f < 0:
|
||||||
|
return None
|
||||||
|
# OpenRouter reports uptime as a 0-1 fraction; some endpoints surface it
|
||||||
|
# as a 0-100 percentage. Normalise.
|
||||||
|
return f * 100.0 if f <= 1.0 else f
|
||||||
|
|
||||||
|
|
||||||
|
def _best_uptime(endpoints: list[dict]) -> tuple[float | None, str | None]:
|
||||||
|
"""Pick the best (highest) non-null uptime across all endpoints.
|
||||||
|
|
||||||
|
Window preference: ``uptime_last_30m`` > ``uptime_last_1d`` >
|
||||||
|
``uptime_last_5m``. Returns ``(uptime_pct, window_used)``.
|
||||||
|
"""
|
||||||
|
for window in ("uptime_last_30m", "uptime_last_1d", "uptime_last_5m"):
|
||||||
|
values = [_coerce_pct(ep.get(window)) for ep in endpoints]
|
||||||
|
values = [v for v in values if v is not None]
|
||||||
|
if values:
|
||||||
|
return max(values), window
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_health(endpoints: list[dict]) -> tuple[bool, float | None]:
|
||||||
|
"""Aggregate a model's per-endpoint health into ``(gated, score_or_none)``.
|
||||||
|
|
||||||
|
Hard gate (returns ``(True, None)``):
|
||||||
|
* ``endpoints`` empty,
|
||||||
|
* no endpoint reports ``status == 0`` (OK), or
|
||||||
|
* best non-null uptime below ``_HEALTH_GATE_UPTIME_PCT``.
|
||||||
|
|
||||||
|
On a pass, returns a 0-100 health score blending uptime, status, and a
|
||||||
|
freshness-weighted recent uptime sample.
|
||||||
|
"""
|
||||||
|
if not endpoints:
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
any_ok = any(int(ep.get("status", 1)) == 0 for ep in endpoints)
|
||||||
|
if not any_ok:
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
best_uptime, _ = _best_uptime(endpoints)
|
||||||
|
if best_uptime is None or best_uptime < _HEALTH_GATE_UPTIME_PCT:
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
# Freshness term: prefer 5m, fall through to 30m / 1d if 5m is missing.
|
||||||
|
freshness = None
|
||||||
|
for window in ("uptime_last_5m", "uptime_last_30m", "uptime_last_1d"):
|
||||||
|
values = [_coerce_pct(ep.get(window)) for ep in endpoints]
|
||||||
|
values = [v for v in values if v is not None]
|
||||||
|
if values:
|
||||||
|
freshness = max(values)
|
||||||
|
break
|
||||||
|
|
||||||
|
uptime_term = best_uptime
|
||||||
|
status_term = 100.0 if any_ok else 0.0
|
||||||
|
freshness_term = freshness if freshness is not None else best_uptime
|
||||||
|
|
||||||
|
score = 0.50 * uptime_term + 0.30 * status_term + 0.20 * freshness_term
|
||||||
|
return False, max(0.0, min(100.0, score))
|
||||||
|
|
@ -8,7 +8,9 @@ Operation outcomes mirror the plan:
|
||||||
|
|
||||||
* **KB-owned actions** (NOTE / FILE / FOLDER mutations): restore from
|
* **KB-owned actions** (NOTE / FILE / FOLDER mutations): restore from
|
||||||
:class:`app.db.DocumentRevision` / :class:`app.db.FolderRevision` rows
|
:class:`app.db.DocumentRevision` / :class:`app.db.FolderRevision` rows
|
||||||
written before the original mutation.
|
written before the original mutation. ``rm``/``rmdir`` re-INSERT a fresh
|
||||||
|
row from the snapshot; ``write_file`` create / ``mkdir`` DELETE the row
|
||||||
|
that was created; everything else is an in-place restore.
|
||||||
* **Connector-owned actions with a declared ``reverse_descriptor``**: invoke
|
* **Connector-owned actions with a declared ``reverse_descriptor``**: invoke
|
||||||
the inverse tool through the agent's normal permission stack (NOT
|
the inverse tool through the agent's normal permission stack (NOT
|
||||||
bypassed). Out of scope for this PR — returns ``REVERSE_NOT_IMPLEMENTED``.
|
bypassed). Out of scope for this PR — returns ``REVERSE_NOT_IMPLEMENTED``.
|
||||||
|
|
@ -18,6 +20,11 @@ Operation outcomes mirror the plan:
|
||||||
A successful revert appends a NEW row to ``agent_action_log`` with
|
A successful revert appends a NEW row to ``agent_action_log`` with
|
||||||
``reverse_of=<original_action_id>`` and the requesting user's
|
``reverse_of=<original_action_id>`` and the requesting user's
|
||||||
``user_id``, preserving an auditable chain.
|
``user_id``, preserving an auditable chain.
|
||||||
|
|
||||||
|
Dispatch must be exact-match (``tool_name == name``), NOT prefix matching.
|
||||||
|
``"rmdir".startswith("rm")`` would otherwise mis-route directory revert
|
||||||
|
to the document branch (and ``delete_note`` vs ``delete_folder`` is the
|
||||||
|
same trap waiting to happen).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -25,17 +32,31 @@ from __future__ import annotations
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import delete, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.agents.new_chat.path_resolver import (
|
||||||
|
DOCUMENTS_ROOT,
|
||||||
|
safe_filename,
|
||||||
|
safe_folder_segment,
|
||||||
|
)
|
||||||
from app.db import (
|
from app.db import (
|
||||||
AgentActionLog,
|
AgentActionLog,
|
||||||
|
Chunk,
|
||||||
|
Document,
|
||||||
DocumentRevision,
|
DocumentRevision,
|
||||||
|
DocumentType,
|
||||||
|
Folder,
|
||||||
FolderRevision,
|
FolderRevision,
|
||||||
NewChatThread,
|
NewChatThread,
|
||||||
)
|
)
|
||||||
|
from app.utils.document_converters import (
|
||||||
|
embed_texts,
|
||||||
|
generate_content_hash,
|
||||||
|
generate_unique_identifier_hash,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -110,14 +131,244 @@ def can_revert(
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Revert paths
|
# Helper: reconstruct virtual path from a snapshot
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def _virtual_path_from_snapshot(
|
||||||
|
session: AsyncSession,
|
||||||
|
revision: DocumentRevision,
|
||||||
|
) -> str | None:
|
||||||
|
"""Reconstruct the virtual_path the document was at before mutation.
|
||||||
|
|
||||||
|
Preference order:
|
||||||
|
1. ``metadata_before["virtual_path"]`` — written by every snapshot
|
||||||
|
helper since this PR.
|
||||||
|
2. Compose ``"<folder_path>/<title_before>"`` from
|
||||||
|
``folder_id_before`` + ``title_before``. Walks the folder chain via
|
||||||
|
``parent_id``.
|
||||||
|
"""
|
||||||
|
metadata = revision.metadata_before or {}
|
||||||
|
candidate = metadata.get("virtual_path") if isinstance(metadata, dict) else None
|
||||||
|
if isinstance(candidate, str) and candidate.startswith(DOCUMENTS_ROOT):
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
title = revision.title_before
|
||||||
|
if not isinstance(title, str) or not title:
|
||||||
|
return None
|
||||||
|
|
||||||
|
parts: list[str] = []
|
||||||
|
cursor: int | None = revision.folder_id_before
|
||||||
|
visited: set[int] = set()
|
||||||
|
while cursor is not None and cursor not in visited:
|
||||||
|
visited.add(cursor)
|
||||||
|
folder = await session.get(Folder, cursor)
|
||||||
|
if folder is None:
|
||||||
|
return None
|
||||||
|
parts.append(safe_folder_segment(str(folder.name or "")))
|
||||||
|
cursor = folder.parent_id
|
||||||
|
parts.reverse()
|
||||||
|
|
||||||
|
base = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT
|
||||||
|
filename = safe_filename(title)
|
||||||
|
return f"{base}/{filename}"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Document revision restore (write/edit/move/rm)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _set_field(target: Any, field: str, value: Any) -> None:
|
||||||
|
if value is not None:
|
||||||
|
setattr(target, field, value)
|
||||||
|
|
||||||
|
|
||||||
|
async def _restore_in_place_document(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
revision: DocumentRevision,
|
||||||
|
) -> RevertOutcome:
|
||||||
|
"""Apply an in-place restore to an existing :class:`Document`."""
|
||||||
|
if revision.document_id is None:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="tool_unavailable",
|
||||||
|
message=(
|
||||||
|
"Original document was hard-deleted; in-place restore is not possible."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
doc = await session.get(Document, revision.document_id)
|
||||||
|
if doc is None:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="tool_unavailable",
|
||||||
|
message="Original document has been deleted; revert cannot proceed.",
|
||||||
|
)
|
||||||
|
|
||||||
|
_set_field(doc, "content", revision.content_before)
|
||||||
|
_set_field(doc, "source_markdown", revision.content_before)
|
||||||
|
_set_field(doc, "title", revision.title_before)
|
||||||
|
_set_field(doc, "folder_id", revision.folder_id_before)
|
||||||
|
metadata_before = revision.metadata_before or {}
|
||||||
|
if isinstance(metadata_before, dict) and metadata_before:
|
||||||
|
doc.document_metadata = dict(metadata_before)
|
||||||
|
|
||||||
|
if isinstance(revision.content_before, str):
|
||||||
|
doc.content_hash = generate_content_hash(
|
||||||
|
revision.content_before, doc.search_space_id
|
||||||
|
)
|
||||||
|
|
||||||
|
virtual_path = await _virtual_path_from_snapshot(session, revision)
|
||||||
|
if virtual_path:
|
||||||
|
doc.unique_identifier_hash = generate_unique_identifier_hash(
|
||||||
|
DocumentType.NOTE,
|
||||||
|
virtual_path,
|
||||||
|
doc.search_space_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks_before = revision.chunks_before
|
||||||
|
if isinstance(chunks_before, list):
|
||||||
|
await session.execute(delete(Chunk).where(Chunk.document_id == doc.id))
|
||||||
|
chunk_texts = [
|
||||||
|
str(c.get("content"))
|
||||||
|
for c in chunks_before
|
||||||
|
if isinstance(c, dict) and isinstance(c.get("content"), str)
|
||||||
|
]
|
||||||
|
if chunk_texts:
|
||||||
|
chunk_embeddings = embed_texts(chunk_texts)
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
Chunk(document_id=doc.id, content=text, embedding=embedding)
|
||||||
|
for text, embedding in zip(
|
||||||
|
chunk_texts, chunk_embeddings, strict=True
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if isinstance(revision.content_before, str):
|
||||||
|
doc.embedding = embed_texts([revision.content_before])[0]
|
||||||
|
|
||||||
|
doc.updated_at = datetime.now(UTC)
|
||||||
|
return RevertOutcome(status="ok", message="Document restored from snapshot.")
|
||||||
|
|
||||||
|
|
||||||
|
async def _reinsert_document_from_revision(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
revision: DocumentRevision,
|
||||||
|
) -> RevertOutcome:
|
||||||
|
"""Re-INSERT a deleted :class:`Document` from a snapshot row (``rm`` revert)."""
|
||||||
|
if not isinstance(revision.title_before, str) or not revision.title_before:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="not_reversible",
|
||||||
|
message="Snapshot lacks title_before; cannot recreate document.",
|
||||||
|
)
|
||||||
|
if not isinstance(revision.content_before, str):
|
||||||
|
return RevertOutcome(
|
||||||
|
status="not_reversible",
|
||||||
|
message="Snapshot lacks content_before; cannot recreate document.",
|
||||||
|
)
|
||||||
|
|
||||||
|
virtual_path = await _virtual_path_from_snapshot(session, revision)
|
||||||
|
if not virtual_path:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="not_reversible",
|
||||||
|
message=(
|
||||||
|
"Snapshot is missing both metadata_before['virtual_path'] AND "
|
||||||
|
"a resolvable (folder_id_before, title_before) pair."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
search_space_id = revision.search_space_id
|
||||||
|
unique_identifier_hash = generate_unique_identifier_hash(
|
||||||
|
DocumentType.NOTE,
|
||||||
|
virtual_path,
|
||||||
|
search_space_id,
|
||||||
|
)
|
||||||
|
collision = await session.execute(
|
||||||
|
select(Document.id).where(
|
||||||
|
Document.search_space_id == search_space_id,
|
||||||
|
Document.unique_identifier_hash == unique_identifier_hash,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if collision.scalar_one_or_none() is not None:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="tool_unavailable",
|
||||||
|
message=(
|
||||||
|
f"A document already exists at '{virtual_path}'; revert would "
|
||||||
|
"collide. Move the live doc out of the way first."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = revision.metadata_before or {}
|
||||||
|
if not isinstance(metadata, dict):
|
||||||
|
metadata = {}
|
||||||
|
metadata = dict(metadata)
|
||||||
|
metadata["virtual_path"] = virtual_path
|
||||||
|
|
||||||
|
content = revision.content_before
|
||||||
|
new_doc = Document(
|
||||||
|
title=revision.title_before,
|
||||||
|
document_type=DocumentType.NOTE,
|
||||||
|
document_metadata=metadata,
|
||||||
|
content=content,
|
||||||
|
content_hash=generate_content_hash(content, search_space_id),
|
||||||
|
unique_identifier_hash=unique_identifier_hash,
|
||||||
|
source_markdown=content,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
folder_id=revision.folder_id_before,
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
session.add(new_doc)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
new_doc.embedding = embed_texts([content])[0]
|
||||||
|
chunk_texts = []
|
||||||
|
chunks_before = revision.chunks_before
|
||||||
|
if isinstance(chunks_before, list):
|
||||||
|
chunk_texts = [
|
||||||
|
str(c.get("content"))
|
||||||
|
for c in chunks_before
|
||||||
|
if isinstance(c, dict) and isinstance(c.get("content"), str)
|
||||||
|
]
|
||||||
|
if chunk_texts:
|
||||||
|
chunk_embeddings = embed_texts(chunk_texts)
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
Chunk(document_id=new_doc.id, content=text, embedding=embedding)
|
||||||
|
for text, embedding in zip(chunk_texts, chunk_embeddings, strict=True)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Repoint the snapshot at the recreated row so a follow-up revert of
|
||||||
|
# the same row works as expected.
|
||||||
|
revision.document_id = new_doc.id
|
||||||
|
return RevertOutcome(
|
||||||
|
status="ok",
|
||||||
|
message=f"Re-inserted document '{revision.title_before}' from snapshot.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _delete_created_document(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
revision: DocumentRevision,
|
||||||
|
) -> RevertOutcome:
|
||||||
|
"""Delete the document that ``write_file`` created (``content_before IS NULL``)."""
|
||||||
|
if revision.document_id is None:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="ok",
|
||||||
|
message="No live row to delete (already removed elsewhere).",
|
||||||
|
)
|
||||||
|
await session.execute(delete(Document).where(Document.id == revision.document_id))
|
||||||
|
return RevertOutcome(
|
||||||
|
status="ok",
|
||||||
|
message="Deleted the document that was created by this action.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _restore_document_revision(
|
async def _restore_document_revision(
|
||||||
session: AsyncSession, *, action: AgentActionLog
|
session: AsyncSession, *, action: AgentActionLog
|
||||||
) -> RevertOutcome:
|
) -> RevertOutcome:
|
||||||
"""Restore the most recent :class:`DocumentRevision` for ``action``."""
|
"""Dispatch document-level revert based on ``action.tool_name``."""
|
||||||
stmt = (
|
stmt = (
|
||||||
select(DocumentRevision)
|
select(DocumentRevision)
|
||||||
.where(DocumentRevision.agent_action_id == action.id)
|
.where(DocumentRevision.agent_action_id == action.id)
|
||||||
|
|
@ -132,23 +383,111 @@ async def _restore_document_revision(
|
||||||
message="No document_revisions row tied to this action.",
|
message="No document_revisions row tied to this action.",
|
||||||
)
|
)
|
||||||
|
|
||||||
from app.db import Document # late import to avoid cycles at module load
|
tool_name = (action.tool_name or "").lower()
|
||||||
|
|
||||||
doc = await session.get(Document, revision.document_id)
|
if tool_name == "rm":
|
||||||
if doc is None:
|
return await _reinsert_document_from_revision(session, revision=revision)
|
||||||
|
|
||||||
|
if tool_name == "write_file" and revision.content_before is None:
|
||||||
|
return await _delete_created_document(session, revision=revision)
|
||||||
|
|
||||||
|
return await _restore_in_place_document(session, revision=revision)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Folder revision restore (mkdir/rmdir/rename/move)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def _restore_in_place_folder(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
revision: FolderRevision,
|
||||||
|
) -> RevertOutcome:
|
||||||
|
if revision.folder_id is None:
|
||||||
return RevertOutcome(
|
return RevertOutcome(
|
||||||
status="tool_unavailable",
|
status="tool_unavailable",
|
||||||
message="Original document has been deleted; revert cannot proceed.",
|
message="Original folder was hard-deleted; in-place restore is impossible.",
|
||||||
|
)
|
||||||
|
folder = await session.get(Folder, revision.folder_id)
|
||||||
|
if folder is None:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="tool_unavailable",
|
||||||
|
message="Original folder has been deleted; revert cannot proceed.",
|
||||||
|
)
|
||||||
|
_set_field(folder, "name", revision.name_before)
|
||||||
|
_set_field(folder, "parent_id", revision.parent_id_before)
|
||||||
|
_set_field(folder, "position", revision.position_before)
|
||||||
|
folder.updated_at = datetime.now(UTC)
|
||||||
|
return RevertOutcome(status="ok", message="Folder restored from snapshot.")
|
||||||
|
|
||||||
|
|
||||||
|
async def _reinsert_folder_from_revision(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
revision: FolderRevision,
|
||||||
|
) -> RevertOutcome:
|
||||||
|
if not isinstance(revision.name_before, str) or not revision.name_before:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="not_reversible",
|
||||||
|
message="Snapshot lacks name_before; cannot recreate folder.",
|
||||||
|
)
|
||||||
|
new_folder = Folder(
|
||||||
|
name=revision.name_before,
|
||||||
|
parent_id=revision.parent_id_before,
|
||||||
|
position=revision.position_before,
|
||||||
|
search_space_id=revision.search_space_id,
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
session.add(new_folder)
|
||||||
|
await session.flush()
|
||||||
|
revision.folder_id = new_folder.id
|
||||||
|
return RevertOutcome(
|
||||||
|
status="ok",
|
||||||
|
message=f"Re-inserted folder '{revision.name_before}' from snapshot.",
|
||||||
)
|
)
|
||||||
|
|
||||||
if revision.content_before is not None:
|
|
||||||
doc.content = revision.content_before
|
async def _delete_created_folder(
|
||||||
if revision.title_before is not None:
|
session: AsyncSession,
|
||||||
doc.title = revision.title_before
|
*,
|
||||||
if revision.folder_id_before is not None:
|
revision: FolderRevision,
|
||||||
doc.folder_id = revision.folder_id_before
|
) -> RevertOutcome:
|
||||||
doc.updated_at = datetime.now(UTC)
|
if revision.folder_id is None:
|
||||||
return RevertOutcome(status="ok", message="Document restored from snapshot.")
|
return RevertOutcome(
|
||||||
|
status="ok",
|
||||||
|
message="No live folder row to delete (already removed elsewhere).",
|
||||||
|
)
|
||||||
|
folder_id = revision.folder_id
|
||||||
|
|
||||||
|
has_doc = await session.execute(
|
||||||
|
select(Document.id).where(Document.folder_id == folder_id).limit(1)
|
||||||
|
)
|
||||||
|
if has_doc.scalar_one_or_none() is not None:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="tool_unavailable",
|
||||||
|
message=(
|
||||||
|
"Folder is no longer empty (documents have been added since "
|
||||||
|
"mkdir); cannot revert."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
has_child = await session.execute(
|
||||||
|
select(Folder.id).where(Folder.parent_id == folder_id).limit(1)
|
||||||
|
)
|
||||||
|
if has_child.scalar_one_or_none() is not None:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="tool_unavailable",
|
||||||
|
message=(
|
||||||
|
"Folder is no longer empty (sub-folders have been added "
|
||||||
|
"since mkdir); cannot revert."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await session.execute(delete(Folder).where(Folder.id == folder_id))
|
||||||
|
return RevertOutcome(
|
||||||
|
status="ok",
|
||||||
|
message="Deleted the folder that was created by this action.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _restore_folder_revision(
|
async def _restore_folder_revision(
|
||||||
|
|
@ -168,41 +507,44 @@ async def _restore_folder_revision(
|
||||||
message="No folder_revisions row tied to this action.",
|
message="No folder_revisions row tied to this action.",
|
||||||
)
|
)
|
||||||
|
|
||||||
from app.db import Folder
|
tool_name = (action.tool_name or "").lower()
|
||||||
|
|
||||||
folder = await session.get(Folder, revision.folder_id)
|
if tool_name == "rmdir":
|
||||||
if folder is None:
|
return await _reinsert_folder_from_revision(session, revision=revision)
|
||||||
return RevertOutcome(
|
|
||||||
status="tool_unavailable",
|
|
||||||
message="Original folder has been deleted; revert cannot proceed.",
|
|
||||||
)
|
|
||||||
|
|
||||||
if revision.name_before is not None:
|
if tool_name == "mkdir":
|
||||||
folder.name = revision.name_before
|
return await _delete_created_folder(session, revision=revision)
|
||||||
if revision.parent_id_before is not None:
|
|
||||||
folder.parent_id = revision.parent_id_before
|
return await _restore_in_place_folder(session, revision=revision)
|
||||||
if revision.position_before is not None:
|
|
||||||
folder.position = revision.position_before
|
|
||||||
folder.updated_at = datetime.now(UTC)
|
|
||||||
return RevertOutcome(status="ok", message="Folder restored from snapshot.")
|
|
||||||
|
|
||||||
|
|
||||||
# Tool-name prefixes that route to KB document / folder revert paths. Kept
|
# ---------------------------------------------------------------------------
|
||||||
# as data so a future PR adding new KB-owned tools doesn't have to touch
|
# Dispatch
|
||||||
# this module's control flow.
|
# ---------------------------------------------------------------------------
|
||||||
_DOC_TOOL_PREFIXES: tuple[str, ...] = (
|
#
|
||||||
|
# Exact-name dispatch: ``tool_name == name``, NOT ``startswith(...)``.
|
||||||
|
# Prefix-matching mis-routes pairs like ``rm``/``rmdir`` and
|
||||||
|
# ``delete_note``/``delete_folder``.
|
||||||
|
|
||||||
|
_DOC_TOOLS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
"edit_file",
|
"edit_file",
|
||||||
"write_file",
|
"write_file",
|
||||||
|
"move_file",
|
||||||
|
"rm",
|
||||||
"update_memory",
|
"update_memory",
|
||||||
"create_note",
|
"create_note",
|
||||||
"update_note",
|
"update_note",
|
||||||
"delete_note",
|
"delete_note",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
_FOLDER_TOOL_PREFIXES: tuple[str, ...] = (
|
_FOLDER_TOOLS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
"mkdir",
|
"mkdir",
|
||||||
"move_file",
|
"rmdir",
|
||||||
"rename_folder",
|
"rename_folder",
|
||||||
"delete_folder",
|
"delete_folder",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -220,9 +562,9 @@ async def revert_action(
|
||||||
"""
|
"""
|
||||||
tool_name = (action.tool_name or "").lower()
|
tool_name = (action.tool_name or "").lower()
|
||||||
|
|
||||||
if tool_name.startswith(_DOC_TOOL_PREFIXES):
|
if tool_name in _DOC_TOOLS:
|
||||||
outcome = await _restore_document_revision(session, action=action)
|
outcome = await _restore_document_revision(session, action=action)
|
||||||
elif tool_name.startswith(_FOLDER_TOOL_PREFIXES):
|
elif tool_name in _FOLDER_TOOLS:
|
||||||
outcome = await _restore_folder_revision(session, action=action)
|
outcome = await _restore_folder_revision(session, action=action)
|
||||||
elif action.reverse_descriptor:
|
elif action.reverse_descriptor:
|
||||||
# Connector-owned reversibles run through the normal permission
|
# Connector-owned reversibles run through the normal permission
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -74,7 +74,7 @@ dependencies = [
|
||||||
"deepagents>=0.4.12",
|
"deepagents>=0.4.12",
|
||||||
"stripe>=15.0.0",
|
"stripe>=15.0.0",
|
||||||
"azure-ai-documentintelligence>=1.0.2",
|
"azure-ai-documentintelligence>=1.0.2",
|
||||||
"litellm>=1.83.4",
|
"litellm>=1.83.7",
|
||||||
"langchain-litellm>=0.6.4",
|
"langchain-litellm>=0.6.4",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -226,6 +226,31 @@ class TestCompose:
|
||||||
# Default block should NOT be present
|
# Default block should NOT be present
|
||||||
assert "<knowledge_base_only_policy>" not in prompt
|
assert "<knowledge_base_only_policy>" not in prompt
|
||||||
|
|
||||||
|
def test_provider_hints_render_with_custom_system_instructions(
|
||||||
|
self, fixed_today: datetime
|
||||||
|
) -> None:
|
||||||
|
"""Regression guard for the always-append decision: provider hints
|
||||||
|
append AFTER a custom system prompt.
|
||||||
|
|
||||||
|
Provider hints are stylistic nudges (parallel tool-call rules,
|
||||||
|
formatting guidance, etc.) that help the model regardless of
|
||||||
|
what the system instructions say. Suppressing them when a
|
||||||
|
custom prompt is set would partially defeat the per-family
|
||||||
|
prompt machinery.
|
||||||
|
"""
|
||||||
|
prompt = compose_system_prompt(
|
||||||
|
today=fixed_today,
|
||||||
|
custom_system_instructions="You are a custom assistant.",
|
||||||
|
model_name="anthropic/claude-3-5-sonnet",
|
||||||
|
)
|
||||||
|
assert "You are a custom assistant." in prompt
|
||||||
|
assert "<provider_hints>" in prompt
|
||||||
|
# The custom prompt must come BEFORE the provider hints so the
|
||||||
|
# user's framing isn't drowned out by the stylistic nudges.
|
||||||
|
assert prompt.index("You are a custom assistant.") < prompt.index(
|
||||||
|
"<provider_hints>"
|
||||||
|
)
|
||||||
|
|
||||||
def test_use_default_false_with_no_custom_yields_no_system_block(
|
def test_use_default_false_with_no_custom_yields_no_system_block(
|
||||||
self, fixed_today: datetime
|
self, fixed_today: datetime
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,17 @@ from app.agents.new_chat.middleware.action_log import ActionLogMiddleware
|
||||||
from app.agents.new_chat.tools.registry import ToolDefinition
|
from app.agents.new_chat.tools.registry import ToolDefinition
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _FakeRuntime:
|
||||||
|
"""Minimal stand-in for ``ToolRuntime`` used in unit tests.
|
||||||
|
|
||||||
|
``ActionLogMiddleware`` reads ``runtime.config['configurable']['turn_id']``
|
||||||
|
to populate the new ``chat_turn_id`` column (see migration 135).
|
||||||
|
"""
|
||||||
|
|
||||||
|
config: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _FakeRequest:
|
class _FakeRequest:
|
||||||
"""Minimal stand-in for ToolCallRequest used in unit tests."""
|
"""Minimal stand-in for ToolCallRequest used in unit tests."""
|
||||||
|
|
@ -120,6 +131,9 @@ class TestActionLogMiddlewarePersistence:
|
||||||
"args": {"color": "red", "size": 3},
|
"args": {"color": "red", "size": 3},
|
||||||
"id": "tc-abc",
|
"id": "tc-abc",
|
||||||
},
|
},
|
||||||
|
runtime=_FakeRuntime(
|
||||||
|
config={"configurable": {"turn_id": "42:1700000000000"}}
|
||||||
|
),
|
||||||
)
|
)
|
||||||
result_msg = ToolMessage(content="ok", tool_call_id="tc-abc", id="msg-1")
|
result_msg = ToolMessage(content="ok", tool_call_id="tc-abc", id="msg-1")
|
||||||
handler = AsyncMock(return_value=result_msg)
|
handler = AsyncMock(return_value=result_msg)
|
||||||
|
|
@ -142,6 +156,32 @@ class TestActionLogMiddlewarePersistence:
|
||||||
assert row.error is None
|
assert row.error is None
|
||||||
assert row.reverse_descriptor is None
|
assert row.reverse_descriptor is None
|
||||||
assert row.reversible is False
|
assert row.reversible is False
|
||||||
|
# Migration 135: ``turn_id`` is the deprecated alias of ``tool_call_id``;
|
||||||
|
# ``chat_turn_id`` comes from ``runtime.config['configurable']['turn_id']``.
|
||||||
|
assert row.tool_call_id == "tc-abc"
|
||||||
|
assert row.turn_id == "tc-abc"
|
||||||
|
assert row.chat_turn_id == "42:1700000000000"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_turn_id_none_when_runtime_missing(
|
||||||
|
self, patch_get_flags, fake_session_factory
|
||||||
|
) -> None:
|
||||||
|
"""``chat_turn_id`` falls back to NULL when ``runtime.config`` is absent."""
|
||||||
|
captured, factory = fake_session_factory
|
||||||
|
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
|
||||||
|
request = _FakeRequest(
|
||||||
|
tool_call={"name": "make_widget", "args": {}, "id": "tc-1"},
|
||||||
|
runtime=None,
|
||||||
|
)
|
||||||
|
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc-1"))
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
|
||||||
|
):
|
||||||
|
await mw.awrap_tool_call(request, handler)
|
||||||
|
row = captured["rows"][0]
|
||||||
|
assert row.tool_call_id == "tc-1"
|
||||||
|
assert row.chat_turn_id is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_writes_row_on_failure_and_reraises(
|
async def test_writes_row_on_failure_and_reraises(
|
||||||
|
|
@ -293,6 +333,76 @@ class TestReverseDescriptor:
|
||||||
assert row.reversible is False
|
assert row.reversible is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestActionLogDispatch:
|
||||||
|
"""Verify ``adispatch_custom_event`` fires after commit."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatches_action_log_event_on_success(
|
||||||
|
self, patch_get_flags, fake_session_factory
|
||||||
|
) -> None:
|
||||||
|
_captured, factory = fake_session_factory
|
||||||
|
mw = ActionLogMiddleware(thread_id=42, search_space_id=7, user_id="u1")
|
||||||
|
request = _FakeRequest(
|
||||||
|
tool_call={
|
||||||
|
"name": "make_widget",
|
||||||
|
"args": {"color": "red"},
|
||||||
|
"id": "tc-evt",
|
||||||
|
},
|
||||||
|
runtime=_FakeRuntime(
|
||||||
|
config={"configurable": {"turn_id": "42:1700000000000"}}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
result_msg = ToolMessage(content="ok", tool_call_id="tc-evt", id="msg-42")
|
||||||
|
handler = AsyncMock(return_value=result_msg)
|
||||||
|
|
||||||
|
dispatch_mock = AsyncMock()
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch("app.db.shielded_async_session", side_effect=lambda: factory()),
|
||||||
|
patch(
|
||||||
|
"app.agents.new_chat.middleware.action_log.adispatch_custom_event",
|
||||||
|
dispatch_mock,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
await mw.awrap_tool_call(request, handler)
|
||||||
|
|
||||||
|
dispatch_mock.assert_awaited_once()
|
||||||
|
call_args = dispatch_mock.await_args
|
||||||
|
assert call_args is not None
|
||||||
|
assert call_args.args[0] == "action_log"
|
||||||
|
payload = call_args.args[1]
|
||||||
|
assert payload["lc_tool_call_id"] == "tc-evt"
|
||||||
|
assert payload["chat_turn_id"] == "42:1700000000000"
|
||||||
|
assert payload["tool_name"] == "make_widget"
|
||||||
|
assert payload["reversible"] is False
|
||||||
|
assert payload["reverse_descriptor_present"] is False
|
||||||
|
assert payload["error"] is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_dispatch_when_persistence_fails(self, patch_get_flags) -> None:
|
||||||
|
"""If commit fails the dispatch is suppressed (no row to surface)."""
|
||||||
|
mw = ActionLogMiddleware(thread_id=1, search_space_id=1, user_id=None)
|
||||||
|
request = _FakeRequest(
|
||||||
|
tool_call={"name": "make_widget", "args": {}, "id": "tc1"}
|
||||||
|
)
|
||||||
|
handler = AsyncMock(return_value=ToolMessage(content="ok", tool_call_id="tc1"))
|
||||||
|
dispatch_mock = AsyncMock()
|
||||||
|
|
||||||
|
def _exploding_session():
|
||||||
|
raise RuntimeError("DB is down")
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch("app.db.shielded_async_session", side_effect=_exploding_session),
|
||||||
|
patch(
|
||||||
|
"app.agents.new_chat.middleware.action_log.adispatch_custom_event",
|
||||||
|
dispatch_mock,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
await mw.awrap_tool_call(request, handler)
|
||||||
|
dispatch_mock.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
class TestArgsTruncation:
|
class TestArgsTruncation:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_huge_args_payload_is_truncated(
|
async def test_huge_args_payload_is_truncated(
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,9 @@ import pytest
|
||||||
from app.agents.new_chat.errors import BusyError
|
from app.agents.new_chat.errors import BusyError
|
||||||
from app.agents.new_chat.middleware.busy_mutex import (
|
from app.agents.new_chat.middleware.busy_mutex import (
|
||||||
BusyMutexMiddleware,
|
BusyMutexMiddleware,
|
||||||
|
end_turn,
|
||||||
get_cancel_event,
|
get_cancel_event,
|
||||||
|
is_cancel_requested,
|
||||||
manager,
|
manager,
|
||||||
request_cancel,
|
request_cancel,
|
||||||
reset_cancel,
|
reset_cancel,
|
||||||
|
|
@ -88,3 +90,65 @@ async def test_no_thread_id_skipped_when_not_required() -> None:
|
||||||
def test_reset_cancel_idempotent() -> None:
|
def test_reset_cancel_idempotent() -> None:
|
||||||
# Should not raise even if event was never created
|
# Should not raise even if event was never created
|
||||||
reset_cancel("never-seen")
|
reset_cancel("never-seen")
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_cancel_creates_event_for_unseen_thread() -> None:
|
||||||
|
thread_id = "never-seen-cancel"
|
||||||
|
reset_cancel(thread_id)
|
||||||
|
|
||||||
|
assert request_cancel(thread_id) is True
|
||||||
|
assert get_cancel_event(thread_id).is_set()
|
||||||
|
assert is_cancel_requested(thread_id) is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_turn_force_clears_lock_and_cancel_state() -> None:
|
||||||
|
thread_id = "forced-end-turn"
|
||||||
|
mw = BusyMutexMiddleware()
|
||||||
|
runtime = _Runtime(thread_id)
|
||||||
|
|
||||||
|
await mw.abefore_agent({}, runtime)
|
||||||
|
assert manager.lock_for(thread_id).locked()
|
||||||
|
|
||||||
|
request_cancel(thread_id)
|
||||||
|
assert is_cancel_requested(thread_id) is True
|
||||||
|
|
||||||
|
end_turn(thread_id)
|
||||||
|
|
||||||
|
assert not manager.lock_for(thread_id).locked()
|
||||||
|
assert not get_cancel_event(thread_id).is_set()
|
||||||
|
assert is_cancel_requested(thread_id) is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_busy_mutex_stale_aafter_does_not_release_new_attempt_lock() -> None:
|
||||||
|
"""A stale aafter call from attempt A must not unlock attempt B.
|
||||||
|
|
||||||
|
Repro flow:
|
||||||
|
1) attempt A acquires thread lock
|
||||||
|
2) forced end_turn clears A so retry can proceed
|
||||||
|
3) attempt B acquires same thread lock
|
||||||
|
4) stale attempt-A aafter runs late
|
||||||
|
|
||||||
|
Expected: B lock remains held.
|
||||||
|
"""
|
||||||
|
thread_id = "stale-aafter-lock"
|
||||||
|
runtime = _Runtime(thread_id)
|
||||||
|
attempt_a = BusyMutexMiddleware()
|
||||||
|
attempt_b = BusyMutexMiddleware()
|
||||||
|
|
||||||
|
await attempt_a.abefore_agent({}, runtime)
|
||||||
|
lock = manager.lock_for(thread_id)
|
||||||
|
assert lock.locked()
|
||||||
|
|
||||||
|
end_turn(thread_id)
|
||||||
|
assert not lock.locked()
|
||||||
|
|
||||||
|
await attempt_b.abefore_agent({}, runtime)
|
||||||
|
assert lock.locked()
|
||||||
|
|
||||||
|
# Stale cleanup from attempt A must not release attempt B's lock.
|
||||||
|
await attempt_a.aafter_agent({}, runtime)
|
||||||
|
assert lock.locked()
|
||||||
|
|
||||||
|
await attempt_b.aafter_agent({}, runtime)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,122 @@
|
||||||
|
"""Tests for the desktop-mode safety ruleset.
|
||||||
|
|
||||||
|
In desktop mode the agent operates against the user's real disk with no
|
||||||
|
revision history, so destructive filesystem operations must require
|
||||||
|
explicit approval. These tests pin the set of tools that get the ``ask``
|
||||||
|
gate so it cannot silently regress.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.permission import PermissionMiddleware
|
||||||
|
from app.agents.new_chat.permissions import (
|
||||||
|
Rule,
|
||||||
|
Ruleset,
|
||||||
|
aggregate_action,
|
||||||
|
evaluate_many,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
# Mirror the ruleset built inside ``chat_deepagent._build_compiled_agent_blocking``
|
||||||
|
# when ``filesystem_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER``. Keeping a
|
||||||
|
# copy here means the rule contract has a focused regression test even when
|
||||||
|
# the larger graph-build helper is hard to instantiate in unit tests.
|
||||||
|
DESKTOP_SAFETY_RULESET = Ruleset(
|
||||||
|
rules=[
|
||||||
|
Rule(permission="rm", pattern="*", action="ask"),
|
||||||
|
Rule(permission="rmdir", pattern="*", action="ask"),
|
||||||
|
Rule(permission="move_file", pattern="*", action="ask"),
|
||||||
|
Rule(permission="edit_file", pattern="*", action="ask"),
|
||||||
|
Rule(permission="write_file", pattern="*", action="ask"),
|
||||||
|
],
|
||||||
|
origin="desktop_safety",
|
||||||
|
)
|
||||||
|
|
||||||
|
SURFSENSE_DEFAULTS = Ruleset(
|
||||||
|
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||||
|
origin="surfsense_defaults",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _action_for(tool_name: str, *rulesets: Ruleset) -> str:
|
||||||
|
rules = evaluate_many(tool_name, [tool_name], *rulesets)
|
||||||
|
return aggregate_action(rules)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDesktopSafetyRulesGateDestructiveOps:
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"tool_name",
|
||||||
|
["rm", "rmdir", "move_file", "edit_file", "write_file"],
|
||||||
|
)
|
||||||
|
def test_destructive_op_resolves_to_ask(self, tool_name: str) -> None:
|
||||||
|
# surfsense_defaults says "allow */*"; desktop_safety must override
|
||||||
|
# because it's layered later (last-match-wins).
|
||||||
|
action = _action_for(tool_name, SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
|
||||||
|
assert action == "ask", (
|
||||||
|
f"{tool_name} must require approval in desktop mode "
|
||||||
|
f"(no revert path on real disk); got {action!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"tool_name",
|
||||||
|
["read_file", "ls", "list_tree", "grep", "glob", "cd", "pwd", "mkdir"],
|
||||||
|
)
|
||||||
|
def test_safe_ops_remain_allowed(self, tool_name: str) -> None:
|
||||||
|
# Read-only and trivially-reversible tools must NOT get gated —
|
||||||
|
# otherwise every navigation in desktop mode pops an interrupt.
|
||||||
|
action = _action_for(tool_name, SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
|
||||||
|
assert action == "allow", (
|
||||||
|
f"{tool_name} should not be gated in desktop mode; got {action!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDesktopSafetyOverridesAllowDefault:
|
||||||
|
def test_layer_order_last_match_wins(self) -> None:
|
||||||
|
# If desktop_safety is layered BEFORE surfsense_defaults, the allow
|
||||||
|
# default would win and the safety net would be inert. This test
|
||||||
|
# protects against accidentally swapping the rulesets in
|
||||||
|
# ``_build_compiled_agent_blocking``.
|
||||||
|
action = _action_for("rm", DESKTOP_SAFETY_RULESET, SURFSENSE_DEFAULTS)
|
||||||
|
# Layered "wrong way" — the broad allow now wins.
|
||||||
|
assert action == "allow"
|
||||||
|
|
||||||
|
# Correct order: defaults < desktop_safety -> ask wins.
|
||||||
|
action = _action_for("rm", SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET)
|
||||||
|
assert action == "ask"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPermissionMiddlewareIntegration:
|
||||||
|
def test_middleware_raises_interrupt_for_rm_in_desktop_mode(self) -> None:
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
from app.agents.new_chat.errors import RejectedError
|
||||||
|
|
||||||
|
mw = PermissionMiddleware(rulesets=[SURFSENSE_DEFAULTS, DESKTOP_SAFETY_RULESET])
|
||||||
|
# Stub the interrupt to a "reject" decision so we can assert the
|
||||||
|
# ask path was taken without spinning up the LangGraph runtime.
|
||||||
|
mw._raise_interrupt = lambda **kw: {"decision_type": "reject"} # type: ignore[assignment]
|
||||||
|
|
||||||
|
state = {
|
||||||
|
"messages": [
|
||||||
|
AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": "rm",
|
||||||
|
"args": {"path": "/Users/me/Documents/important.docx"},
|
||||||
|
"id": "tc-rm",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
class _FakeRuntime:
|
||||||
|
config: dict = {"configurable": {"thread_id": "test"}}
|
||||||
|
|
||||||
|
with pytest.raises(RejectedError):
|
||||||
|
mw.after_model(state, _FakeRuntime())
|
||||||
|
|
@ -0,0 +1,111 @@
|
||||||
|
"""Tests for the default auto-approval list in ``hitl.request_approval``.
|
||||||
|
|
||||||
|
These pin the policy that low-stakes connector creation tools (drafts,
|
||||||
|
new-file creates) skip the HITL interrupt by default. Without this set,
|
||||||
|
every "draft my newsletter" turn used to fire ~3 interrupts before any
|
||||||
|
useful work happened.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.hitl import (
|
||||||
|
DEFAULT_AUTO_APPROVED_TOOLS,
|
||||||
|
HITLResult,
|
||||||
|
request_approval,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class TestDefaultAutoApprovedToolsList:
|
||||||
|
def test_set_contains_expected_creation_tools(self) -> None:
|
||||||
|
# If anyone changes the policy list, we want a single test to
|
||||||
|
# update so the contract is explicit. Keep this in sync with
|
||||||
|
# ``hitl.DEFAULT_AUTO_APPROVED_TOOLS``.
|
||||||
|
expected = {
|
||||||
|
"create_gmail_draft",
|
||||||
|
"update_gmail_draft",
|
||||||
|
"create_notion_page",
|
||||||
|
"create_confluence_page",
|
||||||
|
"create_google_drive_file",
|
||||||
|
"create_dropbox_file",
|
||||||
|
"create_onedrive_file",
|
||||||
|
}
|
||||||
|
assert expected == DEFAULT_AUTO_APPROVED_TOOLS
|
||||||
|
|
||||||
|
def test_set_is_immutable(self) -> None:
|
||||||
|
# frozenset prevents accidental at-runtime mutation that would
|
||||||
|
# silently widen the auto-approval surface.
|
||||||
|
assert isinstance(DEFAULT_AUTO_APPROVED_TOOLS, frozenset)
|
||||||
|
|
||||||
|
def test_send_tools_are_not_auto_approved(self) -> None:
|
||||||
|
# External-broadcast tools must always prompt.
|
||||||
|
for tool_name in (
|
||||||
|
"send_gmail_email",
|
||||||
|
"send_discord_message",
|
||||||
|
"send_teams_message",
|
||||||
|
"delete_notion_page",
|
||||||
|
"create_calendar_event",
|
||||||
|
"delete_calendar_event",
|
||||||
|
):
|
||||||
|
assert tool_name not in DEFAULT_AUTO_APPROVED_TOOLS, (
|
||||||
|
f"{tool_name} must remain HITL-gated"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequestApprovalAutoBypass:
|
||||||
|
def test_auto_approved_tool_skips_interrupt(self) -> None:
|
||||||
|
# No interrupt mock set up — if the function attempted to call
|
||||||
|
# ``langgraph.types.interrupt`` it would raise GraphInterrupt.
|
||||||
|
# The fact that we get a clean HITLResult proves the bypass.
|
||||||
|
result = request_approval(
|
||||||
|
action_type="gmail_draft_creation",
|
||||||
|
tool_name="create_gmail_draft",
|
||||||
|
params={"to": "alice@example.com", "subject": "hi", "body": "hey"},
|
||||||
|
)
|
||||||
|
assert isinstance(result, HITLResult)
|
||||||
|
assert result.rejected is False
|
||||||
|
assert result.decision_type == "auto_approved"
|
||||||
|
# Original params are preserved untouched (no user edits possible).
|
||||||
|
assert result.params == {
|
||||||
|
"to": "alice@example.com",
|
||||||
|
"subject": "hi",
|
||||||
|
"body": "hey",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_non_listed_tool_still_attempts_interrupt(self) -> None:
|
||||||
|
# A tool NOT in the default list must reach ``langgraph.interrupt``.
|
||||||
|
# Outside a runnable context that call raises a RuntimeError —
|
||||||
|
# which is exactly the signal we want: the bypass did NOT fire.
|
||||||
|
with pytest.raises(RuntimeError, match="runnable context"):
|
||||||
|
request_approval(
|
||||||
|
action_type="gmail_email_send",
|
||||||
|
tool_name="send_gmail_email",
|
||||||
|
params={"to": "alice@example.com", "subject": "hi", "body": "hey"},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_user_trusted_tools_still_take_precedence(self) -> None:
|
||||||
|
# ``trusted_tools`` (per-connector "always allow" from MCP/UI)
|
||||||
|
# was checked BEFORE the default list and must keep working
|
||||||
|
# for tools outside the default list.
|
||||||
|
result = request_approval(
|
||||||
|
action_type="mcp_tool_call",
|
||||||
|
tool_name="my_custom_mcp_tool",
|
||||||
|
params={"x": 1},
|
||||||
|
trusted_tools=["my_custom_mcp_tool"],
|
||||||
|
)
|
||||||
|
assert result.decision_type == "trusted"
|
||||||
|
assert result.rejected is False
|
||||||
|
|
||||||
|
def test_auto_approved_overrides_no_trusted_tools(self) -> None:
|
||||||
|
# When trusted_tools is empty and tool is in the default list,
|
||||||
|
# we should still bypass — proves the order in request_approval.
|
||||||
|
result = request_approval(
|
||||||
|
action_type="notion_page_creation",
|
||||||
|
tool_name="create_notion_page",
|
||||||
|
params={"title": "Plan"},
|
||||||
|
trusted_tools=[],
|
||||||
|
)
|
||||||
|
assert result.decision_type == "auto_approved"
|
||||||
|
|
@ -0,0 +1,350 @@
|
||||||
|
"""Tests for ``apply_litellm_prompt_caching`` in
|
||||||
|
:mod:`app.agents.new_chat.prompt_caching`.
|
||||||
|
|
||||||
|
The helper replaces the legacy ``AnthropicPromptCachingMiddleware`` (which
|
||||||
|
never activated for our LiteLLM stack) with LiteLLM-native multi-provider
|
||||||
|
prompt caching. It mutates ``llm.model_kwargs`` so the kwargs flow to
|
||||||
|
``litellm.completion(...)``. The tests below pin its public contract:
|
||||||
|
|
||||||
|
1. Always sets BOTH ``role: system`` and ``index: -1`` injection points so
|
||||||
|
savings compound across multi-turn conversations on Anthropic-family
|
||||||
|
providers.
|
||||||
|
2. Adds ``prompt_cache_key``/``prompt_cache_retention`` only for
|
||||||
|
single-model OPENAI/DEEPSEEK/XAI configs (where OpenAI's automatic
|
||||||
|
prompt-cache surface is available).
|
||||||
|
3. Treats ``ChatLiteLLMRouter`` (auto-mode) as universal-only — no
|
||||||
|
OpenAI-only kwargs because the router fans out across providers.
|
||||||
|
4. Idempotent: user-supplied values in ``model_kwargs`` are preserved.
|
||||||
|
5. Defensive: LLMs without a writable ``model_kwargs`` are silently
|
||||||
|
skipped rather than raising.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.llm_config import AgentConfig
|
||||||
|
from app.agents.new_chat.prompt_caching import apply_litellm_prompt_caching
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Test doubles
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeLLM:
|
||||||
|
"""Stand-in for ``ChatLiteLLM``/``SanitizedChatLiteLLM``.
|
||||||
|
|
||||||
|
The helper only inspects ``getattr(llm, "model_kwargs", None)``,
|
||||||
|
``getattr(llm, "model", None)``, and ``type(llm).__name__``. A simple
|
||||||
|
object suffices — we don't need to spin up real LangChain/LiteLLM
|
||||||
|
machinery for unit tests of the helper's logic.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "openai/gpt-4o",
|
||||||
|
model_kwargs: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.model = model
|
||||||
|
self.model_kwargs: dict[str, Any] = dict(model_kwargs) if model_kwargs else {}
|
||||||
|
|
||||||
|
|
||||||
|
class ChatLiteLLMRouter:
|
||||||
|
"""Class-name-only impostor of the real router.
|
||||||
|
|
||||||
|
The helper's router gate is ``type(llm).__name__ == "ChatLiteLLMRouter"``
|
||||||
|
(a deliberate stringly-typed check to avoid an import cycle with
|
||||||
|
``app.services.llm_router_service``). Reusing the same class name here
|
||||||
|
triggers the same code path without instantiating a real ``Router``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.model = "auto"
|
||||||
|
self.model_kwargs: dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_cfg(**overrides: Any) -> AgentConfig:
|
||||||
|
"""Build an ``AgentConfig`` with sensible defaults for the helper test."""
|
||||||
|
defaults: dict[str, Any] = {
|
||||||
|
"provider": "OPENAI",
|
||||||
|
"model_name": "gpt-4o",
|
||||||
|
"api_key": "k",
|
||||||
|
}
|
||||||
|
return AgentConfig(**{**defaults, **overrides})
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# (a) Universal injection points
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_sets_both_cache_control_injection_points_with_no_config() -> None:
|
||||||
|
"""Bare call (no agent_config, no thread_id) still sets the two
|
||||||
|
universal breakpoints — these cost nothing on providers that don't
|
||||||
|
consume them and unlock caching on every supported provider."""
|
||||||
|
llm = _FakeLLM()
|
||||||
|
|
||||||
|
apply_litellm_prompt_caching(llm)
|
||||||
|
|
||||||
|
points = llm.model_kwargs["cache_control_injection_points"]
|
||||||
|
assert {"location": "message", "role": "system"} in points
|
||||||
|
assert {"location": "message", "index": -1} in points
|
||||||
|
assert len(points) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_injection_points_set_for_anthropic_config() -> None:
|
||||||
|
"""Anthropic-family configs need the marker — verify it lands."""
|
||||||
|
cfg = _make_cfg(provider="ANTHROPIC", model_name="claude-3-5-sonnet")
|
||||||
|
llm = _FakeLLM(model="anthropic/claude-3-5-sonnet")
|
||||||
|
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=cfg)
|
||||||
|
|
||||||
|
assert "cache_control_injection_points" in llm.model_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# (b) Idempotency / user override wins
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_does_not_overwrite_user_supplied_cache_control_injection_points() -> None:
|
||||||
|
"""Users who set their own injection points (e.g. with ``ttl: "1h"``
|
||||||
|
via ``litellm_params``) keep them — the helper merges, never
|
||||||
|
clobbers."""
|
||||||
|
user_points = [
|
||||||
|
{"location": "message", "role": "system", "ttl": "1h"},
|
||||||
|
]
|
||||||
|
llm = _FakeLLM(
|
||||||
|
model_kwargs={"cache_control_injection_points": user_points},
|
||||||
|
)
|
||||||
|
|
||||||
|
apply_litellm_prompt_caching(llm)
|
||||||
|
|
||||||
|
assert llm.model_kwargs["cache_control_injection_points"] is user_points
|
||||||
|
|
||||||
|
|
||||||
|
def test_idempotent_when_called_multiple_times() -> None:
|
||||||
|
"""Build-time + thread-time double-call must be a no-op the second time."""
|
||||||
|
cfg = _make_cfg(provider="OPENAI")
|
||||||
|
llm = _FakeLLM()
|
||||||
|
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=1)
|
||||||
|
snapshot = {
|
||||||
|
"cache_control_injection_points": list(
|
||||||
|
llm.model_kwargs["cache_control_injection_points"]
|
||||||
|
),
|
||||||
|
"prompt_cache_key": llm.model_kwargs["prompt_cache_key"],
|
||||||
|
"prompt_cache_retention": llm.model_kwargs["prompt_cache_retention"],
|
||||||
|
}
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=1)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
llm.model_kwargs["cache_control_injection_points"]
|
||||||
|
== snapshot["cache_control_injection_points"]
|
||||||
|
)
|
||||||
|
assert llm.model_kwargs["prompt_cache_key"] == snapshot["prompt_cache_key"]
|
||||||
|
assert (
|
||||||
|
llm.model_kwargs["prompt_cache_retention"] == snapshot["prompt_cache_retention"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_does_not_overwrite_user_supplied_prompt_cache_key() -> None:
|
||||||
|
"""A pre-set ``prompt_cache_key`` (e.g. tenant-aware override via
|
||||||
|
``litellm_params``) wins over our default per-thread key."""
|
||||||
|
cfg = _make_cfg(provider="OPENAI")
|
||||||
|
llm = _FakeLLM(model_kwargs={"prompt_cache_key": "tenant-abc"})
|
||||||
|
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
|
||||||
|
|
||||||
|
assert llm.model_kwargs["prompt_cache_key"] == "tenant-abc"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# (c) OpenAI-family extras (OPENAI / DEEPSEEK / XAI)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("provider", ["OPENAI", "DEEPSEEK", "XAI"])
|
||||||
|
def test_sets_openai_family_extras(provider: str) -> None:
|
||||||
|
"""OpenAI-style providers gain ``prompt_cache_key`` (raises hit rate
|
||||||
|
via routing affinity) and ``prompt_cache_retention="24h"`` (extends
|
||||||
|
cache TTL beyond the default 5-10 min)."""
|
||||||
|
cfg = _make_cfg(provider=provider)
|
||||||
|
llm = _FakeLLM()
|
||||||
|
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
|
||||||
|
|
||||||
|
assert llm.model_kwargs["prompt_cache_key"] == "surfsense-thread-42"
|
||||||
|
assert llm.model_kwargs["prompt_cache_retention"] == "24h"
|
||||||
|
|
||||||
|
|
||||||
|
def test_skips_prompt_cache_key_when_no_thread_id() -> None:
|
||||||
|
"""Without a thread id we can't construct a per-thread key. Retention
|
||||||
|
is still useful so we set it (it's free)."""
|
||||||
|
cfg = _make_cfg(provider="OPENAI")
|
||||||
|
llm = _FakeLLM()
|
||||||
|
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=None)
|
||||||
|
|
||||||
|
assert "prompt_cache_key" not in llm.model_kwargs
|
||||||
|
assert llm.model_kwargs["prompt_cache_retention"] == "24h"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider",
|
||||||
|
["ANTHROPIC", "BEDROCK", "VERTEX_AI", "GOOGLE_AI_STUDIO", "GROQ", "MOONSHOT"],
|
||||||
|
)
|
||||||
|
def test_no_openai_extras_for_other_providers(provider: str) -> None:
|
||||||
|
"""Non-OpenAI-family providers don't expose ``prompt_cache_key`` —
|
||||||
|
skip it. ``cache_control_injection_points`` is still set (universal)."""
|
||||||
|
cfg = _make_cfg(provider=provider)
|
||||||
|
llm = _FakeLLM()
|
||||||
|
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
|
||||||
|
|
||||||
|
assert "prompt_cache_key" not in llm.model_kwargs
|
||||||
|
assert "prompt_cache_retention" not in llm.model_kwargs
|
||||||
|
assert "cache_control_injection_points" in llm.model_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_openai_extras_in_auto_mode() -> None:
|
||||||
|
"""Auto-mode fans out across mixed providers — we can't statically
|
||||||
|
target OpenAI-only kwargs."""
|
||||||
|
cfg = AgentConfig.from_auto_mode()
|
||||||
|
llm = _FakeLLM()
|
||||||
|
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
|
||||||
|
|
||||||
|
assert "prompt_cache_key" not in llm.model_kwargs
|
||||||
|
assert "prompt_cache_retention" not in llm.model_kwargs
|
||||||
|
assert "cache_control_injection_points" in llm.model_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_openai_extras_for_custom_provider() -> None:
|
||||||
|
"""Custom providers route through arbitrary user-supplied prefixes —
|
||||||
|
we don't try to infer OpenAI-family compatibility."""
|
||||||
|
cfg = _make_cfg(provider="OPENAI", custom_provider="my_proxy")
|
||||||
|
llm = _FakeLLM()
|
||||||
|
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=cfg, thread_id=42)
|
||||||
|
|
||||||
|
assert "prompt_cache_key" not in llm.model_kwargs
|
||||||
|
assert "prompt_cache_retention" not in llm.model_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# (d) ChatLiteLLMRouter — universal injection points only
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_llm_gets_only_universal_injection_points() -> None:
|
||||||
|
"""Even with an OpenAI-flavoured config, a ``ChatLiteLLMRouter`` must
|
||||||
|
receive only the universal injection points — its requests dispatch
|
||||||
|
across provider deployments and OpenAI-only kwargs would be wasted
|
||||||
|
(or stripped by ``drop_params``) on non-OpenAI legs."""
|
||||||
|
router = ChatLiteLLMRouter()
|
||||||
|
cfg = _make_cfg(provider="OPENAI")
|
||||||
|
|
||||||
|
apply_litellm_prompt_caching(router, agent_config=cfg, thread_id=42)
|
||||||
|
|
||||||
|
assert "cache_control_injection_points" in router.model_kwargs
|
||||||
|
assert "prompt_cache_key" not in router.model_kwargs
|
||||||
|
assert "prompt_cache_retention" not in router.model_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# (e) Defensive paths
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_handles_llm_with_no_writable_model_kwargs() -> None:
|
||||||
|
"""Some LLM implementations (e.g. fakes / minimal subclasses) don't
|
||||||
|
expose a writable ``model_kwargs``. The helper must skip silently —
|
||||||
|
raising would crash the entire LLM build path on a non-critical
|
||||||
|
optimisation."""
|
||||||
|
|
||||||
|
class _ImmutableLLM:
|
||||||
|
# ``__slots__`` blocks attribute creation, so ``setattr`` raises.
|
||||||
|
__slots__ = ("model",)
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.model = "openai/gpt-4o"
|
||||||
|
|
||||||
|
llm = _ImmutableLLM()
|
||||||
|
|
||||||
|
apply_litellm_prompt_caching(llm)
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialises_missing_model_kwargs_dict() -> None:
|
||||||
|
"""When ``model_kwargs`` is present-but-None (Pydantic v2 default
|
||||||
|
pattern when no factory is set), the helper initialises it to an
|
||||||
|
empty dict before mutating."""
|
||||||
|
|
||||||
|
class _LazyLLM:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.model = "openai/gpt-4o"
|
||||||
|
self.model_kwargs: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
llm = _LazyLLM()
|
||||||
|
|
||||||
|
apply_litellm_prompt_caching(llm)
|
||||||
|
|
||||||
|
assert isinstance(llm.model_kwargs, dict)
|
||||||
|
assert "cache_control_injection_points" in llm.model_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def test_falls_back_to_llm_model_prefix_when_no_agent_config() -> None:
|
||||||
|
"""Direct caller path (e.g. ``create_chat_litellm_from_config`` for
|
||||||
|
YAML configs without a structured ``AgentConfig``): without
|
||||||
|
``agent_config`` the helper sets only the universal injection points
|
||||||
|
— no OpenAI-family extras even if the prefix says ``openai/``.
|
||||||
|
Conservative: we'd rather miss the speedup than silently misroute."""
|
||||||
|
llm = _FakeLLM(model="openai/gpt-4o")
|
||||||
|
|
||||||
|
apply_litellm_prompt_caching(llm, agent_config=None, thread_id=99)
|
||||||
|
|
||||||
|
assert "cache_control_injection_points" in llm.model_kwargs
|
||||||
|
assert "prompt_cache_key" not in llm.model_kwargs
|
||||||
|
assert "prompt_cache_retention" not in llm.model_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# (f) drop_params safety net (regression guard for #19346)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_litellm_drop_params_is_globally_enabled() -> None:
|
||||||
|
"""``litellm.drop_params=True`` is set globally in
|
||||||
|
:mod:`app.services.llm_service` so any ``prompt_cache_key`` /
|
||||||
|
``prompt_cache_retention`` we set on an OpenAI-family config is
|
||||||
|
auto-stripped if the request later routes to a non-supporting
|
||||||
|
provider (e.g. via auto-mode router fallback). This test pins that
|
||||||
|
invariant — losing it would mean Bedrock/Vertex 400s on ``prompt_cache_key``.
|
||||||
|
"""
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
import app.services.llm_service # noqa: F401 (side-effect: sets globals)
|
||||||
|
|
||||||
|
assert litellm.drop_params is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Regression note: LiteLLM #15696 (multi-content-block last message)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
#
|
||||||
|
# Before LiteLLM 1.81 a list-form last message ``[block_a, block_b]``
|
||||||
|
# would get ``cache_control`` applied to *every* content block instead
|
||||||
|
# of only the last one — wasting cache breakpoints and triggering 400s
|
||||||
|
# on Anthropic when it exceeded the 4-breakpoint limit. Fixed in
|
||||||
|
# https://github.com/BerriAI/litellm/pull/15699.
|
||||||
|
#
|
||||||
|
# We pin ``litellm>=1.83.7`` in ``pyproject.toml`` (well past the fix).
|
||||||
|
# An end-to-end behavioural test would need to run ``litellm.completion``
|
||||||
|
# through the Anthropic transformer, which is integration territory and
|
||||||
|
# better covered by LiteLLM's own test suite. The unit guard here is the
|
||||||
|
# version pin plus the build-time ``model_kwargs`` shape we verify above.
|
||||||
|
|
@ -0,0 +1,117 @@
|
||||||
|
"""Tests for ``_resolve_prompt_model_name`` in :mod:`app.agents.new_chat.chat_deepagent`.
|
||||||
|
|
||||||
|
The helper picks the model id fed to ``detect_provider_variant`` so the
|
||||||
|
right ``<provider_hints>`` block lands in the system prompt. The tests
|
||||||
|
below pin its preference order:
|
||||||
|
|
||||||
|
1. ``agent_config.litellm_params["base_model"]`` (Azure-correct).
|
||||||
|
2. ``agent_config.model_name``.
|
||||||
|
3. ``getattr(llm, "model", None)``.
|
||||||
|
|
||||||
|
Without (1) an Azure deployment named e.g. ``"prod-chat-001"`` would
|
||||||
|
silently miss every provider regex.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.chat_deepagent import _resolve_prompt_model_name
|
||||||
|
from app.agents.new_chat.llm_config import AgentConfig
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _make_cfg(**overrides) -> AgentConfig:
|
||||||
|
"""Build an ``AgentConfig`` with sensible defaults for the helper test."""
|
||||||
|
defaults = {
|
||||||
|
"provider": "OPENAI",
|
||||||
|
"model_name": "x",
|
||||||
|
"api_key": "k",
|
||||||
|
}
|
||||||
|
return AgentConfig(**{**defaults, **overrides})
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeLLM:
|
||||||
|
"""Stand-in for a ``ChatLiteLLM`` / ``ChatLiteLLMRouter`` instance.
|
||||||
|
|
||||||
|
The resolver only reads the ``.model`` attribute via ``getattr``,
|
||||||
|
matching the established idiom in ``knowledge_search.py`` /
|
||||||
|
``stream_new_chat.py`` / ``document_summarizer.py``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: str | None) -> None:
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
|
||||||
|
def test_prefers_litellm_params_base_model_over_deployment_name() -> None:
|
||||||
|
"""Azure deployment slug must NOT shadow the underlying model family.
|
||||||
|
|
||||||
|
This is the failure mode the helper exists to prevent: a deployment
|
||||||
|
named ``"azure/prod-chat-001"`` would not match any provider regex
|
||||||
|
on its own, but the family ``"gpt-4o"`` lives in
|
||||||
|
``litellm_params["base_model"]`` and routes to ``openai_classic``.
|
||||||
|
"""
|
||||||
|
cfg = _make_cfg(
|
||||||
|
model_name="azure/prod-chat-001",
|
||||||
|
litellm_params={"base_model": "gpt-4o"},
|
||||||
|
)
|
||||||
|
assert _resolve_prompt_model_name(cfg, _FakeLLM("azure/prod-chat-001")) == "gpt-4o"
|
||||||
|
|
||||||
|
|
||||||
|
def test_falls_back_to_model_name_when_litellm_params_is_none() -> None:
|
||||||
|
cfg = _make_cfg(
|
||||||
|
model_name="anthropic/claude-3-5-sonnet",
|
||||||
|
litellm_params=None,
|
||||||
|
)
|
||||||
|
got = _resolve_prompt_model_name(cfg, _FakeLLM("anthropic/claude-3-5-sonnet"))
|
||||||
|
assert got == "anthropic/claude-3-5-sonnet"
|
||||||
|
|
||||||
|
|
||||||
|
def test_handles_litellm_params_without_base_model_key() -> None:
|
||||||
|
cfg = _make_cfg(
|
||||||
|
model_name="openai/gpt-4o",
|
||||||
|
litellm_params={"temperature": 0.5},
|
||||||
|
)
|
||||||
|
assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o"
|
||||||
|
|
||||||
|
|
||||||
|
def test_ignores_blank_base_model() -> None:
|
||||||
|
"""Whitespace-only ``base_model`` must not shadow ``model_name``."""
|
||||||
|
cfg = _make_cfg(
|
||||||
|
model_name="openai/gpt-4o",
|
||||||
|
litellm_params={"base_model": " "},
|
||||||
|
)
|
||||||
|
assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o"
|
||||||
|
|
||||||
|
|
||||||
|
def test_ignores_non_string_base_model() -> None:
|
||||||
|
"""Defensive: a non-string ``base_model`` should not crash the resolver."""
|
||||||
|
cfg = _make_cfg(
|
||||||
|
model_name="openai/gpt-4o",
|
||||||
|
litellm_params={"base_model": 42},
|
||||||
|
)
|
||||||
|
assert _resolve_prompt_model_name(cfg, _FakeLLM("openai/gpt-4o")) == "openai/gpt-4o"
|
||||||
|
|
||||||
|
|
||||||
|
def test_falls_back_to_llm_model_when_no_agent_config() -> None:
|
||||||
|
"""No ``agent_config`` -> use ``llm.model`` directly. Defensive path
|
||||||
|
for direct callers; production callers always supply a config."""
|
||||||
|
assert (
|
||||||
|
_resolve_prompt_model_name(None, _FakeLLM("openai/gpt-4o-mini"))
|
||||||
|
== "openai/gpt-4o-mini"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_returns_none_when_nothing_available() -> None:
|
||||||
|
"""``compose_system_prompt`` treats ``None`` as the ``"default"``
|
||||||
|
variant and emits no provider block."""
|
||||||
|
assert _resolve_prompt_model_name(None, _FakeLLM(None)) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_mode_resolves_to_auto_string() -> None:
|
||||||
|
"""Auto mode -> ``"auto"``. ``detect_provider_variant("auto")``
|
||||||
|
returns ``"default"``, which is correct: the child model isn't
|
||||||
|
known until the LiteLLM Router dispatches."""
|
||||||
|
cfg = AgentConfig.from_auto_mode()
|
||||||
|
assert _resolve_prompt_model_name(cfg, _FakeLLM("auto")) == "auto"
|
||||||
|
|
@ -0,0 +1,333 @@
|
||||||
|
"""Cloud-mode behavior tests for the new ``rm`` and ``rmdir`` filesystem tools.
|
||||||
|
|
||||||
|
The tools build ``Command(update=...)`` payloads that the persistence
|
||||||
|
middleware applies at end of turn. These tests stub out the backend and
|
||||||
|
runtime to assert the staging payload shape:
|
||||||
|
|
||||||
|
* ``rm`` queues into ``pending_deletes`` and tombstones state files.
|
||||||
|
* ``rm`` rejects directories, ``/documents``, root, and the anonymous doc.
|
||||||
|
* ``rmdir`` queues into ``pending_dir_deletes`` and rejects non-empty dirs.
|
||||||
|
* ``rmdir`` un-stages a same-turn ``mkdir`` rather than queuing a delete.
|
||||||
|
* ``rmdir`` refuses to drop the cwd or any of its ancestors.
|
||||||
|
* ``KBPostgresBackend`` view-helpers honor staged deletes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||||
|
from app.agents.new_chat.middleware.filesystem import SurfSenseFilesystemMiddleware
|
||||||
|
from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _make_middleware(mode: FilesystemMode = FilesystemMode.CLOUD):
|
||||||
|
middleware = SurfSenseFilesystemMiddleware.__new__(SurfSenseFilesystemMiddleware)
|
||||||
|
middleware._filesystem_mode = mode
|
||||||
|
middleware._custom_tool_descriptions = {}
|
||||||
|
return middleware
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime(state: dict[str, Any] | None = None, *, tool_call_id: str = "tc-abc"):
|
||||||
|
state = state or {}
|
||||||
|
state.setdefault("cwd", "/documents")
|
||||||
|
return SimpleNamespace(state=state, tool_call_id=tool_call_id)
|
||||||
|
|
||||||
|
|
||||||
|
class _KBBackendStub(KBPostgresBackend):
|
||||||
|
"""Construct-able subclass of :class:`KBPostgresBackend` for tests.
|
||||||
|
|
||||||
|
We bypass the real ``__init__`` (which expects a runtime + DB session)
|
||||||
|
and inject just the methods the rm/rmdir tools touch. The class
|
||||||
|
inheritance keeps ``isinstance(backend, KBPostgresBackend)`` checks
|
||||||
|
inside the tools happy, which is what gates them from the desktop
|
||||||
|
code path.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, children=None, file_data=None) -> None:
|
||||||
|
self.als_info = AsyncMock(return_value=children or [])
|
||||||
|
self._load_file_data = AsyncMock(
|
||||||
|
return_value=(file_data, 17) if file_data is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_backend_stub(*, children=None, file_data=None) -> KBPostgresBackend:
|
||||||
|
return _KBBackendStub(children=children, file_data=file_data)
|
||||||
|
|
||||||
|
|
||||||
|
def _bind_backend(middleware, backend):
|
||||||
|
"""Inject a backend resolver onto the middleware test instance."""
|
||||||
|
middleware._get_backend = lambda runtime: backend
|
||||||
|
return backend
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# rm
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRmStaging:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stages_delete_and_tombstones_state(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
|
||||||
|
runtime = _runtime(
|
||||||
|
{
|
||||||
|
"cwd": "/documents",
|
||||||
|
"files": {"/documents/notes.md": {"content": ["hello"]}},
|
||||||
|
"doc_id_by_path": {"/documents/notes.md": 17},
|
||||||
|
},
|
||||||
|
tool_call_id="tc-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
tool = m._create_rm_tool()
|
||||||
|
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
|
||||||
|
|
||||||
|
assert hasattr(result, "update"), f"expected Command, got {result!r}"
|
||||||
|
update = result.update
|
||||||
|
assert update["pending_deletes"] == [
|
||||||
|
{"path": "/documents/notes.md", "tool_call_id": "tc-1"}
|
||||||
|
]
|
||||||
|
assert update["files"] == {"/documents/notes.md": None}
|
||||||
|
assert update["doc_id_by_path"] == {"/documents/notes.md": None}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_documents_root(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime()
|
||||||
|
tool = m._create_rm_tool()
|
||||||
|
result = await tool.coroutine("/documents", runtime=runtime)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "refusing to rm" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_root(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime()
|
||||||
|
tool = m._create_rm_tool()
|
||||||
|
result = await tool.coroutine("/", runtime=runtime)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "refusing to rm" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_directory_via_staged_dirs(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime(
|
||||||
|
{
|
||||||
|
"staged_dirs": ["/documents/team-x"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tool = m._create_rm_tool()
|
||||||
|
result = await tool.coroutine("/documents/team-x", runtime=runtime)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "directory" in result.lower()
|
||||||
|
assert "rmdir" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_directory_via_listing(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
_bind_backend(
|
||||||
|
m,
|
||||||
|
_make_backend_stub(
|
||||||
|
children=[{"path": "/documents/foo/x.md", "is_dir": False}]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
runtime = _runtime()
|
||||||
|
tool = m._create_rm_tool()
|
||||||
|
result = await tool.coroutine("/documents/foo", runtime=runtime)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "directory" in result.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_anonymous_doc(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime(
|
||||||
|
{
|
||||||
|
"kb_anon_doc": {
|
||||||
|
"path": "/documents/uploaded.xml",
|
||||||
|
"title": "uploaded",
|
||||||
|
"content": "",
|
||||||
|
"chunks": [],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tool = m._create_rm_tool()
|
||||||
|
result = await tool.coroutine("/documents/uploaded.xml", runtime=runtime)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "read-only" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_drops_path_from_dirty_paths(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
|
||||||
|
runtime = _runtime(
|
||||||
|
{
|
||||||
|
"files": {"/documents/notes.md": {"content": ["x"]}},
|
||||||
|
"doc_id_by_path": {"/documents/notes.md": 17},
|
||||||
|
"dirty_paths": ["/documents/notes.md"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tool = m._create_rm_tool()
|
||||||
|
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
|
||||||
|
update = result.update
|
||||||
|
# First element is _CLEAR sentinel; the rest must NOT contain the
|
||||||
|
# rm'd path.
|
||||||
|
dirty = update.get("dirty_paths") or []
|
||||||
|
assert "/documents/notes.md" not in dirty[1:]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# rmdir
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRmdirStaging:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stages_dir_delete_when_empty_and_db_backed(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
backend = _bind_backend(m, _make_backend_stub(children=[]))
|
||||||
|
# Override _load_file_data to return None (folder, not a file) and
|
||||||
|
# parent listing to claim the folder exists.
|
||||||
|
backend._load_file_data = AsyncMock(return_value=None)
|
||||||
|
backend.als_info = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
[], # children of /documents/proj
|
||||||
|
[
|
||||||
|
{"path": "/documents/proj", "is_dir": True},
|
||||||
|
], # parent listing
|
||||||
|
]
|
||||||
|
)
|
||||||
|
runtime = _runtime(
|
||||||
|
{
|
||||||
|
"cwd": "/documents",
|
||||||
|
},
|
||||||
|
tool_call_id="tc-rd",
|
||||||
|
)
|
||||||
|
|
||||||
|
tool = m._create_rmdir_tool()
|
||||||
|
result = await tool.coroutine("/documents/proj", runtime=runtime)
|
||||||
|
|
||||||
|
assert hasattr(result, "update")
|
||||||
|
update = result.update
|
||||||
|
assert update["pending_dir_deletes"] == [
|
||||||
|
{"path": "/documents/proj", "tool_call_id": "tc-rd"}
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_non_empty(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
_bind_backend(
|
||||||
|
m,
|
||||||
|
_make_backend_stub(
|
||||||
|
children=[{"path": "/documents/proj/x.md", "is_dir": False}]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
runtime = _runtime()
|
||||||
|
tool = m._create_rmdir_tool()
|
||||||
|
result = await tool.coroutine("/documents/proj", runtime=runtime)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "not empty" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unstages_same_turn_mkdir(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
_bind_backend(m, _make_backend_stub(children=[]))
|
||||||
|
runtime = _runtime(
|
||||||
|
{
|
||||||
|
"cwd": "/documents",
|
||||||
|
"staged_dirs": ["/documents/scratch"],
|
||||||
|
},
|
||||||
|
tool_call_id="tc-rd",
|
||||||
|
)
|
||||||
|
tool = m._create_rmdir_tool()
|
||||||
|
result = await tool.coroutine("/documents/scratch", runtime=runtime)
|
||||||
|
|
||||||
|
assert hasattr(result, "update")
|
||||||
|
update = result.update
|
||||||
|
assert "pending_dir_deletes" not in update
|
||||||
|
# _CLEAR sentinel + remaining items (in this case, none).
|
||||||
|
staged_after = update["staged_dirs"]
|
||||||
|
assert staged_after[0] == "\x00__SURFSENSE_FILESYSTEM_CLEAR__\x00"
|
||||||
|
assert "/documents/scratch" not in staged_after[1:]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_root(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime()
|
||||||
|
tool = m._create_rmdir_tool()
|
||||||
|
for victim in ("/", "/documents"):
|
||||||
|
result = await tool.coroutine(victim, runtime=runtime)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "refusing to rmdir" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_cwd(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime({"cwd": "/documents/proj"})
|
||||||
|
tool = m._create_rmdir_tool()
|
||||||
|
result = await tool.coroutine("/documents/proj", runtime=runtime)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "cwd" in result.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_ancestor_of_cwd(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
runtime = _runtime({"cwd": "/documents/proj/sub"})
|
||||||
|
tool = m._create_rmdir_tool()
|
||||||
|
result = await tool.coroutine("/documents/proj", runtime=runtime)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "cwd" in result.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_files(self):
|
||||||
|
m = _make_middleware()
|
||||||
|
_bind_backend(m, _make_backend_stub(children=[], file_data={"content": ["x"]}))
|
||||||
|
runtime = _runtime()
|
||||||
|
tool = m._create_rmdir_tool()
|
||||||
|
result = await tool.coroutine("/documents/notes.md", runtime=runtime)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "is a file" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# KBPostgresBackend view filter
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestKBPostgresBackendDeleteFilter:
|
||||||
|
"""als_info / glob / grep should suppress paths queued for delete."""
|
||||||
|
|
||||||
|
def _make_backend(self, state: dict[str, Any]) -> KBPostgresBackend:
|
||||||
|
runtime = SimpleNamespace(state=state)
|
||||||
|
backend = KBPostgresBackend(search_space_id=1, runtime=runtime)
|
||||||
|
return backend
|
||||||
|
|
||||||
|
def test_pending_filesystem_view_returns_deleted_paths(self):
|
||||||
|
backend = self._make_backend(
|
||||||
|
{
|
||||||
|
"pending_deletes": [
|
||||||
|
{"path": "/documents/x.md", "tool_call_id": "t1"},
|
||||||
|
],
|
||||||
|
"pending_dir_deletes": [
|
||||||
|
{"path": "/documents/d1", "tool_call_id": "t2"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
removed, alias, deleted_dirs = backend._pending_filesystem_view({})
|
||||||
|
assert "/documents/x.md" in removed
|
||||||
|
assert "/documents/d1" in deleted_dirs
|
||||||
|
assert alias == {}
|
||||||
|
|
||||||
|
def test_dir_suppressed_covers_descendants(self):
|
||||||
|
backend = self._make_backend({})
|
||||||
|
deleted_dirs = {"/documents/d"}
|
||||||
|
assert backend._is_dir_suppressed("/documents/d", deleted_dirs)
|
||||||
|
assert backend._is_dir_suppressed("/documents/d/x.md", deleted_dirs)
|
||||||
|
assert backend._is_dir_suppressed("/documents/d/sub/y.md", deleted_dirs)
|
||||||
|
assert not backend._is_dir_suppressed("/documents/other.md", deleted_dirs)
|
||||||
|
|
@ -98,10 +98,54 @@ class TestInitialFilesystemState:
|
||||||
state = _initial_filesystem_state()
|
state = _initial_filesystem_state()
|
||||||
assert state["cwd"] == "/documents"
|
assert state["cwd"] == "/documents"
|
||||||
assert state["staged_dirs"] == []
|
assert state["staged_dirs"] == []
|
||||||
|
assert state["staged_dir_tool_calls"] == {}
|
||||||
assert state["pending_moves"] == []
|
assert state["pending_moves"] == []
|
||||||
|
assert state["pending_deletes"] == []
|
||||||
|
assert state["pending_dir_deletes"] == []
|
||||||
assert state["doc_id_by_path"] == {}
|
assert state["doc_id_by_path"] == {}
|
||||||
assert state["dirty_paths"] == []
|
assert state["dirty_paths"] == []
|
||||||
|
assert state["dirty_path_tool_calls"] == {}
|
||||||
assert state["kb_priority"] == []
|
assert state["kb_priority"] == []
|
||||||
assert state["kb_matched_chunk_ids"] == {}
|
assert state["kb_matched_chunk_ids"] == {}
|
||||||
assert state["kb_anon_doc"] is None
|
assert state["kb_anon_doc"] is None
|
||||||
assert state["tree_version"] == 0
|
assert state["tree_version"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiEditSamePathCoalescing:
|
||||||
|
"""Multi-edit-same-path turns must coalesce into ONE binding record.
|
||||||
|
|
||||||
|
The persistence body uses ``dirty_path_tool_calls[path]`` to find the
|
||||||
|
tool_call_id that produced the current state on disk. Because
|
||||||
|
``dirty_paths`` dedupes via :func:`_add_unique_reducer` the second
|
||||||
|
edit doesn't append a new path entry — and because
|
||||||
|
``_dict_merge_with_tombstones_reducer`` lets the right-hand side
|
||||||
|
overwrite, the LATEST tool_call_id wins. That's the correct behavior
|
||||||
|
for snapshotting: revert restores to the pre-mutation state, and
|
||||||
|
multiple back-to-back edits in one turn coalesce into a single
|
||||||
|
revisible op (the user sees ONE Revert button per turn-per-path,
|
||||||
|
not N).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_dirty_paths_dedupes_repeated_writes(self):
|
||||||
|
# ``_add_unique_reducer`` is applied to ``dirty_paths``. Two writes
|
||||||
|
# to the same path produce one entry, not two.
|
||||||
|
first = _add_unique_reducer([], ["/documents/a.md"])
|
||||||
|
second = _add_unique_reducer(first, ["/documents/a.md"])
|
||||||
|
assert second == ["/documents/a.md"]
|
||||||
|
|
||||||
|
def test_dirty_path_tool_calls_keeps_latest_tool_call_id(self):
|
||||||
|
# First write tags the path with tcid-1.
|
||||||
|
merged = _dict_merge_with_tombstones_reducer({}, {"/documents/a.md": "tcid-1"})
|
||||||
|
# Second write to the same path tags it with tcid-2 (latest wins).
|
||||||
|
merged = _dict_merge_with_tombstones_reducer(
|
||||||
|
merged, {"/documents/a.md": "tcid-2"}
|
||||||
|
)
|
||||||
|
assert merged == {"/documents/a.md": "tcid-2"}
|
||||||
|
|
||||||
|
def test_rm_tombstones_dirty_path_tool_call(self):
|
||||||
|
# ``rm`` writes ``{path: None}`` into dirty_path_tool_calls to
|
||||||
|
# prevent a stale binding from leaking past the delete.
|
||||||
|
merged = _dict_merge_with_tombstones_reducer(
|
||||||
|
{"/documents/a.md": "tcid-1"}, {"/documents/a.md": None}
|
||||||
|
)
|
||||||
|
assert merged == {}
|
||||||
|
|
|
||||||
0
surfsense_backend/tests/unit/db/__init__.py
Normal file
0
surfsense_backend/tests/unit/db/__init__.py
Normal file
|
|
@ -0,0 +1,83 @@
|
||||||
|
"""Smoke test for the ``134_relax_revision_fks`` Alembic migration.
|
||||||
|
|
||||||
|
A full apply/rollback test would require a live Postgres; here we verify
|
||||||
|
the migration module's static contract:
|
||||||
|
|
||||||
|
* The chain wires it as a successor of ``133_drop_documents_content_hash_unique``.
|
||||||
|
* ``upgrade()`` declares two FK creations with ``ondelete='SET NULL'``
|
||||||
|
(one for ``document_revisions.document_id``, one for
|
||||||
|
``folder_revisions.folder_id``).
|
||||||
|
* ``downgrade()`` re-establishes ``ondelete='CASCADE'`` after draining
|
||||||
|
orphaned revisions.
|
||||||
|
|
||||||
|
If any of these invariants regress the snapshot/revert pipeline silently
|
||||||
|
loses the ability to undo ``rm`` / ``rmdir`` on environments that ran the
|
||||||
|
migration "down" or never ran it at all.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import inspect
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
_MIGRATION_PATH = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "alembic"
|
||||||
|
/ "versions"
|
||||||
|
/ "134_relax_revision_fks.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_migration():
|
||||||
|
"""Load the migration module by file path (no package import needed)."""
|
||||||
|
spec = importlib.util.spec_from_file_location("_migration_134", _MIGRATION_PATH)
|
||||||
|
assert spec and spec.loader, "could not load migration spec"
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def test_migration_chain_revision_ids() -> None:
|
||||||
|
module = _load_migration()
|
||||||
|
# The migration file uses short numeric revision IDs to match the
|
||||||
|
# in-tree convention (cf. ``133`` -> ``134``); the ``134_<slug>.py``
|
||||||
|
# filename is documentation, not the canonical revision string.
|
||||||
|
assert getattr(module, "revision", None) == "134"
|
||||||
|
assert getattr(module, "down_revision", None) == "133"
|
||||||
|
|
||||||
|
|
||||||
|
def test_migration_exposes_upgrade_and_downgrade() -> None:
|
||||||
|
module = _load_migration()
|
||||||
|
upgrade = getattr(module, "upgrade", None)
|
||||||
|
downgrade = getattr(module, "downgrade", None)
|
||||||
|
assert callable(upgrade), "upgrade() is required"
|
||||||
|
assert callable(downgrade), "downgrade() is required"
|
||||||
|
|
||||||
|
|
||||||
|
def test_upgrade_creates_set_null_fks_for_both_revision_tables() -> None:
|
||||||
|
module = _load_migration()
|
||||||
|
src = inspect.getsource(module.upgrade)
|
||||||
|
assert "document_revisions" in src
|
||||||
|
assert "folder_revisions" in src
|
||||||
|
# Both new FKs MUST be ON DELETE SET NULL — that's the entire point
|
||||||
|
# of the migration: snapshots must outlive their parent row.
|
||||||
|
assert src.count('ondelete="SET NULL"') >= 2
|
||||||
|
# And the ``document_id`` / ``folder_id`` columns become nullable.
|
||||||
|
assert "nullable=True" in src
|
||||||
|
|
||||||
|
|
||||||
|
def test_downgrade_drains_orphans_then_restores_cascade() -> None:
|
||||||
|
module = _load_migration()
|
||||||
|
src = inspect.getsource(module.downgrade)
|
||||||
|
# Drain orphaned rows BEFORE we can re-impose NOT NULL.
|
||||||
|
assert "DELETE FROM document_revisions WHERE document_id IS NULL" in src
|
||||||
|
assert "DELETE FROM folder_revisions WHERE folder_id IS NULL" in src
|
||||||
|
# Then restore the original CASCADE/NOT NULL contract.
|
||||||
|
assert src.count('ondelete="CASCADE"') >= 2
|
||||||
|
assert "nullable=False" in src
|
||||||
|
|
@ -168,6 +168,8 @@ class TestModeSpecificPrompts:
|
||||||
"edit_file",
|
"edit_file",
|
||||||
"move_file",
|
"move_file",
|
||||||
"mkdir",
|
"mkdir",
|
||||||
|
"rm",
|
||||||
|
"rmdir",
|
||||||
"list_tree",
|
"list_tree",
|
||||||
"grep",
|
"grep",
|
||||||
):
|
):
|
||||||
|
|
@ -182,6 +184,8 @@ class TestModeSpecificPrompts:
|
||||||
"edit_file",
|
"edit_file",
|
||||||
"move_file",
|
"move_file",
|
||||||
"mkdir",
|
"mkdir",
|
||||||
|
"rm",
|
||||||
|
"rmdir",
|
||||||
"list_tree",
|
"list_tree",
|
||||||
"grep",
|
"grep",
|
||||||
):
|
):
|
||||||
|
|
@ -190,6 +194,18 @@ class TestModeSpecificPrompts:
|
||||||
assert "/documents/" not in text, f"{name} mentions cloud namespace"
|
assert "/documents/" not in text, f"{name} mentions cloud namespace"
|
||||||
assert "temp_" not in text, f"{name} mentions cloud temp_ semantics"
|
assert "temp_" not in text, f"{name} mentions cloud temp_ semantics"
|
||||||
|
|
||||||
|
def test_cloud_descs_include_rm_and_rmdir(self):
|
||||||
|
descs = _build_tool_descriptions(FilesystemMode.CLOUD)
|
||||||
|
assert "rm" in descs and "rmdir" in descs
|
||||||
|
assert "Deletes a single file" in descs["rm"]
|
||||||
|
assert "Deletes an empty directory" in descs["rmdir"]
|
||||||
|
assert "rmdir" in descs["rmdir"] and "POSIX" in descs["rmdir"]
|
||||||
|
|
||||||
|
def test_desktop_descs_warn_about_irreversibility(self):
|
||||||
|
descs = _build_tool_descriptions(FilesystemMode.DESKTOP_LOCAL_FOLDER)
|
||||||
|
assert "NOT reversible" in descs["rm"]
|
||||||
|
assert "NOT reversible" in descs["rmdir"]
|
||||||
|
|
||||||
def test_sandbox_addendum_appended_when_available(self):
|
def test_sandbox_addendum_appended_when_available(self):
|
||||||
prompt = _build_filesystem_system_prompt(
|
prompt = _build_filesystem_system_prompt(
|
||||||
FilesystemMode.CLOUD, sandbox_available=True
|
FilesystemMode.CLOUD, sandbox_available=True
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,309 @@
|
||||||
|
"""Unit tests for the kb_persistence snapshot helpers.
|
||||||
|
|
||||||
|
The full ``commit_staged_filesystem_state`` body exercises a real session
|
||||||
|
in integration tests; here we verify the building blocks used by the
|
||||||
|
snapshot/revert pipeline:
|
||||||
|
|
||||||
|
* ``_find_action_ids_batch`` issues a SINGLE query for N tool_call_ids
|
||||||
|
(regression guard against the N+1 lookup pattern).
|
||||||
|
* ``_mark_action_reversible`` is a no-op when ``action_id`` is ``None``.
|
||||||
|
* ``_doc_revision_payload`` and ``_load_chunks_for_snapshot`` produce the
|
||||||
|
shape the snapshot helpers consume.
|
||||||
|
|
||||||
|
These tests use ``MagicMock`` / ``AsyncMock`` against a fake session so
|
||||||
|
the assertions run in milliseconds and don't require Postgres.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware import kb_persistence
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeResult:
|
||||||
|
def __init__(self, rows: list[Any] | None = None, scalar: Any = None) -> None:
|
||||||
|
self._rows = rows or []
|
||||||
|
self._scalar = scalar
|
||||||
|
|
||||||
|
def all(self) -> list[Any]:
|
||||||
|
return list(self._rows)
|
||||||
|
|
||||||
|
def scalar_one_or_none(self) -> Any:
|
||||||
|
return self._scalar
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.execute = AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_find_action_ids_batch_issues_single_query() -> None:
|
||||||
|
"""The lookup MUST be a single ``IN (...)`` SELECT, not N selects."""
|
||||||
|
session = _FakeSession()
|
||||||
|
session.execute.return_value = _FakeResult(
|
||||||
|
rows=[
|
||||||
|
MagicMock(id=11, tool_call_id="tc-a"),
|
||||||
|
MagicMock(id=22, tool_call_id="tc-b"),
|
||||||
|
MagicMock(id=33, tool_call_id="tc-c"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
mapping = await kb_persistence._find_action_ids_batch(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
thread_id=1,
|
||||||
|
tool_call_ids={"tc-a", "tc-b", "tc-c"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mapping == {"tc-a": 11, "tc-b": 22, "tc-c": 33}
|
||||||
|
assert session.execute.await_count == 1, (
|
||||||
|
"Snapshot binding must batch into ONE query; got "
|
||||||
|
f"{session.execute.await_count} (regression: N+1 lookup pattern)."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_find_action_ids_batch_short_circuits_when_thread_id_missing() -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
mapping = await kb_persistence._find_action_ids_batch(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
thread_id=None,
|
||||||
|
tool_call_ids={"tc-a"},
|
||||||
|
)
|
||||||
|
assert mapping == {}
|
||||||
|
assert session.execute.await_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_find_action_ids_batch_short_circuits_when_no_calls() -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
mapping = await kb_persistence._find_action_ids_batch(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
thread_id=42,
|
||||||
|
tool_call_ids=set(),
|
||||||
|
)
|
||||||
|
assert mapping == {}
|
||||||
|
assert session.execute.await_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mark_action_reversible_is_noop_for_null_id() -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
await kb_persistence._mark_action_reversible(session, action_id=None) # type: ignore[arg-type]
|
||||||
|
assert session.execute.await_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mark_action_reversible_runs_update_for_real_id() -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
await kb_persistence._mark_action_reversible(session, action_id=99) # type: ignore[arg-type]
|
||||||
|
assert session.execute.await_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_doc_revision_payload_captures_metadata_virtual_path() -> None:
|
||||||
|
"""Snapshot helpers must capture ``metadata_before`` for revert reuse."""
|
||||||
|
doc = MagicMock()
|
||||||
|
doc.content = "body"
|
||||||
|
doc.title = "notes.md"
|
||||||
|
doc.folder_id = 7
|
||||||
|
doc.document_metadata = {"virtual_path": "/documents/team/notes.md"}
|
||||||
|
|
||||||
|
payload = kb_persistence._doc_revision_payload(
|
||||||
|
doc, chunks_before=[{"content": "x"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert payload["title_before"] == "notes.md"
|
||||||
|
assert payload["folder_id_before"] == 7
|
||||||
|
assert payload["content_before"] == "body"
|
||||||
|
assert payload["chunks_before"] == [{"content": "x"}]
|
||||||
|
assert payload["metadata_before"] == {"virtual_path": "/documents/team/notes.md"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_doc_revision_payload_handles_missing_metadata() -> None:
|
||||||
|
doc = MagicMock()
|
||||||
|
doc.content = ""
|
||||||
|
doc.title = ""
|
||||||
|
doc.folder_id = None
|
||||||
|
doc.document_metadata = None
|
||||||
|
payload = kb_persistence._doc_revision_payload(doc)
|
||||||
|
assert payload["metadata_before"] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_chunks_for_snapshot_returns_content_only() -> None:
|
||||||
|
"""Snapshot chunks intentionally omit embeddings (regenerated on revert)."""
|
||||||
|
session = _FakeSession()
|
||||||
|
session.execute.return_value = _FakeResult(
|
||||||
|
rows=[
|
||||||
|
MagicMock(content="alpha"),
|
||||||
|
MagicMock(content="beta"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
chunks = await kb_persistence._load_chunks_for_snapshot(
|
||||||
|
session,
|
||||||
|
doc_id=42, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
assert chunks == [{"content": "alpha"}, {"content": "beta"}]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Deferred reversibility-flip dispatches.
|
||||||
|
#
|
||||||
|
# The snapshot helpers used to dispatch ``action_log_updated`` directly
|
||||||
|
# from inside the SAVEPOINT block. That meant the SSE side-channel
|
||||||
|
# could tell the UI a row was reversible while the OUTER transaction
|
||||||
|
# was still pending — and if the outer commit failed, every SAVEPOINT
|
||||||
|
# rolled back too, leaving the UI in a state inconsistent with
|
||||||
|
# durable storage. The deferred-dispatch contract fixes that:
|
||||||
|
#
|
||||||
|
# • when a ``deferred_dispatches`` list is provided, the helper
|
||||||
|
# APPENDS the action_id and does NOT dispatch;
|
||||||
|
# • the caller (``commit_staged_filesystem_state``) flushes the list
|
||||||
|
# only AFTER ``await session.commit()`` succeeds; on rollback it
|
||||||
|
# clears the list so nothing is emitted.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _NestedCtx:
|
||||||
|
"""Async context manager mimicking ``session.begin_nested()``."""
|
||||||
|
|
||||||
|
async def __aenter__(self) -> _NestedCtx:
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pre_write_snapshot_defers_dispatch_when_list_provided(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""Helpers MUST queue dispatches when ``deferred_dispatches`` is set."""
|
||||||
|
session = MagicMock()
|
||||||
|
session.begin_nested = MagicMock(return_value=_NestedCtx())
|
||||||
|
session.execute = AsyncMock(return_value=_FakeResult(rows=[]))
|
||||||
|
session.flush = AsyncMock()
|
||||||
|
|
||||||
|
def _add(rev: Any) -> None:
|
||||||
|
rev.id = 17
|
||||||
|
|
||||||
|
session.add = MagicMock(side_effect=_add)
|
||||||
|
|
||||||
|
dispatched: list[int] = []
|
||||||
|
|
||||||
|
async def _fake_dispatch(action_id: int | None) -> None:
|
||||||
|
if action_id is not None:
|
||||||
|
dispatched.append(int(action_id))
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
|
||||||
|
)
|
||||||
|
|
||||||
|
deferred: list[int] = []
|
||||||
|
doc = MagicMock(id=99, document_metadata={"virtual_path": "/documents/x.md"})
|
||||||
|
doc.title = "x.md"
|
||||||
|
doc.folder_id = None
|
||||||
|
doc.content = "body"
|
||||||
|
|
||||||
|
rev_id = await kb_persistence._snapshot_document_pre_write(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
doc=doc,
|
||||||
|
action_id=42,
|
||||||
|
search_space_id=1,
|
||||||
|
turn_id="t-1",
|
||||||
|
deferred_dispatches=deferred,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert rev_id == 17
|
||||||
|
# Inline dispatch must NOT have fired; the action_id is queued.
|
||||||
|
assert dispatched == []
|
||||||
|
assert deferred == [42]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pre_write_snapshot_dispatches_inline_when_list_omitted(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""Direct callers (no outer transaction) keep the legacy inline dispatch."""
|
||||||
|
session = MagicMock()
|
||||||
|
session.begin_nested = MagicMock(return_value=_NestedCtx())
|
||||||
|
session.execute = AsyncMock(return_value=_FakeResult(rows=[]))
|
||||||
|
session.flush = AsyncMock()
|
||||||
|
|
||||||
|
def _add(rev: Any) -> None:
|
||||||
|
rev.id = 7
|
||||||
|
|
||||||
|
session.add = MagicMock(side_effect=_add)
|
||||||
|
|
||||||
|
dispatched: list[int] = []
|
||||||
|
|
||||||
|
async def _fake_dispatch(action_id: int | None) -> None:
|
||||||
|
if action_id is not None:
|
||||||
|
dispatched.append(int(action_id))
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
|
||||||
|
)
|
||||||
|
|
||||||
|
doc = MagicMock(id=11, document_metadata={"virtual_path": "/documents/y.md"})
|
||||||
|
doc.title = "y.md"
|
||||||
|
doc.folder_id = None
|
||||||
|
doc.content = "body"
|
||||||
|
|
||||||
|
await kb_persistence._snapshot_document_pre_write(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
doc=doc,
|
||||||
|
action_id=88,
|
||||||
|
search_space_id=1,
|
||||||
|
turn_id="t-1",
|
||||||
|
# No deferred_dispatches arg — fall back to inline dispatch.
|
||||||
|
)
|
||||||
|
|
||||||
|
assert dispatched == [88]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pre_mkdir_snapshot_defers_dispatch_when_list_provided(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""Folder mkdir snapshots honour the same deferred-dispatch contract."""
|
||||||
|
session = MagicMock()
|
||||||
|
session.begin_nested = MagicMock(return_value=_NestedCtx())
|
||||||
|
session.execute = AsyncMock() # _mark_action_reversible calls execute
|
||||||
|
session.flush = AsyncMock()
|
||||||
|
|
||||||
|
def _add(rev: Any) -> None:
|
||||||
|
rev.id = 3
|
||||||
|
|
||||||
|
session.add = MagicMock(side_effect=_add)
|
||||||
|
|
||||||
|
dispatched: list[int] = []
|
||||||
|
|
||||||
|
async def _fake_dispatch(action_id: int | None) -> None:
|
||||||
|
if action_id is not None:
|
||||||
|
dispatched.append(int(action_id))
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
kb_persistence, "_dispatch_reversibility_update", _fake_dispatch
|
||||||
|
)
|
||||||
|
|
||||||
|
deferred: list[int] = []
|
||||||
|
folder = MagicMock(id=2, name="f", parent_id=None, position="a0")
|
||||||
|
|
||||||
|
await kb_persistence._snapshot_folder_pre_mkdir(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
folder=folder,
|
||||||
|
action_id=55,
|
||||||
|
search_space_id=1,
|
||||||
|
turn_id="t-1",
|
||||||
|
deferred_dispatches=deferred,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert dispatched == []
|
||||||
|
assert deferred == [55]
|
||||||
139
surfsense_backend/tests/unit/middleware/test_knowledge_tree.py
Normal file
139
surfsense_backend/tests/unit/middleware/test_knowledge_tree.py
Normal file
|
|
@ -0,0 +1,139 @@
|
||||||
|
"""Unit tests for ``KnowledgeTreeMiddleware`` rendering.
|
||||||
|
|
||||||
|
The empty-folder marker is critical UX: without it, the LLM cannot
|
||||||
|
distinguish a leaf folder containing one document from a leaf folder
|
||||||
|
that has no descendants at all, and ends up firing ``rmdir`` on
|
||||||
|
non-empty folders. These tests pin the rendering contract so that
|
||||||
|
contract cannot silently regress.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.knowledge_tree import KnowledgeTreeMiddleware
|
||||||
|
from app.agents.new_chat.path_resolver import DOCUMENTS_ROOT
|
||||||
|
|
||||||
|
|
||||||
|
def _compute(folder_paths: list[str], doc_paths: list[str]) -> set[str]:
|
||||||
|
return KnowledgeTreeMiddleware._compute_non_empty_folders(folder_paths, doc_paths)
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputeNonEmptyFolders:
|
||||||
|
def test_folder_with_direct_document_is_non_empty(self):
|
||||||
|
folder_paths = [f"{DOCUMENTS_ROOT}/Travel/Boarding Pass"]
|
||||||
|
doc_paths = [
|
||||||
|
f"{DOCUMENTS_ROOT}/Travel/Boarding Pass/southwest.pdf.xml",
|
||||||
|
]
|
||||||
|
non_empty = _compute(folder_paths, doc_paths)
|
||||||
|
assert f"{DOCUMENTS_ROOT}/Travel/Boarding Pass" in non_empty
|
||||||
|
|
||||||
|
def test_truly_empty_leaf_folder_is_not_non_empty(self):
|
||||||
|
folder_paths = [f"{DOCUMENTS_ROOT}/Travel/Boarding Pass"]
|
||||||
|
doc_paths: list[str] = []
|
||||||
|
assert _compute(folder_paths, doc_paths) == set()
|
||||||
|
|
||||||
|
def test_documents_propagate_up_to_all_ancestors(self):
|
||||||
|
folder_paths = [
|
||||||
|
f"{DOCUMENTS_ROOT}/A",
|
||||||
|
f"{DOCUMENTS_ROOT}/A/B",
|
||||||
|
f"{DOCUMENTS_ROOT}/A/B/C",
|
||||||
|
]
|
||||||
|
doc_paths = [f"{DOCUMENTS_ROOT}/A/B/C/file.xml"]
|
||||||
|
non_empty = _compute(folder_paths, doc_paths)
|
||||||
|
assert non_empty == {
|
||||||
|
f"{DOCUMENTS_ROOT}/A",
|
||||||
|
f"{DOCUMENTS_ROOT}/A/B",
|
||||||
|
f"{DOCUMENTS_ROOT}/A/B/C",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_chain_with_subfolders_marks_only_leaf_empty(self):
|
||||||
|
# POSIX-like semantic: a folder is "empty" only if it has no
|
||||||
|
# immediate children (docs OR sub-folders). The model needs this
|
||||||
|
# because parallel ``rmdir`` calls all see the same starting state,
|
||||||
|
# so trying to rmdir a parent before its children is never safe.
|
||||||
|
folder_paths = [
|
||||||
|
f"{DOCUMENTS_ROOT}/X",
|
||||||
|
f"{DOCUMENTS_ROOT}/X/Y",
|
||||||
|
f"{DOCUMENTS_ROOT}/X/Y/Z",
|
||||||
|
]
|
||||||
|
non_empty = _compute(folder_paths, [])
|
||||||
|
# Only ``X/Y/Z`` (the leaf) is empty. ``X`` and ``X/Y`` each have a
|
||||||
|
# sub-folder child, so they are non-empty and should NOT carry the
|
||||||
|
# ``(empty)`` marker.
|
||||||
|
assert non_empty == {f"{DOCUMENTS_ROOT}/X", f"{DOCUMENTS_ROOT}/X/Y"}
|
||||||
|
|
||||||
|
def test_sibling_with_doc_does_not_mark_other_sibling_non_empty(self):
|
||||||
|
# Mirrors a real DB layout where every intermediate folder is
|
||||||
|
# materialized in the ``folders`` table.
|
||||||
|
folder_paths = [
|
||||||
|
f"{DOCUMENTS_ROOT}/Travel",
|
||||||
|
f"{DOCUMENTS_ROOT}/Travel/Boarding Pass",
|
||||||
|
f"{DOCUMENTS_ROOT}/Travel/Notes",
|
||||||
|
]
|
||||||
|
doc_paths = [f"{DOCUMENTS_ROOT}/Travel/Notes/itinerary.xml"]
|
||||||
|
non_empty = _compute(folder_paths, doc_paths)
|
||||||
|
# ``Travel`` is non-empty because it has children, ``Notes`` is non-empty
|
||||||
|
# because of the doc, but ``Boarding Pass`` (sibling leaf) is empty.
|
||||||
|
assert f"{DOCUMENTS_ROOT}/Travel" in non_empty
|
||||||
|
assert f"{DOCUMENTS_ROOT}/Travel/Notes" in non_empty
|
||||||
|
assert f"{DOCUMENTS_ROOT}/Travel/Boarding Pass" not in non_empty
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatTreeRendering:
|
||||||
|
"""Integration check: empty leaf gets ``(empty)`` marker; non-empty doesn't."""
|
||||||
|
|
||||||
|
def _render(
|
||||||
|
self,
|
||||||
|
folder_paths: list[str],
|
||||||
|
doc_specs: list[dict],
|
||||||
|
) -> str:
|
||||||
|
from app.agents.new_chat.path_resolver import PathIndex
|
||||||
|
|
||||||
|
index = PathIndex(
|
||||||
|
folder_paths={i + 1: p for i, p in enumerate(folder_paths)},
|
||||||
|
)
|
||||||
|
|
||||||
|
class _Row:
|
||||||
|
def __init__(self, **kw):
|
||||||
|
self.__dict__.update(kw)
|
||||||
|
|
||||||
|
docs = [_Row(**spec) for spec in doc_specs]
|
||||||
|
|
||||||
|
mw = KnowledgeTreeMiddleware(
|
||||||
|
search_space_id=1,
|
||||||
|
filesystem_mode=None, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
return mw._format_tree(index, docs)
|
||||||
|
|
||||||
|
def test_renders_empty_marker_only_for_truly_empty_folders(self):
|
||||||
|
# Reproduces the failure scenario from the bug report:
|
||||||
|
# ``Boarding Pass`` is empty (its only doc was just deleted), while
|
||||||
|
# ``Tax Returns`` still has ``federal.pdf``. All intermediate
|
||||||
|
# folders are present in the index, mirroring the real DB layout.
|
||||||
|
folder_paths = [
|
||||||
|
"/documents/File Upload",
|
||||||
|
"/documents/File Upload/2026-04-08",
|
||||||
|
"/documents/File Upload/2026-04-08/Travel",
|
||||||
|
"/documents/File Upload/2026-04-08/Travel/Boarding Pass",
|
||||||
|
"/documents/File Upload/2026-04-15",
|
||||||
|
"/documents/File Upload/2026-04-15/Finance",
|
||||||
|
"/documents/File Upload/2026-04-15/Finance/Tax Returns",
|
||||||
|
]
|
||||||
|
tax_returns_folder_id = (
|
||||||
|
folder_paths.index("/documents/File Upload/2026-04-15/Finance/Tax Returns")
|
||||||
|
+ 1
|
||||||
|
)
|
||||||
|
rendered = self._render(
|
||||||
|
folder_paths=folder_paths,
|
||||||
|
doc_specs=[
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"title": "federal.pdf",
|
||||||
|
"folder_id": tax_returns_folder_id,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert "Boarding Pass/ (empty)" in rendered
|
||||||
|
assert "Tax Returns/ (empty)" not in rendered
|
||||||
|
# Intermediate ancestors of the doc must NOT be marked empty.
|
||||||
|
assert "Finance/ (empty)" not in rendered
|
||||||
|
assert "2026-04-15/ (empty)" not in rendered
|
||||||
|
|
@ -69,3 +69,74 @@ def test_local_backend_write_rejects_missing_parent_directory(tmp_path: Path):
|
||||||
assert write.error is not None
|
assert write.error is not None
|
||||||
assert "parent directory" in write.error
|
assert "parent directory" in write.error
|
||||||
assert not (tmp_path / "tempoo").exists()
|
assert not (tmp_path / "tempoo").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_backend_delete_file_success(tmp_path: Path):
|
||||||
|
backend = LocalFolderBackend(str(tmp_path))
|
||||||
|
(tmp_path / "delete-me.md").write_text("bye")
|
||||||
|
|
||||||
|
res = backend.delete_file("/delete-me.md")
|
||||||
|
assert res.error is None
|
||||||
|
assert res.path == "/delete-me.md"
|
||||||
|
assert not (tmp_path / "delete-me.md").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_backend_delete_file_rejects_directory(tmp_path: Path):
|
||||||
|
backend = LocalFolderBackend(str(tmp_path))
|
||||||
|
(tmp_path / "subdir").mkdir()
|
||||||
|
|
||||||
|
res = backend.delete_file("/subdir")
|
||||||
|
assert res.error is not None
|
||||||
|
assert "directory" in res.error
|
||||||
|
assert (tmp_path / "subdir").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_backend_delete_file_missing_returns_error(tmp_path: Path):
|
||||||
|
backend = LocalFolderBackend(str(tmp_path))
|
||||||
|
|
||||||
|
res = backend.delete_file("/nope.md")
|
||||||
|
assert res.error is not None
|
||||||
|
assert "not found" in res.error
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_backend_rmdir_success(tmp_path: Path):
|
||||||
|
backend = LocalFolderBackend(str(tmp_path))
|
||||||
|
(tmp_path / "empty").mkdir()
|
||||||
|
|
||||||
|
res = backend.rmdir("/empty")
|
||||||
|
assert res.error is None
|
||||||
|
assert res.path == "/empty"
|
||||||
|
assert not (tmp_path / "empty").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_backend_rmdir_rejects_non_empty(tmp_path: Path):
|
||||||
|
backend = LocalFolderBackend(str(tmp_path))
|
||||||
|
(tmp_path / "withkid").mkdir()
|
||||||
|
(tmp_path / "withkid" / "child.md").write_text("x")
|
||||||
|
|
||||||
|
res = backend.rmdir("/withkid")
|
||||||
|
assert res.error is not None
|
||||||
|
assert "not empty" in res.error
|
||||||
|
assert (tmp_path / "withkid" / "child.md").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_backend_rmdir_rejects_file(tmp_path: Path):
|
||||||
|
backend = LocalFolderBackend(str(tmp_path))
|
||||||
|
(tmp_path / "f.md").write_text("x")
|
||||||
|
|
||||||
|
res = backend.rmdir("/f.md")
|
||||||
|
assert res.error is not None
|
||||||
|
assert "not a directory" in res.error
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_backend_rmdir_rejects_root(tmp_path: Path):
|
||||||
|
"""``rmdir /`` MUST fail. The exact error wording comes from
|
||||||
|
``_resolve_virtual`` (root resolves to outside the sandbox); what
|
||||||
|
matters is that the call returns an error and does NOT delete the
|
||||||
|
sandbox root on disk."""
|
||||||
|
backend = LocalFolderBackend(str(tmp_path))
|
||||||
|
|
||||||
|
res = backend.rmdir("/")
|
||||||
|
assert res.error is not None
|
||||||
|
assert "Invalid path" in res.error or "root" in res.error
|
||||||
|
assert tmp_path.exists()
|
||||||
|
|
|
||||||
0
surfsense_backend/tests/unit/routes/__init__.py
Normal file
0
surfsense_backend/tests/unit/routes/__init__.py
Normal file
|
|
@ -0,0 +1,143 @@
|
||||||
|
"""Unit tests for the edit-from-arbitrary-position helpers inside ``new_chat_routes``.
|
||||||
|
|
||||||
|
The regenerate route's edit-from-position path introduces:
|
||||||
|
* ``_find_pre_turn_checkpoint_id`` — walks LangGraph checkpoint tuples
|
||||||
|
newest-first and picks the first one whose ``metadata["turn_id"]``
|
||||||
|
differs from the edited turn. That checkpoint is the rewind target
|
||||||
|
(state immediately before the edited turn started).
|
||||||
|
* ``RegenerateRequest`` accepts ``from_message_id`` + ``revert_actions``
|
||||||
|
with a validator that prevents callers from requesting a revert pass
|
||||||
|
without specifying which turn to roll back.
|
||||||
|
|
||||||
|
These are pure-Python helpers that don't need a live DB, so we exercise
|
||||||
|
them with a small ``CheckpointTuple``-shaped namespace and direct
|
||||||
|
schema instantiation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.routes.new_chat_routes import _find_pre_turn_checkpoint_id
|
||||||
|
from app.schemas.new_chat import RegenerateRequest
|
||||||
|
|
||||||
|
|
||||||
|
def _cp(checkpoint_id: str, turn_id: str | None) -> SimpleNamespace:
|
||||||
|
"""Build a fake ``CheckpointTuple`` with the metadata shape we read."""
|
||||||
|
return SimpleNamespace(
|
||||||
|
config={"configurable": {"checkpoint_id": checkpoint_id}},
|
||||||
|
metadata={"turn_id": turn_id} if turn_id is not None else {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFindPreTurnCheckpointId:
|
||||||
|
def test_returns_last_pre_turn_checkpoint_when_editing_latest_turn(self) -> None:
|
||||||
|
# Newest-first: T2 is the most-recent turn. The latest non-T2
|
||||||
|
# checkpoint (cp2) is the rewind target — state immediately
|
||||||
|
# before T2 began.
|
||||||
|
tuples = [
|
||||||
|
_cp("cp4", "T2"),
|
||||||
|
_cp("cp3", "T2"),
|
||||||
|
_cp("cp2", "T1"),
|
||||||
|
_cp("cp1", "T1"),
|
||||||
|
]
|
||||||
|
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
|
||||||
|
|
||||||
|
def test_returns_pre_turn_checkpoint_when_later_turns_exist(self) -> None:
|
||||||
|
# Regression for the bug where walking newest-first returned the
|
||||||
|
# FIRST cp with ``turn_id != target`` — which is one of the
|
||||||
|
# later-turn checkpoints, NOT the pre-turn boundary. Editing
|
||||||
|
# T2 must rewind to the latest T1 checkpoint (cp2), not to the
|
||||||
|
# latest T3 checkpoint (cp6).
|
||||||
|
tuples = [
|
||||||
|
_cp("cp6", "T3"),
|
||||||
|
_cp("cp5", "T3"),
|
||||||
|
_cp("cp4", "T2"),
|
||||||
|
_cp("cp3", "T2"),
|
||||||
|
_cp("cp2", "T1"),
|
||||||
|
_cp("cp1", "T1"),
|
||||||
|
]
|
||||||
|
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
|
||||||
|
|
||||||
|
def test_returns_none_when_editing_first_turn(self) -> None:
|
||||||
|
# No pre-turn boundary exists; caller is expected to fall back
|
||||||
|
# to the oldest checkpoint or special-case "first turn of the
|
||||||
|
# thread".
|
||||||
|
tuples = [
|
||||||
|
_cp("cp4", "T2"),
|
||||||
|
_cp("cp3", "T2"),
|
||||||
|
_cp("cp2", "T1"),
|
||||||
|
_cp("cp1", "T1"),
|
||||||
|
]
|
||||||
|
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T1") is None
|
||||||
|
|
||||||
|
def test_returns_none_when_only_edited_turn_present(self) -> None:
|
||||||
|
tuples = [_cp("cp2", "T2"), _cp("cp1", "T2")]
|
||||||
|
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") is None
|
||||||
|
|
||||||
|
def test_returns_none_for_empty_history(self) -> None:
|
||||||
|
assert _find_pre_turn_checkpoint_id([], turn_id="T1") is None
|
||||||
|
|
||||||
|
def test_legacy_checkpoints_without_turn_id_count_as_pre_turn(self) -> None:
|
||||||
|
# Checkpoints written before migration 136 have no
|
||||||
|
# ``metadata.turn_id``. They should be eligible rewind targets
|
||||||
|
# — they came before the
|
||||||
|
# edited turn began.
|
||||||
|
tuples = [
|
||||||
|
_cp("cp3", "T2"),
|
||||||
|
SimpleNamespace(
|
||||||
|
config={"configurable": {"checkpoint_id": "cp2"}},
|
||||||
|
metadata=None,
|
||||||
|
),
|
||||||
|
_cp("cp1", "T1"),
|
||||||
|
]
|
||||||
|
# Walking oldest-first: cp1(T1) tracked, cp2(legacy/None) tracked,
|
||||||
|
# then cp3(T2) crosses the boundary -> return cp2.
|
||||||
|
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp2"
|
||||||
|
|
||||||
|
def test_skips_checkpoint_missing_checkpoint_id_in_config(self) -> None:
|
||||||
|
# If a checkpoint tuple's ``config["configurable"]`` is missing
|
||||||
|
# the ``checkpoint_id`` key (corrupt / partial), we keep the
|
||||||
|
# last known good target instead of crashing.
|
||||||
|
broken = SimpleNamespace(
|
||||||
|
config={"configurable": {}}, metadata={"turn_id": "T1"}
|
||||||
|
)
|
||||||
|
tuples = [
|
||||||
|
_cp("cp3", "T2"),
|
||||||
|
broken,
|
||||||
|
_cp("cp1", "T1"),
|
||||||
|
]
|
||||||
|
# cp1(T1) tracked, broken skipped, cp3(T2) -> return cp1.
|
||||||
|
assert _find_pre_turn_checkpoint_id(tuples, turn_id="T2") == "cp1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegenerateRequestValidation:
|
||||||
|
def test_revert_actions_requires_from_message_id(self) -> None:
|
||||||
|
with pytest.raises(Exception) as exc:
|
||||||
|
RegenerateRequest(
|
||||||
|
search_space_id=1,
|
||||||
|
user_query="hi",
|
||||||
|
revert_actions=True,
|
||||||
|
)
|
||||||
|
msg = str(exc.value).lower()
|
||||||
|
assert "from_message_id" in msg
|
||||||
|
|
||||||
|
def test_from_message_id_without_revert_is_allowed(self) -> None:
|
||||||
|
req = RegenerateRequest(
|
||||||
|
search_space_id=1,
|
||||||
|
user_query="hi",
|
||||||
|
from_message_id=42,
|
||||||
|
)
|
||||||
|
assert req.from_message_id == 42
|
||||||
|
assert req.revert_actions is False
|
||||||
|
|
||||||
|
def test_revert_actions_with_from_message_id_passes(self) -> None:
|
||||||
|
req = RegenerateRequest(
|
||||||
|
search_space_id=1,
|
||||||
|
user_query="hi",
|
||||||
|
from_message_id=42,
|
||||||
|
revert_actions=True,
|
||||||
|
)
|
||||||
|
assert req.revert_actions is True
|
||||||
530
surfsense_backend/tests/unit/routes/test_revert_turn_route.py
Normal file
530
surfsense_backend/tests/unit/routes/test_revert_turn_route.py
Normal file
|
|
@ -0,0 +1,530 @@
|
||||||
|
"""Unit tests for ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
|
||||||
|
|
||||||
|
The per-turn batch revert route walks rows in reverse ``created_at``
|
||||||
|
order, reverts each independently, and returns a per-action result
|
||||||
|
list. Partial success is normal — the response status
|
||||||
|
is ``"partial"`` whenever any row could not be reverted, but we never
|
||||||
|
collapse the whole batch into a 4xx.
|
||||||
|
|
||||||
|
These tests stub ``load_thread`` / ``revert_action`` and feed a fake
|
||||||
|
session, so they exercise the route's dispatch logic without a real DB.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.routes import agent_revert_route
|
||||||
|
from app.services.revert_service import RevertOutcome
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _FakeAction:
|
||||||
|
id: int
|
||||||
|
tool_name: str
|
||||||
|
user_id: str | None = "u1"
|
||||||
|
reverse_of: int | None = None
|
||||||
|
error: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _FakeUser:
|
||||||
|
id: str = "u1"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _ScalarResult:
|
||||||
|
rows: list[Any]
|
||||||
|
|
||||||
|
def first(self) -> Any:
|
||||||
|
return self.rows[0] if self.rows else None
|
||||||
|
|
||||||
|
def all(self) -> list[Any]:
|
||||||
|
return list(self.rows)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _Result:
|
||||||
|
rows: list[Any] = field(default_factory=list)
|
||||||
|
|
||||||
|
def scalars(self) -> _ScalarResult:
|
||||||
|
return _ScalarResult(self.rows)
|
||||||
|
|
||||||
|
def all(self) -> list[Any]:
|
||||||
|
# ``_was_already_reverted_batch`` calls ``.all()`` directly on
|
||||||
|
# the row-tuple result (no ``.scalars()`` indirection). The
|
||||||
|
# rows queued for that helper are list[(revert_id, original_id)].
|
||||||
|
return list(self.rows)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeNestedCtx:
|
||||||
|
"""Async context manager that mimics ``session.begin_nested()``.
|
||||||
|
|
||||||
|
The route raises a sentinel exception inside this block to roll back
|
||||||
|
bad rows. We just pass the exception through.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def __aenter__(self) -> _FakeNestedCtx:
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||||
|
# Returning False (or None) propagates the exception; the route
|
||||||
|
# catches its own sentinel above this layer.
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
"""Minimal AsyncSession stand-in for the revert-turn route.
|
||||||
|
|
||||||
|
Holds a queue of result objects; each ``execute(...)`` pops the next
|
||||||
|
one. The route calls ``execute`` exactly once per query so this maps
|
||||||
|
cleanly onto the assertion order of the test.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._results: list[_Result] = []
|
||||||
|
self.committed = False
|
||||||
|
self.rolled_back = False
|
||||||
|
# Count execute() calls to assert "no N+1 reverts".
|
||||||
|
self.execute_call_count = 0
|
||||||
|
|
||||||
|
def queue(self, *results: _Result) -> None:
|
||||||
|
self._results.extend(results)
|
||||||
|
|
||||||
|
async def execute(self, _stmt: Any) -> _Result:
|
||||||
|
self.execute_call_count += 1
|
||||||
|
if not self._results:
|
||||||
|
return _Result(rows=[])
|
||||||
|
return self._results.pop(0)
|
||||||
|
|
||||||
|
def begin_nested(self) -> _FakeNestedCtx:
|
||||||
|
return _FakeNestedCtx()
|
||||||
|
|
||||||
|
async def commit(self) -> None:
|
||||||
|
self.committed = True
|
||||||
|
|
||||||
|
async def rollback(self) -> None:
|
||||||
|
self.rolled_back = True
|
||||||
|
|
||||||
|
|
||||||
|
def _enabled_flags() -> AgentFeatureFlags:
|
||||||
|
return AgentFeatureFlags(
|
||||||
|
disable_new_agent_stack=False,
|
||||||
|
enable_action_log=True,
|
||||||
|
enable_revert_route=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def patch_get_flags():
|
||||||
|
def _patch(flags: AgentFeatureFlags):
|
||||||
|
return patch(
|
||||||
|
"app.routes.agent_revert_route.get_flags",
|
||||||
|
return_value=flags,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _patch
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlagGuard:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_503_when_revert_route_disabled(
|
||||||
|
self, patch_get_flags
|
||||||
|
) -> None:
|
||||||
|
flags = AgentFeatureFlags(
|
||||||
|
disable_new_agent_stack=False,
|
||||||
|
enable_action_log=True,
|
||||||
|
enable_revert_route=False,
|
||||||
|
)
|
||||||
|
session = _FakeSession()
|
||||||
|
with patch_get_flags(flags), pytest.raises(Exception) as exc:
|
||||||
|
await agent_revert_route.revert_agent_turn(
|
||||||
|
thread_id=1,
|
||||||
|
chat_turn_id="42:1700000000000",
|
||||||
|
session=session,
|
||||||
|
user=_FakeUser(),
|
||||||
|
)
|
||||||
|
assert getattr(exc.value, "status_code", None) == 503
|
||||||
|
|
||||||
|
|
||||||
|
class TestRevertTurnDispatch:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_turn_returns_ok_with_no_rows(self, patch_get_flags) -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
session.queue(_Result(rows=[])) # rows query returns nothing
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||||
|
),
|
||||||
|
):
|
||||||
|
response = await agent_revert_route.revert_agent_turn(
|
||||||
|
thread_id=1,
|
||||||
|
chat_turn_id="ct-empty",
|
||||||
|
session=session,
|
||||||
|
user=_FakeUser(),
|
||||||
|
)
|
||||||
|
assert response.status == "ok"
|
||||||
|
assert response.total == 0
|
||||||
|
assert response.results == []
|
||||||
|
assert session.committed is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_walks_rows_in_reverse_and_reverts_each(
|
||||||
|
self, patch_get_flags
|
||||||
|
) -> None:
|
||||||
|
rows = [
|
||||||
|
_FakeAction(id=10, tool_name="rm"),
|
||||||
|
_FakeAction(id=9, tool_name="write_file"),
|
||||||
|
_FakeAction(id=8, tool_name="mkdir"),
|
||||||
|
]
|
||||||
|
session = _FakeSession()
|
||||||
|
session.queue(_Result(rows=rows))
|
||||||
|
# Single batched ``_was_already_reverted_batch`` probe replaces
|
||||||
|
# the previous N per-row SELECTs.
|
||||||
|
session.queue(_Result(rows=[]))
|
||||||
|
|
||||||
|
async def _fake_revert(_session, *, action, requester_user_id):
|
||||||
|
return RevertOutcome(
|
||||||
|
status="ok",
|
||||||
|
message=f"reverted-{action.id}",
|
||||||
|
new_action_id=100 + action.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
|
||||||
|
),
|
||||||
|
):
|
||||||
|
response = await agent_revert_route.revert_agent_turn(
|
||||||
|
thread_id=1,
|
||||||
|
chat_turn_id="ct-3",
|
||||||
|
session=session,
|
||||||
|
user=_FakeUser(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == "ok"
|
||||||
|
assert response.total == 3
|
||||||
|
assert response.reverted == 3
|
||||||
|
assert [r.action_id for r in response.results] == [10, 9, 8]
|
||||||
|
assert all(r.status == "reverted" for r in response.results)
|
||||||
|
assert response.results[0].new_action_id == 110
|
||||||
|
# Only TWO ``execute`` calls regardless of the row count: one
|
||||||
|
# for the rows query, one for the batched
|
||||||
|
# ``_was_already_reverted_batch`` probe. Regression guard
|
||||||
|
# against re-introducing the per-row N+1 lookup.
|
||||||
|
assert session.execute_call_count == 2, (
|
||||||
|
"revert-turn loop must batch idempotency probes; got "
|
||||||
|
f"{session.execute_call_count} execute() calls (expected 2)."
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_already_reverted_rows_are_marked_idempotent(
|
||||||
|
self, patch_get_flags
|
||||||
|
) -> None:
|
||||||
|
rows = [_FakeAction(id=5, tool_name="edit_file")]
|
||||||
|
session = _FakeSession()
|
||||||
|
session.queue(_Result(rows=rows))
|
||||||
|
# Batch probe returns ``[(revert_id, original_id)]``.
|
||||||
|
session.queue(_Result(rows=[(42, 5)]))
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||||
|
),
|
||||||
|
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
|
||||||
|
):
|
||||||
|
response = await agent_revert_route.revert_agent_turn(
|
||||||
|
thread_id=1,
|
||||||
|
chat_turn_id="ct-i",
|
||||||
|
session=session,
|
||||||
|
user=_FakeUser(),
|
||||||
|
)
|
||||||
|
assert response.status == "ok"
|
||||||
|
assert response.already_reverted == 1
|
||||||
|
assert response.results[0].status == "already_reverted"
|
||||||
|
assert response.results[0].new_action_id == 42
|
||||||
|
revert.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revert_action_skips_existing_revert_rows(
|
||||||
|
self, patch_get_flags
|
||||||
|
) -> None:
|
||||||
|
rows = [_FakeAction(id=99, tool_name="_revert:edit_file", reverse_of=42)]
|
||||||
|
session = _FakeSession()
|
||||||
|
session.queue(_Result(rows=rows))
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||||
|
),
|
||||||
|
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
|
||||||
|
):
|
||||||
|
response = await agent_revert_route.revert_agent_turn(
|
||||||
|
thread_id=1,
|
||||||
|
chat_turn_id="ct-rev",
|
||||||
|
session=session,
|
||||||
|
user=_FakeUser(),
|
||||||
|
)
|
||||||
|
assert response.status == "ok"
|
||||||
|
assert response.results[0].status == "skipped"
|
||||||
|
revert.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_partial_success_when_some_rows_not_reversible(
|
||||||
|
self, patch_get_flags
|
||||||
|
) -> None:
|
||||||
|
rows = [
|
||||||
|
_FakeAction(id=2, tool_name="send_email"),
|
||||||
|
_FakeAction(id=1, tool_name="edit_file"),
|
||||||
|
]
|
||||||
|
session = _FakeSession()
|
||||||
|
session.queue(_Result(rows=rows))
|
||||||
|
# Single batched idempotency probe.
|
||||||
|
session.queue(_Result(rows=[]))
|
||||||
|
|
||||||
|
async def _fake_revert(_session, *, action, requester_user_id):
|
||||||
|
if action.tool_name == "send_email":
|
||||||
|
return RevertOutcome(
|
||||||
|
status="not_reversible",
|
||||||
|
message="connector revert not yet implemented",
|
||||||
|
)
|
||||||
|
return RevertOutcome(status="ok", message="ok", new_action_id=500)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
|
||||||
|
),
|
||||||
|
):
|
||||||
|
response = await agent_revert_route.revert_agent_turn(
|
||||||
|
thread_id=1,
|
||||||
|
chat_turn_id="ct-mix",
|
||||||
|
session=session,
|
||||||
|
user=_FakeUser(),
|
||||||
|
)
|
||||||
|
assert response.status == "partial"
|
||||||
|
assert response.reverted == 1
|
||||||
|
assert response.not_reversible == 1
|
||||||
|
statuses = sorted(r.status for r in response.results)
|
||||||
|
assert statuses == ["not_reversible", "reverted"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unexpected_exception_marks_row_failed_not_batch(
|
||||||
|
self, patch_get_flags
|
||||||
|
) -> None:
|
||||||
|
rows = [
|
||||||
|
_FakeAction(id=20, tool_name="edit_file"),
|
||||||
|
_FakeAction(id=21, tool_name="edit_file"),
|
||||||
|
]
|
||||||
|
session = _FakeSession()
|
||||||
|
session.queue(_Result(rows=rows))
|
||||||
|
# Single batched idempotency probe.
|
||||||
|
session.queue(_Result(rows=[]))
|
||||||
|
|
||||||
|
async def _fake_revert(_session, *, action, requester_user_id):
|
||||||
|
if action.id == 20:
|
||||||
|
raise RuntimeError("disk on fire")
|
||||||
|
return RevertOutcome(status="ok", message="ok", new_action_id=999)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "revert_action", AsyncMock(side_effect=_fake_revert)
|
||||||
|
),
|
||||||
|
):
|
||||||
|
response = await agent_revert_route.revert_agent_turn(
|
||||||
|
thread_id=1,
|
||||||
|
chat_turn_id="ct-fail",
|
||||||
|
session=session,
|
||||||
|
user=_FakeUser(),
|
||||||
|
)
|
||||||
|
assert response.status == "partial"
|
||||||
|
assert response.failed == 1
|
||||||
|
assert response.reverted == 1
|
||||||
|
bad = next(r for r in response.results if r.action_id == 20)
|
||||||
|
assert bad.status == "failed"
|
||||||
|
assert "disk on fire" in (bad.error or "")
|
||||||
|
good = next(r for r in response.results if r.action_id == 21)
|
||||||
|
assert good.status == "reverted"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_permission_denied_when_other_user_owns_action(
|
||||||
|
self, patch_get_flags
|
||||||
|
) -> None:
|
||||||
|
rows = [_FakeAction(id=7, tool_name="edit_file", user_id="someone-else")]
|
||||||
|
session = _FakeSession()
|
||||||
|
session.queue(_Result(rows=rows))
|
||||||
|
# Batch idempotency probe (no prior reverts).
|
||||||
|
session.queue(_Result(rows=[]))
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||||
|
),
|
||||||
|
patch.object(agent_revert_route, "revert_action", AsyncMock()) as revert,
|
||||||
|
):
|
||||||
|
response = await agent_revert_route.revert_agent_turn(
|
||||||
|
thread_id=1,
|
||||||
|
chat_turn_id="ct-perm",
|
||||||
|
session=session,
|
||||||
|
user=_FakeUser(id="not-owner"),
|
||||||
|
)
|
||||||
|
assert response.status == "partial"
|
||||||
|
assert response.results[0].status == "permission_denied"
|
||||||
|
# ``permission_denied`` has its own dedicated counter so the
|
||||||
|
# response invariant ``total == sum(counters)`` always holds
|
||||||
|
# without overloading ``not_reversible`` (which historically
|
||||||
|
# absorbed this case and confused frontend toasts).
|
||||||
|
assert response.permission_denied == 1
|
||||||
|
assert response.not_reversible == 0
|
||||||
|
revert.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_counter_invariant_holds_across_mixed_outcomes(
|
||||||
|
self, patch_get_flags
|
||||||
|
) -> None:
|
||||||
|
"""Every row is accounted for in EXACTLY ONE counter.
|
||||||
|
|
||||||
|
Mixes one of every supported outcome (reverted, already_reverted,
|
||||||
|
not_reversible, permission_denied, failed, skipped) and asserts
|
||||||
|
that the sum of counters equals ``response.total``.
|
||||||
|
"""
|
||||||
|
rows = [
|
||||||
|
_FakeAction(id=10, tool_name="edit_file"), # ok
|
||||||
|
_FakeAction(id=9, tool_name="edit_file"), # already_reverted
|
||||||
|
_FakeAction(id=8, tool_name="send_email"), # not_reversible
|
||||||
|
_FakeAction(id=7, tool_name="rm", user_id="other"), # permission_denied
|
||||||
|
_FakeAction(id=6, tool_name="edit_file"), # failed
|
||||||
|
_FakeAction(id=5, tool_name="_revert:edit_file", reverse_of=99), # skipped
|
||||||
|
]
|
||||||
|
session = _FakeSession()
|
||||||
|
session.queue(_Result(rows=rows))
|
||||||
|
# Single batched probe; only id=9 has a prior revert.
|
||||||
|
# Schema: list[(revert_id, original_id)].
|
||||||
|
session.queue(_Result(rows=[(42, 9)]))
|
||||||
|
|
||||||
|
async def _fake_revert(_session, *, action, requester_user_id):
|
||||||
|
if action.id == 10:
|
||||||
|
return RevertOutcome(status="ok", message="ok", new_action_id=500)
|
||||||
|
if action.id == 8:
|
||||||
|
return RevertOutcome(
|
||||||
|
status="not_reversible",
|
||||||
|
message="connector revert not yet implemented",
|
||||||
|
)
|
||||||
|
if action.id == 6:
|
||||||
|
raise RuntimeError("boom")
|
||||||
|
raise AssertionError(f"unexpected revert call for {action.id}")
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route,
|
||||||
|
"revert_action",
|
||||||
|
AsyncMock(side_effect=_fake_revert),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
response = await agent_revert_route.revert_agent_turn(
|
||||||
|
thread_id=1,
|
||||||
|
chat_turn_id="ct-mixed-all",
|
||||||
|
session=session,
|
||||||
|
user=_FakeUser(), # only id=7 has a different user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.total == len(rows) == 6
|
||||||
|
bucket_sum = (
|
||||||
|
response.reverted
|
||||||
|
+ response.already_reverted
|
||||||
|
+ response.not_reversible
|
||||||
|
+ response.permission_denied
|
||||||
|
+ response.failed
|
||||||
|
+ response.skipped
|
||||||
|
)
|
||||||
|
assert bucket_sum == response.total, (
|
||||||
|
"Counter invariant broken: total "
|
||||||
|
f"({response.total}) != sum of counters ({bucket_sum}). "
|
||||||
|
f"Counters: reverted={response.reverted}, "
|
||||||
|
f"already_reverted={response.already_reverted}, "
|
||||||
|
f"not_reversible={response.not_reversible}, "
|
||||||
|
f"permission_denied={response.permission_denied}, "
|
||||||
|
f"failed={response.failed}, skipped={response.skipped}"
|
||||||
|
)
|
||||||
|
assert response.reverted == 1
|
||||||
|
assert response.already_reverted == 1
|
||||||
|
assert response.not_reversible == 1
|
||||||
|
assert response.permission_denied == 1
|
||||||
|
assert response.failed == 1
|
||||||
|
assert response.skipped == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_integrity_error_translates_to_already_reverted(
|
||||||
|
self, patch_get_flags
|
||||||
|
) -> None:
|
||||||
|
"""The partial unique index on ``reverse_of`` raises
|
||||||
|
``IntegrityError`` when a concurrent revert wins the race against
|
||||||
|
the pre-flight ``_was_already_reverted`` SELECT. The route MUST
|
||||||
|
recover by re-querying for the winning revert id and returning
|
||||||
|
``status="already_reverted"`` (not ``"failed"``) so racing
|
||||||
|
clients see consistent idempotent semantics.
|
||||||
|
"""
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
|
||||||
|
rows = [_FakeAction(id=33, tool_name="edit_file")]
|
||||||
|
session = _FakeSession()
|
||||||
|
session.queue(_Result(rows=rows))
|
||||||
|
# Batch pre-flight probe: nothing yet (we'll race).
|
||||||
|
session.queue(_Result(rows=[]))
|
||||||
|
# Post-IntegrityError fallback uses the SCALAR
|
||||||
|
# ``_was_already_reverted`` (single-id lookup) so it pulls
|
||||||
|
# ``[777]`` via ``.scalars().first()``.
|
||||||
|
session.queue(_Result(rows=[777]))
|
||||||
|
|
||||||
|
async def _racing_revert(_session, *, action, requester_user_id):
|
||||||
|
raise IntegrityError("INSERT", {}, Exception("dup reverse_of"))
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch_get_flags(_enabled_flags()),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route, "load_thread", AsyncMock(return_value=object())
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
agent_revert_route,
|
||||||
|
"revert_action",
|
||||||
|
AsyncMock(side_effect=_racing_revert),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
response = await agent_revert_route.revert_agent_turn(
|
||||||
|
thread_id=1,
|
||||||
|
chat_turn_id="ct-race",
|
||||||
|
session=session,
|
||||||
|
user=_FakeUser(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.failed == 0, (
|
||||||
|
"IntegrityError must NOT surface as a failed row; the unique "
|
||||||
|
"index is the durable expression of idempotency."
|
||||||
|
)
|
||||||
|
assert response.already_reverted == 1
|
||||||
|
assert response.results[0].status == "already_reverted"
|
||||||
|
assert response.results[0].new_action_id == 777
|
||||||
|
|
@ -0,0 +1,921 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.auto_model_pin_service import (
|
||||||
|
clear_healthy,
|
||||||
|
clear_runtime_cooldown,
|
||||||
|
is_recently_healthy,
|
||||||
|
mark_healthy,
|
||||||
|
mark_runtime_cooldown,
|
||||||
|
resolve_or_get_pinned_llm_config_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _clear_runtime_cooldown_map():
|
||||||
|
clear_runtime_cooldown()
|
||||||
|
clear_healthy()
|
||||||
|
yield
|
||||||
|
clear_runtime_cooldown()
|
||||||
|
clear_healthy()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _FakeQuotaResult:
|
||||||
|
allowed: bool
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeExecResult:
|
||||||
|
def __init__(self, thread):
|
||||||
|
self._thread = thread
|
||||||
|
|
||||||
|
def unique(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def scalar_one_or_none(self):
|
||||||
|
return self._thread
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
def __init__(self, thread):
|
||||||
|
self.thread = thread
|
||||||
|
self.commit_count = 0
|
||||||
|
|
||||||
|
async def execute(self, _stmt):
|
||||||
|
return _FakeExecResult(self.thread)
|
||||||
|
|
||||||
|
async def commit(self):
|
||||||
|
self.commit_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
def _thread(
|
||||||
|
*,
|
||||||
|
search_space_id: int = 10,
|
||||||
|
pinned_llm_config_id: int | None = None,
|
||||||
|
):
|
||||||
|
return SimpleNamespace(
|
||||||
|
id=1,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
pinned_llm_config_id=pinned_llm_config_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_first_turn_pins_one_model(monkeypatch):
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "OPENAI",
|
||||||
|
"model_name": "gpt-prem",
|
||||||
|
"api_key": "k2",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _allowed(*_args, **_kwargs):
|
||||||
|
return _FakeQuotaResult(allowed=True)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_allowed,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id in {-1, -2}
|
||||||
|
assert session.thread.pinned_llm_config_id == result.resolved_llm_config_id
|
||||||
|
assert session.commit_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_next_turn_reuses_existing_pin(monkeypatch):
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "OPENAI",
|
||||||
|
"model_name": "gpt-prem",
|
||||||
|
"api_key": "k2",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _must_not_call(*_args, **_kwargs):
|
||||||
|
raise AssertionError(
|
||||||
|
"premium_get_usage should not be called for valid pin reuse"
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_must_not_call,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -1
|
||||||
|
assert result.from_existing_pin is True
|
||||||
|
assert session.commit_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_premium_eligible_auto_can_pin_premium(monkeypatch):
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "OPENAI",
|
||||||
|
"model_name": "gpt-prem",
|
||||||
|
"api_key": "k2",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _allowed(*_args, **_kwargs):
|
||||||
|
return _FakeQuotaResult(allowed=True)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_allowed,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -1
|
||||||
|
assert result.resolved_tier == "premium"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_premium_ineligible_auto_pins_free_only(monkeypatch):
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -2,
|
||||||
|
"provider": "OPENAI",
|
||||||
|
"model_name": "gpt-free",
|
||||||
|
"api_key": "k1",
|
||||||
|
"billing_tier": "free",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "OPENAI",
|
||||||
|
"model_name": "gpt-prem",
|
||||||
|
"api_key": "k2",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _blocked(*_args, **_kwargs):
|
||||||
|
return _FakeQuotaResult(allowed=False)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_blocked,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -2
|
||||||
|
assert result.resolved_tier == "free"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pinned_premium_stays_premium_after_quota_exhaustion(monkeypatch):
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -2,
|
||||||
|
"provider": "OPENAI",
|
||||||
|
"model_name": "gpt-free",
|
||||||
|
"api_key": "k1",
|
||||||
|
"billing_tier": "free",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "OPENAI",
|
||||||
|
"model_name": "gpt-prem",
|
||||||
|
"api_key": "k2",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _blocked(*_args, **_kwargs):
|
||||||
|
return _FakeQuotaResult(allowed=False)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_blocked,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -1
|
||||||
|
assert result.from_existing_pin is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_force_repin_free_switches_auto_premium_pin_to_free(monkeypatch):
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -2,
|
||||||
|
"provider": "OPENAI",
|
||||||
|
"model_name": "gpt-free",
|
||||||
|
"api_key": "k1",
|
||||||
|
"billing_tier": "free",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "OPENAI",
|
||||||
|
"model_name": "gpt-prem",
|
||||||
|
"api_key": "k2",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _blocked(*_args, **_kwargs):
|
||||||
|
return _FakeQuotaResult(allowed=False)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_blocked,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
force_repin_free=True,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -2
|
||||||
|
assert result.resolved_tier == "free"
|
||||||
|
assert result.from_existing_pin is False
|
||||||
|
assert session.thread.pinned_llm_config_id == -2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_explicit_user_model_change_clears_pin(monkeypatch):
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread(pinned_llm_config_id=-2))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=7,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == 7
|
||||||
|
assert session.thread.pinned_llm_config_id is None
|
||||||
|
assert session.commit_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_pinned_config_repairs_with_new_pin(monkeypatch):
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread(pinned_llm_config_id=-999))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{"id": -2, "provider": "OPENAI", "model_name": "gpt-free", "api_key": "k1"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _allowed(*_args, **_kwargs):
|
||||||
|
return _FakeQuotaResult(allowed=True)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_allowed,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -2
|
||||||
|
assert session.thread.pinned_llm_config_id == -2
|
||||||
|
assert session.commit_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Quality-aware pin selection (Auto Fastest upgrade)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_gated_config_is_excluded_from_selection(monkeypatch):
|
||||||
|
"""A cfg flagged ``health_gated`` must never be picked even if it has
|
||||||
|
the highest score among eligible cfgs."""
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "venice/dead-model",
|
||||||
|
"api_key": "k1",
|
||||||
|
"billing_tier": "free",
|
||||||
|
"auto_pin_tier": "C",
|
||||||
|
"quality_score": 95,
|
||||||
|
"health_gated": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": -2,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "google/gemini-flash",
|
||||||
|
"api_key": "k1",
|
||||||
|
"billing_tier": "free",
|
||||||
|
"auto_pin_tier": "C",
|
||||||
|
"quality_score": 60,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _blocked(*_args, **_kwargs):
|
||||||
|
return _FakeQuotaResult(allowed=False)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_blocked,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tier_a_locks_first_premium_user_skips_or(monkeypatch):
|
||||||
|
"""Premium-eligible users with Tier A available should never spill to
|
||||||
|
Tier B even if a B cfg ranks higher by ``quality_score``."""
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "AZURE_OPENAI",
|
||||||
|
"model_name": "gpt-5",
|
||||||
|
"api_key": "k-yaml",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"auto_pin_tier": "A",
|
||||||
|
"quality_score": 70,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": -2,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "openai/gpt-5",
|
||||||
|
"api_key": "k-or",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"auto_pin_tier": "B",
|
||||||
|
"quality_score": 95,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _allowed(*_args, **_kwargs):
|
||||||
|
return _FakeQuotaResult(allowed=True)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_allowed,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -1
|
||||||
|
assert result.resolved_tier == "premium"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tier_a_falls_through_to_or_when_a_pool_empty_for_user(monkeypatch):
|
||||||
|
"""Free-only user with no Tier A free cfg should pick from Tier C."""
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "AZURE_OPENAI",
|
||||||
|
"model_name": "gpt-5",
|
||||||
|
"api_key": "k-yaml",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"auto_pin_tier": "A",
|
||||||
|
"quality_score": 100,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": -2,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "google/gemini-flash:free",
|
||||||
|
"api_key": "k-or",
|
||||||
|
"billing_tier": "free",
|
||||||
|
"auto_pin_tier": "C",
|
||||||
|
"quality_score": 60,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _blocked(*_args, **_kwargs):
|
||||||
|
return _FakeQuotaResult(allowed=False)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_blocked,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_top_k_picks_only_high_score_models(monkeypatch):
|
||||||
|
"""Different thread IDs should spread across top-K, never pick the
|
||||||
|
obvious low-quality cfg even when it sits in the candidate list."""
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
high_score_cfgs = [
|
||||||
|
{
|
||||||
|
"id": -i,
|
||||||
|
"provider": "AZURE_OPENAI",
|
||||||
|
"model_name": f"gpt-x-{i}",
|
||||||
|
"api_key": "k",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"auto_pin_tier": "A",
|
||||||
|
"quality_score": 90,
|
||||||
|
"health_gated": False,
|
||||||
|
}
|
||||||
|
for i in range(1, 6) # 5 high-quality Tier A cfgs
|
||||||
|
]
|
||||||
|
low_score_trap = {
|
||||||
|
"id": -99,
|
||||||
|
"provider": "AZURE_OPENAI",
|
||||||
|
"model_name": "tiny-legacy",
|
||||||
|
"api_key": "k",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"auto_pin_tier": "A",
|
||||||
|
"quality_score": 10,
|
||||||
|
"health_gated": False,
|
||||||
|
}
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[*high_score_cfgs, low_score_trap],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _allowed(*_args, **_kwargs):
|
||||||
|
return _FakeQuotaResult(allowed=True)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_allowed,
|
||||||
|
)
|
||||||
|
|
||||||
|
high_score_ids = {c["id"] for c in high_score_cfgs}
|
||||||
|
seen = set()
|
||||||
|
for thread_id in range(1, 50):
|
||||||
|
session = _FakeSession(_thread())
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=thread_id,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
seen.add(result.resolved_llm_config_id)
|
||||||
|
assert result.resolved_llm_config_id != -99, (
|
||||||
|
"low-score trap cfg should never be picked"
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id in high_score_ids
|
||||||
|
|
||||||
|
# Spread across at least a couple of top-K cfgs.
|
||||||
|
assert len(seen) > 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pin_reuse_survives_health_gating_for_existing_pin(monkeypatch):
|
||||||
|
"""An *already* pinned cfg that later flips to ``health_gated`` should
|
||||||
|
still not be reused — gated cfgs are filtered out of the candidate
|
||||||
|
pool, which forces a repair to a healthy cfg.
|
||||||
|
|
||||||
|
This guards the no-silent-tier-switch invariant: we don't keep using
|
||||||
|
a known-broken model just because the thread happened to be pinned
|
||||||
|
to it before the gate fired."""
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "venice/dead-model",
|
||||||
|
"api_key": "k",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"auto_pin_tier": "B",
|
||||||
|
"quality_score": 50,
|
||||||
|
"health_gated": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": -2,
|
||||||
|
"provider": "AZURE_OPENAI",
|
||||||
|
"model_name": "gpt-5",
|
||||||
|
"api_key": "k",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"auto_pin_tier": "A",
|
||||||
|
"quality_score": 90,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _allowed(*_args, **_kwargs):
|
||||||
|
return _FakeQuotaResult(allowed=True)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_allowed,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -2
|
||||||
|
assert result.from_existing_pin is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pin_reuse_regression_existing_healthy_pin(monkeypatch):
|
||||||
|
"""Existing pin reuse must short-circuit the new tier/score logic."""
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "AZURE_OPENAI",
|
||||||
|
"model_name": "gpt-5",
|
||||||
|
"api_key": "k",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"auto_pin_tier": "A",
|
||||||
|
"quality_score": 50, # lower than -2
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": -2,
|
||||||
|
"provider": "AZURE_OPENAI",
|
||||||
|
"model_name": "gpt-5-pro",
|
||||||
|
"api_key": "k",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"auto_pin_tier": "A",
|
||||||
|
"quality_score": 99,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _must_not_call(*_args, **_kwargs):
|
||||||
|
raise AssertionError("premium_get_usage should not run on pin reuse")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_must_not_call,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -1
|
||||||
|
assert result.from_existing_pin is True
|
||||||
|
assert session.commit_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runtime_cooled_down_pin_is_not_reused(monkeypatch):
|
||||||
|
"""A runtime-cooled config should be excluded from candidate reuse.
|
||||||
|
|
||||||
|
This enables one-shot recovery from transient provider 429 bursts: we can
|
||||||
|
mark the pinned cfg as cooled down and force a repair to another eligible
|
||||||
|
cfg on the next resolution.
|
||||||
|
"""
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "google/gemma-4-26b-a4b-it:free",
|
||||||
|
"api_key": "k",
|
||||||
|
"billing_tier": "free",
|
||||||
|
"auto_pin_tier": "C",
|
||||||
|
"quality_score": 90,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": -2,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "google/gemini-2.5-flash:free",
|
||||||
|
"api_key": "k",
|
||||||
|
"billing_tier": "free",
|
||||||
|
"auto_pin_tier": "C",
|
||||||
|
"quality_score": 80,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _blocked(*_args, **_kwargs):
|
||||||
|
return _FakeQuotaResult(allowed=False)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_blocked,
|
||||||
|
)
|
||||||
|
|
||||||
|
mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -2
|
||||||
|
assert result.from_existing_pin is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clearing_runtime_cooldown_restores_pin_reuse(monkeypatch):
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "google/gemma-4-26b-a4b-it:free",
|
||||||
|
"api_key": "k",
|
||||||
|
"billing_tier": "free",
|
||||||
|
"auto_pin_tier": "C",
|
||||||
|
"quality_score": 90,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _must_not_call(*_args, **_kwargs):
|
||||||
|
raise AssertionError("premium_get_usage should not run on healthy pin reuse")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_must_not_call,
|
||||||
|
)
|
||||||
|
|
||||||
|
mark_runtime_cooldown(-1, reason="provider_rate_limited", cooldown_seconds=600)
|
||||||
|
clear_runtime_cooldown(-1)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -1
|
||||||
|
assert result.from_existing_pin is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_pin_repin_excludes_previous_config_on_runtime_retry(monkeypatch):
|
||||||
|
"""Runtime retry should never repin the just-failed config."""
|
||||||
|
from app.config import config
|
||||||
|
|
||||||
|
session = _FakeSession(_thread(pinned_llm_config_id=-1))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"GLOBAL_LLM_CONFIGS",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": -1,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "google/gemma-4-26b-a4b-it:free",
|
||||||
|
"api_key": "k",
|
||||||
|
"billing_tier": "free",
|
||||||
|
"auto_pin_tier": "C",
|
||||||
|
"quality_score": 90,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": -2,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": "google/gemini-2.5-flash:free",
|
||||||
|
"api_key": "k",
|
||||||
|
"billing_tier": "free",
|
||||||
|
"auto_pin_tier": "C",
|
||||||
|
"quality_score": 80,
|
||||||
|
"health_gated": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _blocked(*_args, **_kwargs):
|
||||||
|
return _FakeQuotaResult(allowed=False)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.auto_model_pin_service.TokenQuotaService.premium_get_usage",
|
||||||
|
_blocked,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await resolve_or_get_pinned_llm_config_id(
|
||||||
|
session,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=10,
|
||||||
|
user_id="00000000-0000-0000-0000-000000000001",
|
||||||
|
selected_llm_config_id=0,
|
||||||
|
exclude_config_ids={-1},
|
||||||
|
)
|
||||||
|
assert result.resolved_llm_config_id == -2
|
||||||
|
assert result.from_existing_pin is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Healthy-status cache (preflight TTL companion)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_mark_healthy_then_is_recently_healthy_true_within_ttl():
|
||||||
|
mark_healthy(-42, ttl_seconds=60)
|
||||||
|
assert is_recently_healthy(-42) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_healthy_expires_after_ttl(monkeypatch):
|
||||||
|
import app.services.auto_model_pin_service as svc
|
||||||
|
|
||||||
|
real_time = svc.time.time
|
||||||
|
base = real_time()
|
||||||
|
|
||||||
|
monkeypatch.setattr(svc.time, "time", lambda: base)
|
||||||
|
mark_healthy(-7, ttl_seconds=10)
|
||||||
|
assert is_recently_healthy(-7) is True
|
||||||
|
|
||||||
|
monkeypatch.setattr(svc.time, "time", lambda: base + 11)
|
||||||
|
assert is_recently_healthy(-7) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_mark_runtime_cooldown_invalidates_healthy_cache():
|
||||||
|
mark_healthy(-9, ttl_seconds=60)
|
||||||
|
assert is_recently_healthy(-9) is True
|
||||||
|
|
||||||
|
mark_runtime_cooldown(-9, reason="test", cooldown_seconds=60)
|
||||||
|
assert is_recently_healthy(-9) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_healthy_removes_single_entry():
|
||||||
|
mark_healthy(-11, ttl_seconds=60)
|
||||||
|
mark_healthy(-12, ttl_seconds=60)
|
||||||
|
clear_healthy(-11)
|
||||||
|
assert is_recently_healthy(-11) is False
|
||||||
|
assert is_recently_healthy(-12) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_healthy_no_args_drops_all_entries():
|
||||||
|
mark_healthy(-21, ttl_seconds=60)
|
||||||
|
mark_healthy(-22, ttl_seconds=60)
|
||||||
|
clear_healthy()
|
||||||
|
assert is_recently_healthy(-21) is False
|
||||||
|
assert is_recently_healthy(-22) is False
|
||||||
|
|
@ -0,0 +1,226 @@
|
||||||
|
"""LLMRouterService pool-filter / rebuild tests.
|
||||||
|
|
||||||
|
These tests focus on the *config plumbing* (which configs enter the router
|
||||||
|
pool, rebuild resets state correctly). They stub out the underlying
|
||||||
|
``litellm.Router`` so we don't need real API keys or network access.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.llm_router_service import LLMRouterService
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_yaml_config(
|
||||||
|
*,
|
||||||
|
id: int,
|
||||||
|
model_name: str,
|
||||||
|
billing_tier: str = "free",
|
||||||
|
) -> dict:
|
||||||
|
return {
|
||||||
|
"id": id,
|
||||||
|
"name": f"yaml-{id}",
|
||||||
|
"provider": "OPENAI",
|
||||||
|
"model_name": model_name,
|
||||||
|
"api_key": "sk-test",
|
||||||
|
"api_base": "",
|
||||||
|
"billing_tier": billing_tier,
|
||||||
|
"rpm": 100,
|
||||||
|
"tpm": 100_000,
|
||||||
|
"litellm_params": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_openrouter_config(
|
||||||
|
*,
|
||||||
|
id: int,
|
||||||
|
model_name: str,
|
||||||
|
billing_tier: str,
|
||||||
|
router_pool_eligible: bool | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Build a synthetic dynamic-OR config dict for router-pool tests.
|
||||||
|
|
||||||
|
Defaults mirror Strategy 3: premium OR enters the pool, free OR stays
|
||||||
|
out. Callers can override ``router_pool_eligible`` to simulate legacy
|
||||||
|
configs or to regression-test the filter mechanics directly.
|
||||||
|
"""
|
||||||
|
if router_pool_eligible is None:
|
||||||
|
router_pool_eligible = billing_tier == "premium"
|
||||||
|
return {
|
||||||
|
"id": id,
|
||||||
|
"name": f"or-{id}",
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": model_name,
|
||||||
|
"api_key": "sk-or-test",
|
||||||
|
"api_base": "",
|
||||||
|
"billing_tier": billing_tier,
|
||||||
|
"rpm": 20 if billing_tier == "free" else 200,
|
||||||
|
"tpm": 100_000 if billing_tier == "free" else 1_000_000,
|
||||||
|
"litellm_params": {},
|
||||||
|
"router_pool_eligible": router_pool_eligible,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_router_singleton() -> None:
|
||||||
|
instance = LLMRouterService.get_instance()
|
||||||
|
instance._initialized = False
|
||||||
|
instance._router = None
|
||||||
|
instance._model_list = []
|
||||||
|
instance._premium_model_strings = set()
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_pool_includes_or_premium_excludes_or_free():
|
||||||
|
"""Strategy 3: premium OR joins the pool, free OR stays out.
|
||||||
|
|
||||||
|
Dynamic OpenRouter premium entries opt into load balancing alongside
|
||||||
|
curated YAML configs. Dynamic OR free entries are intentionally kept
|
||||||
|
out because OpenRouter's free tier enforces a single account-global
|
||||||
|
quota bucket that per-deployment router accounting can't represent.
|
||||||
|
"""
|
||||||
|
_reset_router_singleton()
|
||||||
|
configs = [
|
||||||
|
_fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
|
||||||
|
_fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"),
|
||||||
|
_fake_openrouter_config(
|
||||||
|
id=-10_001, model_name="openai/gpt-4o", billing_tier="premium"
|
||||||
|
),
|
||||||
|
_fake_openrouter_config(
|
||||||
|
id=-10_002,
|
||||||
|
model_name="meta-llama/llama-3.3-70b:free",
|
||||||
|
billing_tier="free",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.services.llm_router_service.Router") as mock_router,
|
||||||
|
patch(
|
||||||
|
"app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
|
||||||
|
) as mock_ctx_fb,
|
||||||
|
):
|
||||||
|
mock_ctx_fb.side_effect = lambda ml: (ml, None)
|
||||||
|
mock_router.return_value = object()
|
||||||
|
LLMRouterService.initialize(configs)
|
||||||
|
|
||||||
|
pool_models = {
|
||||||
|
dep["litellm_params"]["model"]
|
||||||
|
for dep in LLMRouterService.get_instance()._model_list
|
||||||
|
}
|
||||||
|
# YAML premium + YAML free + dynamic OR premium are all in the pool.
|
||||||
|
# Dynamic OR free is NOT (shared-bucket rate limits can't be load-balanced).
|
||||||
|
assert pool_models == {
|
||||||
|
"openai/gpt-4o",
|
||||||
|
"openai/gpt-4o-mini",
|
||||||
|
"openrouter/openai/gpt-4o",
|
||||||
|
}
|
||||||
|
|
||||||
|
prem = LLMRouterService.get_instance()._premium_model_strings
|
||||||
|
# YAML premium is fingerprinted under both its model_string and its
|
||||||
|
# ``base_model`` form (existing behavior we don't want to regress).
|
||||||
|
assert "openai/gpt-4o" in prem
|
||||||
|
# Dynamic OR premium is now fingerprinted as premium so pool-level
|
||||||
|
# calls through the router are billed against premium quota.
|
||||||
|
assert "openrouter/openai/gpt-4o" in prem
|
||||||
|
assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is True
|
||||||
|
# Dynamic OR free never enters the pool, so it's never counted as premium.
|
||||||
|
assert (
|
||||||
|
LLMRouterService.is_premium_model("openrouter/meta-llama/llama-3.3-70b:free")
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_pool_filter_mechanics_respect_override():
|
||||||
|
"""The ``router_pool_eligible`` filter itself works independently of tier.
|
||||||
|
|
||||||
|
Regression guard: if a future refactor ever sets the flag False on a
|
||||||
|
premium config (e.g. for maintenance), that config MUST be skipped by
|
||||||
|
``initialize`` even though its tier is premium.
|
||||||
|
"""
|
||||||
|
_reset_router_singleton()
|
||||||
|
configs = [
|
||||||
|
_fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
|
||||||
|
_fake_openrouter_config(
|
||||||
|
id=-10_001,
|
||||||
|
model_name="openai/gpt-4o",
|
||||||
|
billing_tier="premium",
|
||||||
|
router_pool_eligible=False, # opt out despite being premium
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.services.llm_router_service.Router") as mock_router,
|
||||||
|
patch(
|
||||||
|
"app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
|
||||||
|
) as mock_ctx_fb,
|
||||||
|
):
|
||||||
|
mock_ctx_fb.side_effect = lambda ml: (ml, None)
|
||||||
|
mock_router.return_value = object()
|
||||||
|
LLMRouterService.initialize(configs)
|
||||||
|
|
||||||
|
pool_models = {
|
||||||
|
dep["litellm_params"]["model"]
|
||||||
|
for dep in LLMRouterService.get_instance()._model_list
|
||||||
|
}
|
||||||
|
assert pool_models == {"openai/gpt-4o"}
|
||||||
|
assert LLMRouterService.is_premium_model("openrouter/openai/gpt-4o") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_rebuild_refreshes_pool_after_configs_change():
|
||||||
|
_reset_router_singleton()
|
||||||
|
configs_v1 = [
|
||||||
|
_fake_yaml_config(id=-1, model_name="gpt-4o", billing_tier="premium"),
|
||||||
|
]
|
||||||
|
configs_v2 = [
|
||||||
|
*configs_v1,
|
||||||
|
_fake_yaml_config(id=-2, model_name="gpt-4o-mini", billing_tier="free"),
|
||||||
|
]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.services.llm_router_service.Router") as mock_router,
|
||||||
|
patch(
|
||||||
|
"app.services.llm_router_service.LLMRouterService._build_context_fallback_groups"
|
||||||
|
) as mock_ctx_fb,
|
||||||
|
):
|
||||||
|
mock_ctx_fb.side_effect = lambda ml: (ml, None)
|
||||||
|
mock_router.return_value = object()
|
||||||
|
|
||||||
|
LLMRouterService.initialize(configs_v1)
|
||||||
|
assert len(LLMRouterService.get_instance()._model_list) == 1
|
||||||
|
|
||||||
|
# ``initialize`` should be a no-op here (already initialized).
|
||||||
|
LLMRouterService.initialize(configs_v2)
|
||||||
|
assert len(LLMRouterService.get_instance()._model_list) == 1
|
||||||
|
|
||||||
|
# ``rebuild`` must clear the guard and re-run with the new configs.
|
||||||
|
LLMRouterService.rebuild(configs_v2)
|
||||||
|
assert len(LLMRouterService.get_instance()._model_list) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_model_pin_candidates_include_dynamic_openrouter():
|
||||||
|
"""Dynamic OR configs must remain Auto-mode thread-pin candidates.
|
||||||
|
|
||||||
|
Guards against a future regression where someone adds the
|
||||||
|
``router_pool_eligible`` filter to ``auto_model_pin_service._global_candidates``.
|
||||||
|
"""
|
||||||
|
from app.config import config
|
||||||
|
from app.services.auto_model_pin_service import _global_candidates
|
||||||
|
|
||||||
|
or_premium = _fake_openrouter_config(
|
||||||
|
id=-10_001, model_name="openai/gpt-4o", billing_tier="premium"
|
||||||
|
)
|
||||||
|
or_free = _fake_openrouter_config(
|
||||||
|
id=-10_002,
|
||||||
|
model_name="meta-llama/llama-3.3-70b:free",
|
||||||
|
billing_tier="free",
|
||||||
|
)
|
||||||
|
original = config.GLOBAL_LLM_CONFIGS
|
||||||
|
try:
|
||||||
|
config.GLOBAL_LLM_CONFIGS = [or_premium, or_free]
|
||||||
|
candidate_ids = {c["id"] for c in _global_candidates()}
|
||||||
|
assert candidate_ids == {-10_001, -10_002}
|
||||||
|
finally:
|
||||||
|
config.GLOBAL_LLM_CONFIGS = original
|
||||||
|
|
@ -0,0 +1,216 @@
|
||||||
|
"""Unit tests for the dynamic OpenRouter integration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.openrouter_integration_service import (
|
||||||
|
_OPENROUTER_DYNAMIC_MARKER,
|
||||||
|
_generate_configs,
|
||||||
|
_openrouter_tier,
|
||||||
|
_stable_config_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _minimal_openrouter_model(
|
||||||
|
*,
|
||||||
|
model_id: str,
|
||||||
|
pricing: dict | None = None,
|
||||||
|
name: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Return a synthetic OpenRouter /api/v1/models entry.
|
||||||
|
|
||||||
|
The real API payload includes a lot of fields; we only populate what
|
||||||
|
``_generate_configs`` actually inspects (architecture, tool support,
|
||||||
|
context, pricing, id).
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"id": model_id,
|
||||||
|
"name": name or model_id,
|
||||||
|
"architecture": {"output_modalities": ["text"]},
|
||||||
|
"supported_parameters": ["tools"],
|
||||||
|
"context_length": 200_000,
|
||||||
|
"pricing": pricing or {"prompt": "0.000003", "completion": "0.000015"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _openrouter_tier
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_openrouter_tier_free_suffix():
|
||||||
|
assert _openrouter_tier({"id": "foo/bar:free"}) == "free"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openrouter_tier_zero_pricing():
|
||||||
|
model = {
|
||||||
|
"id": "foo/bar",
|
||||||
|
"pricing": {"prompt": "0", "completion": "0"},
|
||||||
|
}
|
||||||
|
assert _openrouter_tier(model) == "free"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openrouter_tier_paid():
|
||||||
|
model = {
|
||||||
|
"id": "foo/bar",
|
||||||
|
"pricing": {"prompt": "0.000003", "completion": "0.000015"},
|
||||||
|
}
|
||||||
|
assert _openrouter_tier(model) == "premium"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openrouter_tier_missing_pricing_is_premium():
|
||||||
|
assert _openrouter_tier({"id": "foo/bar"}) == "premium"
|
||||||
|
assert _openrouter_tier({"id": "foo/bar", "pricing": {}}) == "premium"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _stable_config_id
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_stable_config_id_deterministic():
|
||||||
|
taken1: set[int] = set()
|
||||||
|
taken2: set[int] = set()
|
||||||
|
a = _stable_config_id("openai/gpt-4o", -10_000, taken1)
|
||||||
|
b = _stable_config_id("openai/gpt-4o", -10_000, taken2)
|
||||||
|
assert a == b
|
||||||
|
assert a < 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_stable_config_id_collision_decrements():
|
||||||
|
"""When two model_ids hash to the same slot, the second should decrement."""
|
||||||
|
taken: set[int] = set()
|
||||||
|
a = _stable_config_id("openai/gpt-4o", -10_000, taken)
|
||||||
|
# Force a collision by pre-populating ``taken`` with a slot we know will be
|
||||||
|
# picked.
|
||||||
|
taken_forced = {a}
|
||||||
|
b = _stable_config_id("openai/gpt-4o", -10_000, taken_forced)
|
||||||
|
assert b != a
|
||||||
|
assert b == a - 1
|
||||||
|
assert b in taken_forced
|
||||||
|
|
||||||
|
|
||||||
|
def test_stable_config_id_different_models_different_ids():
|
||||||
|
taken: set[int] = set()
|
||||||
|
ids = {
|
||||||
|
_stable_config_id("openai/gpt-4o", -10_000, taken),
|
||||||
|
_stable_config_id("anthropic/claude-3.5-sonnet", -10_000, taken),
|
||||||
|
_stable_config_id("google/gemini-2.0-flash", -10_000, taken),
|
||||||
|
}
|
||||||
|
assert len(ids) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_stable_config_id_survives_catalogue_churn():
|
||||||
|
"""Removing a model should not shift other models' IDs (the bug we fix)."""
|
||||||
|
taken1: set[int] = set()
|
||||||
|
id_a1 = _stable_config_id("openai/gpt-4o", -10_000, taken1)
|
||||||
|
_ = _stable_config_id("anthropic/claude-3-haiku", -10_000, taken1)
|
||||||
|
id_c1 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken1)
|
||||||
|
|
||||||
|
taken2: set[int] = set()
|
||||||
|
id_a2 = _stable_config_id("openai/gpt-4o", -10_000, taken2)
|
||||||
|
id_c2 = _stable_config_id("google/gemini-2.0-flash", -10_000, taken2)
|
||||||
|
|
||||||
|
assert id_a1 == id_a2
|
||||||
|
assert id_c1 == id_c2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _generate_configs
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
_SETTINGS_BASE: dict = {
|
||||||
|
"api_key": "sk-or-test",
|
||||||
|
"id_offset": -10_000,
|
||||||
|
"rpm": 200,
|
||||||
|
"tpm": 1_000_000,
|
||||||
|
"free_rpm": 20,
|
||||||
|
"free_tpm": 100_000,
|
||||||
|
"anonymous_enabled_paid": False,
|
||||||
|
"anonymous_enabled_free": True,
|
||||||
|
"quota_reserve_tokens": 4000,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_configs_respects_tier():
|
||||||
|
"""Premium OR models opt into the router pool; free OR models stay out.
|
||||||
|
|
||||||
|
Strategy-3 split: premium participates in LiteLLM Router load balancing,
|
||||||
|
free stays excluded because OpenRouter enforces a shared global free-tier
|
||||||
|
bucket that per-deployment router accounting can't represent.
|
||||||
|
"""
|
||||||
|
raw = [
|
||||||
|
_minimal_openrouter_model(model_id="openai/gpt-4o"),
|
||||||
|
_minimal_openrouter_model(
|
||||||
|
model_id="meta-llama/llama-3.3-70b-instruct:free",
|
||||||
|
pricing={"prompt": "0", "completion": "0"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
|
||||||
|
by_model = {c["model_name"]: c for c in cfgs}
|
||||||
|
|
||||||
|
paid = by_model["openai/gpt-4o"]
|
||||||
|
assert paid["billing_tier"] == "premium"
|
||||||
|
assert paid["rpm"] == 200
|
||||||
|
assert paid["tpm"] == 1_000_000
|
||||||
|
assert paid["anonymous_enabled"] is False
|
||||||
|
assert paid["router_pool_eligible"] is True
|
||||||
|
assert paid[_OPENROUTER_DYNAMIC_MARKER] is True
|
||||||
|
|
||||||
|
free = by_model["meta-llama/llama-3.3-70b-instruct:free"]
|
||||||
|
assert free["billing_tier"] == "free"
|
||||||
|
assert free["rpm"] == 20
|
||||||
|
assert free["tpm"] == 100_000
|
||||||
|
assert free["anonymous_enabled"] is True
|
||||||
|
assert free["router_pool_eligible"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_configs_excludes_upstream_openrouter_free_router():
|
||||||
|
"""OpenRouter's own ``openrouter/free`` meta-router must never become a card.
|
||||||
|
|
||||||
|
The upstream API returns this as a first-class zero-priced model, so
|
||||||
|
without an explicit blocklist entry it would slip through every other
|
||||||
|
filter (text output, tool calling, 200k context, non-Amazon) and land
|
||||||
|
in the selector as a duplicate of the concrete ``:free`` cards. The
|
||||||
|
exclusion in ``_EXCLUDED_MODEL_IDS`` prevents that.
|
||||||
|
"""
|
||||||
|
raw = [
|
||||||
|
_minimal_openrouter_model(model_id="openai/gpt-4o"),
|
||||||
|
_minimal_openrouter_model(
|
||||||
|
model_id="openrouter/free",
|
||||||
|
pricing={"prompt": "0", "completion": "0"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
|
||||||
|
model_names = {c["model_name"] for c in cfgs}
|
||||||
|
assert "openrouter/free" not in model_names
|
||||||
|
assert "openai/gpt-4o" in model_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_configs_drops_non_text_and_non_tool_models():
|
||||||
|
raw = [
|
||||||
|
_minimal_openrouter_model(model_id="openai/gpt-4o"),
|
||||||
|
{ # image-output model
|
||||||
|
"id": "openai/dall-e",
|
||||||
|
"architecture": {"output_modalities": ["image"]},
|
||||||
|
"supported_parameters": ["tools"],
|
||||||
|
"context_length": 200_000,
|
||||||
|
"pricing": {"prompt": "0.01", "completion": "0.01"},
|
||||||
|
},
|
||||||
|
{ # text but no tool calling
|
||||||
|
"id": "openai/completion-only",
|
||||||
|
"architecture": {"output_modalities": ["text"]},
|
||||||
|
"supported_parameters": [],
|
||||||
|
"context_length": 200_000,
|
||||||
|
"pricing": {"prompt": "0.01", "completion": "0.01"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
cfgs = _generate_configs(raw, dict(_SETTINGS_BASE))
|
||||||
|
model_names = [c["model_name"] for c in cfgs]
|
||||||
|
assert "openai/gpt-4o" in model_names
|
||||||
|
assert "openai/dall-e" not in model_names
|
||||||
|
assert "openai/completion-only" not in model_names
|
||||||
|
|
@ -0,0 +1,108 @@
|
||||||
|
"""Tests for deprecated-key warnings and back-compat in
|
||||||
|
``load_openrouter_integration_settings``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _write_yaml(tmp_path: Path, body: str) -> Path:
|
||||||
|
cfg_dir = tmp_path / "app" / "config"
|
||||||
|
cfg_dir.mkdir(parents=True)
|
||||||
|
cfg_path = cfg_dir / "global_llm_config.yaml"
|
||||||
|
cfg_path.write_text(body, encoding="utf-8")
|
||||||
|
return cfg_path
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_base_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||||
|
from app import config as config_module
|
||||||
|
|
||||||
|
monkeypatch.setattr(config_module, "BASE_DIR", tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_billing_tier_emits_warning(monkeypatch, tmp_path, capsys):
|
||||||
|
_write_yaml(
|
||||||
|
tmp_path,
|
||||||
|
"""
|
||||||
|
openrouter_integration:
|
||||||
|
enabled: true
|
||||||
|
api_key: "sk-or-test"
|
||||||
|
billing_tier: "premium"
|
||||||
|
""".lstrip(),
|
||||||
|
)
|
||||||
|
_patch_base_dir(monkeypatch, tmp_path)
|
||||||
|
|
||||||
|
from app.config import load_openrouter_integration_settings
|
||||||
|
|
||||||
|
settings = load_openrouter_integration_settings()
|
||||||
|
captured = capsys.readouterr().out
|
||||||
|
assert settings is not None
|
||||||
|
assert "billing_tier is deprecated" in captured
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_anonymous_enabled_back_compat(monkeypatch, tmp_path, capsys):
|
||||||
|
_write_yaml(
|
||||||
|
tmp_path,
|
||||||
|
"""
|
||||||
|
openrouter_integration:
|
||||||
|
enabled: true
|
||||||
|
api_key: "sk-or-test"
|
||||||
|
anonymous_enabled: true
|
||||||
|
""".lstrip(),
|
||||||
|
)
|
||||||
|
_patch_base_dir(monkeypatch, tmp_path)
|
||||||
|
|
||||||
|
from app.config import load_openrouter_integration_settings
|
||||||
|
|
||||||
|
settings = load_openrouter_integration_settings()
|
||||||
|
captured = capsys.readouterr().out
|
||||||
|
assert settings is not None
|
||||||
|
assert settings["anonymous_enabled_paid"] is True
|
||||||
|
assert settings["anonymous_enabled_free"] is True
|
||||||
|
assert "anonymous_enabled is" in captured
|
||||||
|
assert "deprecated" in captured
|
||||||
|
|
||||||
|
|
||||||
|
def test_new_keys_take_priority_over_legacy_back_compat(monkeypatch, tmp_path, capsys):
|
||||||
|
"""If both legacy and new keys are present, new keys win (setdefault)."""
|
||||||
|
_write_yaml(
|
||||||
|
tmp_path,
|
||||||
|
"""
|
||||||
|
openrouter_integration:
|
||||||
|
enabled: true
|
||||||
|
api_key: "sk-or-test"
|
||||||
|
anonymous_enabled: true
|
||||||
|
anonymous_enabled_paid: false
|
||||||
|
anonymous_enabled_free: false
|
||||||
|
""".lstrip(),
|
||||||
|
)
|
||||||
|
_patch_base_dir(monkeypatch, tmp_path)
|
||||||
|
|
||||||
|
from app.config import load_openrouter_integration_settings
|
||||||
|
|
||||||
|
settings = load_openrouter_integration_settings()
|
||||||
|
capsys.readouterr()
|
||||||
|
assert settings is not None
|
||||||
|
assert settings["anonymous_enabled_paid"] is False
|
||||||
|
assert settings["anonymous_enabled_free"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_disabled_integration_returns_none(monkeypatch, tmp_path):
|
||||||
|
_write_yaml(
|
||||||
|
tmp_path,
|
||||||
|
"""
|
||||||
|
openrouter_integration:
|
||||||
|
enabled: false
|
||||||
|
api_key: "sk-or-test"
|
||||||
|
""".lstrip(),
|
||||||
|
)
|
||||||
|
_patch_base_dir(monkeypatch, tmp_path)
|
||||||
|
|
||||||
|
from app.config import load_openrouter_integration_settings
|
||||||
|
|
||||||
|
assert load_openrouter_integration_settings() is None
|
||||||
|
|
@ -0,0 +1,331 @@
|
||||||
|
"""Unit tests for the OpenRouter ``_enrich_health`` background task."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.openrouter_integration_service import (
|
||||||
|
OpenRouterIntegrationService,
|
||||||
|
)
|
||||||
|
from app.services.quality_score import (
|
||||||
|
_HEALTH_FAIL_RATIO_FALLBACK,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
def _or_cfg(
|
||||||
|
*,
|
||||||
|
cid: int,
|
||||||
|
model_name: str,
|
||||||
|
tier: str = "premium",
|
||||||
|
static_score: int = 50,
|
||||||
|
) -> dict:
|
||||||
|
return {
|
||||||
|
"id": cid,
|
||||||
|
"provider": "OPENROUTER",
|
||||||
|
"model_name": model_name,
|
||||||
|
"billing_tier": tier,
|
||||||
|
"auto_pin_tier": "B" if tier == "premium" else "C",
|
||||||
|
"quality_score_static": static_score,
|
||||||
|
"quality_score_health": None,
|
||||||
|
"quality_score": static_score,
|
||||||
|
"health_gated": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class _StubResponse:
|
||||||
|
def __init__(self, *, payload: dict, status_code: int = 200):
|
||||||
|
self._payload = payload
|
||||||
|
self.status_code = status_code
|
||||||
|
|
||||||
|
def raise_for_status(self) -> None:
|
||||||
|
if self.status_code >= 400:
|
||||||
|
raise RuntimeError(f"HTTP {self.status_code}")
|
||||||
|
|
||||||
|
def json(self) -> dict:
|
||||||
|
return self._payload
|
||||||
|
|
||||||
|
|
||||||
|
class _StubAsyncClient:
|
||||||
|
"""Minimal drop-in for ``httpx.AsyncClient`` used by ``_fetch_endpoints``."""
|
||||||
|
|
||||||
|
def __init__(self, responder):
|
||||||
|
self._responder = responder
|
||||||
|
self.requests: list[str] = []
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get(self, url: str, headers: dict | None = None) -> _StubResponse:
|
||||||
|
self.requests.append(url)
|
||||||
|
return self._responder(url)
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_async_client(monkeypatch, responder) -> _StubAsyncClient:
|
||||||
|
"""Replace ``httpx.AsyncClient`` for the duration of the test."""
|
||||||
|
client = _StubAsyncClient(responder)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.services.openrouter_integration_service.httpx.AsyncClient",
|
||||||
|
lambda *_args, **_kwargs: client,
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def _healthy_payload() -> dict:
|
||||||
|
return {
|
||||||
|
"data": {
|
||||||
|
"endpoints": [
|
||||||
|
{
|
||||||
|
"status": 0,
|
||||||
|
"uptime_last_30m": 0.99,
|
||||||
|
"uptime_last_1d": 0.995,
|
||||||
|
"uptime_last_5m": 0.99,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _unhealthy_payload() -> dict:
|
||||||
|
return {
|
||||||
|
"data": {
|
||||||
|
"endpoints": [
|
||||||
|
{
|
||||||
|
"status": 0,
|
||||||
|
"uptime_last_30m": 0.55,
|
||||||
|
"uptime_last_1d": 0.62,
|
||||||
|
"uptime_last_5m": 0.50,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Bounded fan-out + happy path
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_enrich_health_marks_healthy_and_gates_unhealthy(monkeypatch):
|
||||||
|
cfgs = [
|
||||||
|
_or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
|
||||||
|
_or_cfg(cid=-2, model_name="venice/dead-model", static_score=60),
|
||||||
|
]
|
||||||
|
|
||||||
|
def responder(url: str) -> _StubResponse:
|
||||||
|
if "anthropic" in url:
|
||||||
|
return _StubResponse(payload=_healthy_payload())
|
||||||
|
return _StubResponse(payload=_unhealthy_payload())
|
||||||
|
|
||||||
|
_patch_async_client(monkeypatch, responder)
|
||||||
|
|
||||||
|
service = OpenRouterIntegrationService()
|
||||||
|
service._settings = {"api_key": ""}
|
||||||
|
await service._enrich_health(cfgs)
|
||||||
|
|
||||||
|
healthy = next(c for c in cfgs if c["id"] == -1)
|
||||||
|
gated = next(c for c in cfgs if c["id"] == -2)
|
||||||
|
|
||||||
|
assert healthy["health_gated"] is False
|
||||||
|
assert healthy["quality_score_health"] is not None
|
||||||
|
assert healthy["quality_score"] >= healthy["quality_score_static"]
|
||||||
|
|
||||||
|
assert gated["health_gated"] is True
|
||||||
|
assert gated["quality_score"] == gated["quality_score_static"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_enrich_health_only_touches_or_provider(monkeypatch):
|
||||||
|
"""YAML cfgs that aren't OPENROUTER must be skipped entirely."""
|
||||||
|
yaml_cfg = {
|
||||||
|
"id": -1,
|
||||||
|
"provider": "AZURE_OPENAI",
|
||||||
|
"model_name": "gpt-5",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
"auto_pin_tier": "A",
|
||||||
|
"quality_score_static": 80,
|
||||||
|
"quality_score": 80,
|
||||||
|
"health_gated": False,
|
||||||
|
}
|
||||||
|
or_cfg = _or_cfg(cid=-2, model_name="anthropic/claude-haiku")
|
||||||
|
|
||||||
|
requests: list[str] = []
|
||||||
|
|
||||||
|
def responder(url: str) -> _StubResponse:
|
||||||
|
requests.append(url)
|
||||||
|
return _StubResponse(payload=_healthy_payload())
|
||||||
|
|
||||||
|
_patch_async_client(monkeypatch, responder)
|
||||||
|
|
||||||
|
service = OpenRouterIntegrationService()
|
||||||
|
service._settings = {}
|
||||||
|
await service._enrich_health([yaml_cfg, or_cfg])
|
||||||
|
|
||||||
|
assert all("anthropic/claude-haiku" in r for r in requests)
|
||||||
|
# YAML cfg is untouched.
|
||||||
|
assert yaml_cfg["quality_score"] == 80
|
||||||
|
assert yaml_cfg["health_gated"] is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Failure ratio fallback
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_enrich_health_falls_back_to_last_good_when_failure_ratio_high(
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
|
"""If >= 25% of fetches fail, keep last-good cache instead of writing
|
||||||
|
partial data."""
|
||||||
|
cfgs = [
|
||||||
|
_or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
|
||||||
|
_or_cfg(cid=-2, model_name="openai/gpt-5", static_score=80),
|
||||||
|
_or_cfg(cid=-3, model_name="google/gemini-flash", static_score=65),
|
||||||
|
_or_cfg(cid=-4, model_name="venice/something", static_score=50),
|
||||||
|
]
|
||||||
|
|
||||||
|
service = OpenRouterIntegrationService()
|
||||||
|
service._settings = {}
|
||||||
|
# Pre-seed last-good cache with a known-healthy snapshot.
|
||||||
|
service._health_cache = {
|
||||||
|
"anthropic/claude-haiku": {"gated": False, "score": 95.0},
|
||||||
|
}
|
||||||
|
|
||||||
|
def all_fail(_url: str) -> _StubResponse:
|
||||||
|
return _StubResponse(payload={}, status_code=500)
|
||||||
|
|
||||||
|
_patch_async_client(monkeypatch, all_fail)
|
||||||
|
await service._enrich_health(cfgs)
|
||||||
|
|
||||||
|
# Above threshold ⇒ degraded; last-good cache wins for the cached cfg.
|
||||||
|
cached_hit = next(c for c in cfgs if c["model_name"] == "anthropic/claude-haiku")
|
||||||
|
assert cached_hit["quality_score_health"] == 95.0
|
||||||
|
assert cached_hit["health_gated"] is False
|
||||||
|
# Confirm the threshold constant we're testing against is real.
|
||||||
|
assert _HEALTH_FAIL_RATIO_FALLBACK <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_enrich_health_keeps_static_only_with_no_cache_and_failures(
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
|
"""If a fetch fails and there's no last-good cache, the cfg keeps its
|
||||||
|
static-only ``quality_score`` and is *not* gated by default."""
|
||||||
|
cfgs = [
|
||||||
|
_or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70),
|
||||||
|
]
|
||||||
|
|
||||||
|
def fail(_url: str) -> _StubResponse:
|
||||||
|
return _StubResponse(payload={}, status_code=500)
|
||||||
|
|
||||||
|
_patch_async_client(monkeypatch, fail)
|
||||||
|
|
||||||
|
service = OpenRouterIntegrationService()
|
||||||
|
service._settings = {}
|
||||||
|
await service._enrich_health(cfgs)
|
||||||
|
|
||||||
|
cfg = cfgs[0]
|
||||||
|
assert cfg["health_gated"] is False
|
||||||
|
assert cfg["quality_score"] == cfg["quality_score_static"]
|
||||||
|
assert cfg["quality_score_health"] is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Last-good cache: success populates, next failure reuses
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_enrich_health_populates_cache_on_success_then_reuses_on_failure(
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
|
cfg = _or_cfg(cid=-1, model_name="anthropic/claude-haiku", static_score=70)
|
||||||
|
|
||||||
|
service = OpenRouterIntegrationService()
|
||||||
|
service._settings = {}
|
||||||
|
|
||||||
|
def healthy(_url: str) -> _StubResponse:
|
||||||
|
return _StubResponse(payload=_healthy_payload())
|
||||||
|
|
||||||
|
_patch_async_client(monkeypatch, healthy)
|
||||||
|
await service._enrich_health([cfg])
|
||||||
|
|
||||||
|
assert "anthropic/claude-haiku" in service._health_cache
|
||||||
|
cached_score = service._health_cache["anthropic/claude-haiku"]["score"]
|
||||||
|
assert cached_score is not None
|
||||||
|
|
||||||
|
# Next cycle: enough other healthy cfgs so failure ratio stays below
|
||||||
|
# the 25% threshold even when this one fails individually.
|
||||||
|
other_cfgs = [
|
||||||
|
_or_cfg(cid=-2 - i, model_name=f"healthy/m-{i}", static_score=60)
|
||||||
|
for i in range(10)
|
||||||
|
]
|
||||||
|
cfg["quality_score_health"] = None
|
||||||
|
cfg["quality_score"] = cfg["quality_score_static"]
|
||||||
|
|
||||||
|
def mixed(url: str) -> _StubResponse:
|
||||||
|
if "anthropic" in url:
|
||||||
|
return _StubResponse(payload={}, status_code=500)
|
||||||
|
return _StubResponse(payload=_healthy_payload())
|
||||||
|
|
||||||
|
_patch_async_client(monkeypatch, mixed)
|
||||||
|
await service._enrich_health([cfg, *other_cfgs])
|
||||||
|
|
||||||
|
assert cfg["quality_score_health"] == cached_score
|
||||||
|
assert cfg["health_gated"] is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Bounded fan-out: respects top-N caps
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_enrich_health_bounds_premium_fanout(monkeypatch):
|
||||||
|
"""Top-N premium cap is honoured even when many cfgs are present."""
|
||||||
|
from app.services.quality_score import _HEALTH_ENRICH_TOP_N_PREMIUM
|
||||||
|
|
||||||
|
cfgs = [
|
||||||
|
_or_cfg(
|
||||||
|
cid=-i, model_name=f"openai/m-{i}", tier="premium", static_score=100 - i
|
||||||
|
)
|
||||||
|
for i in range(1, _HEALTH_ENRICH_TOP_N_PREMIUM + 20)
|
||||||
|
]
|
||||||
|
|
||||||
|
seen: list[str] = []
|
||||||
|
|
||||||
|
def responder(url: str) -> _StubResponse:
|
||||||
|
seen.append(url)
|
||||||
|
return _StubResponse(payload=_healthy_payload())
|
||||||
|
|
||||||
|
_patch_async_client(monkeypatch, responder)
|
||||||
|
|
||||||
|
service = OpenRouterIntegrationService()
|
||||||
|
service._settings = {}
|
||||||
|
await service._enrich_health(cfgs)
|
||||||
|
|
||||||
|
assert len(seen) == _HEALTH_ENRICH_TOP_N_PREMIUM
|
||||||
|
|
||||||
|
|
||||||
|
async def test_enrich_health_no_or_cfgs_is_noop(monkeypatch):
|
||||||
|
"""When the catalogue has no OR cfgs at all, no HTTP calls fire."""
|
||||||
|
yaml_cfg: dict[str, Any] = {
|
||||||
|
"id": -1,
|
||||||
|
"provider": "AZURE_OPENAI",
|
||||||
|
"model_name": "gpt-5",
|
||||||
|
"billing_tier": "premium",
|
||||||
|
}
|
||||||
|
requests: list[str] = []
|
||||||
|
|
||||||
|
def responder(url: str) -> _StubResponse:
|
||||||
|
requests.append(url)
|
||||||
|
return _StubResponse(payload=_healthy_payload())
|
||||||
|
|
||||||
|
_patch_async_client(monkeypatch, responder)
|
||||||
|
|
||||||
|
service = OpenRouterIntegrationService()
|
||||||
|
service._settings = {}
|
||||||
|
await service._enrich_health([yaml_cfg])
|
||||||
|
assert requests == []
|
||||||
345
surfsense_backend/tests/unit/services/test_quality_score.py
Normal file
345
surfsense_backend/tests/unit/services/test_quality_score.py
Normal file
|
|
@ -0,0 +1,345 @@
|
||||||
|
"""Unit tests for the Auto (Fastest) quality scoring module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.quality_score import (
|
||||||
|
_HEALTH_GATE_UPTIME_PCT,
|
||||||
|
_OPERATOR_TRUST_BONUS,
|
||||||
|
aggregate_health,
|
||||||
|
capabilities_signal,
|
||||||
|
context_signal,
|
||||||
|
created_recency_signal,
|
||||||
|
pricing_band,
|
||||||
|
slug_penalty,
|
||||||
|
static_score_or,
|
||||||
|
static_score_yaml,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# created_recency_signal
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_created_recency_signal_recent_model_scores_high():
|
||||||
|
now = 1_750_000_000 # ~mid-2025
|
||||||
|
one_month_ago = now - (30 * 86_400)
|
||||||
|
assert created_recency_signal(one_month_ago, now) == 20
|
||||||
|
|
||||||
|
|
||||||
|
def test_created_recency_signal_old_model_scores_zero():
|
||||||
|
now = 1_750_000_000
|
||||||
|
five_years_ago = now - (5 * 365 * 86_400)
|
||||||
|
assert created_recency_signal(five_years_ago, now) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_created_recency_signal_missing_timestamp_is_neutral():
|
||||||
|
now = 1_750_000_000
|
||||||
|
assert created_recency_signal(None, now) == 0
|
||||||
|
assert created_recency_signal(0, now) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_created_recency_signal_monotonic_decay():
|
||||||
|
now = 1_750_000_000
|
||||||
|
scores = [
|
||||||
|
created_recency_signal(now - days * 86_400, now)
|
||||||
|
for days in (30, 120, 300, 500, 700, 1000, 1500)
|
||||||
|
]
|
||||||
|
assert scores == sorted(scores, reverse=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# pricing_band
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_pricing_band_free_returns_zero():
|
||||||
|
assert pricing_band("0", "0") == 0
|
||||||
|
assert pricing_band(0.0, 0.0) == 0
|
||||||
|
assert pricing_band(None, None) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_pricing_band_handles_unparseable():
|
||||||
|
assert pricing_band("not-a-number", "0") == 0
|
||||||
|
assert pricing_band({}, []) == 0 # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def test_pricing_band_premium_tiers_increase_with_price():
|
||||||
|
cheap = pricing_band("0.0000003", "0.0000005")
|
||||||
|
mid = pricing_band("0.000003", "0.000015")
|
||||||
|
flagship = pricing_band("0.00001", "0.00005")
|
||||||
|
assert 0 < cheap < mid < flagship
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# context_signal
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"ctx,expected",
|
||||||
|
[
|
||||||
|
(1_500_000, 10),
|
||||||
|
(1_000_000, 10),
|
||||||
|
(500_000, 8),
|
||||||
|
(200_000, 6),
|
||||||
|
(128_000, 4),
|
||||||
|
(100_000, 2),
|
||||||
|
(50_000, 0),
|
||||||
|
(0, 0),
|
||||||
|
(None, 0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_context_signal_bands(ctx, expected):
|
||||||
|
assert context_signal(ctx) == expected
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# capabilities_signal
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_capabilities_signal_caps_at_five():
|
||||||
|
assert (
|
||||||
|
capabilities_signal(
|
||||||
|
["tools", "structured_outputs", "reasoning", "include_reasoning"]
|
||||||
|
)
|
||||||
|
<= 5
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_capabilities_signal_tools_only():
|
||||||
|
assert capabilities_signal(["tools"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_capabilities_signal_empty():
|
||||||
|
assert capabilities_signal(None) == 0
|
||||||
|
assert capabilities_signal([]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# slug_penalty
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_slug_penalty_demotes_tiny_models():
|
||||||
|
assert slug_penalty("meta-llama/llama-3.2-1b-instruct") < 0
|
||||||
|
assert slug_penalty("liquid/lfm-7b") < 0
|
||||||
|
assert slug_penalty("google/gemma-3n-e4b-it") < 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_slug_penalty_skips_capable_mini_nano_lite_models():
|
||||||
|
"""Critical Option C+ regression: don't penalise modern frontier
|
||||||
|
models named ``-nano`` / ``-mini`` / ``-lite`` (gpt-5-mini, etc.)."""
|
||||||
|
assert slug_penalty("openai/gpt-5-mini") == 0
|
||||||
|
assert slug_penalty("openai/gpt-5-nano") == 0
|
||||||
|
assert slug_penalty("google/gemini-2.5-flash-lite") == 0
|
||||||
|
assert slug_penalty("anthropic/claude-haiku-4.5") == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_slug_penalty_demotes_legacy_variants():
|
||||||
|
assert slug_penalty("openai/o1-preview") < 0
|
||||||
|
assert slug_penalty("foo/bar-base") < 0
|
||||||
|
assert slug_penalty("foo/bar-distill") < 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_slug_penalty_empty_input():
|
||||||
|
assert slug_penalty("") == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# static_score_or
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _or_model(
|
||||||
|
*,
|
||||||
|
model_id: str,
|
||||||
|
created: int | None = None,
|
||||||
|
prompt: str = "0.000003",
|
||||||
|
completion: str = "0.000015",
|
||||||
|
context: int = 200_000,
|
||||||
|
params: list[str] | None = None,
|
||||||
|
) -> dict:
|
||||||
|
return {
|
||||||
|
"id": model_id,
|
||||||
|
"created": created,
|
||||||
|
"pricing": {"prompt": prompt, "completion": completion},
|
||||||
|
"context_length": context,
|
||||||
|
"supported_parameters": params if params is not None else ["tools"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_static_score_or_frontier_premium_beats_free_tiny():
|
||||||
|
now = 1_750_000_000
|
||||||
|
frontier = _or_model(
|
||||||
|
model_id="openai/gpt-5",
|
||||||
|
created=now - (60 * 86_400),
|
||||||
|
prompt="0.000005",
|
||||||
|
completion="0.000020",
|
||||||
|
context=400_000,
|
||||||
|
params=["tools", "structured_outputs", "reasoning"],
|
||||||
|
)
|
||||||
|
tiny_free = _or_model(
|
||||||
|
model_id="meta-llama/llama-3.2-1b-instruct:free",
|
||||||
|
created=now - (5 * 365 * 86_400),
|
||||||
|
prompt="0",
|
||||||
|
completion="0",
|
||||||
|
context=128_000,
|
||||||
|
params=["tools"],
|
||||||
|
)
|
||||||
|
assert static_score_or(frontier, now_ts=now) > static_score_or(
|
||||||
|
tiny_free, now_ts=now
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_static_score_or_score_is_clamped_0_to_100():
|
||||||
|
now = int(time.time())
|
||||||
|
score = static_score_or(_or_model(model_id="openai/gpt-4o"), now_ts=now)
|
||||||
|
assert 0 <= score <= 100
|
||||||
|
|
||||||
|
|
||||||
|
def test_static_score_or_unknown_provider_is_neutral_not_zero():
|
||||||
|
now = int(time.time())
|
||||||
|
score = static_score_or(
|
||||||
|
_or_model(model_id="some-new-lab/some-model"),
|
||||||
|
now_ts=now,
|
||||||
|
)
|
||||||
|
assert score > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_static_score_or_recent_release_beats_year_old_same_provider():
|
||||||
|
now = 1_750_000_000
|
||||||
|
fresh = _or_model(model_id="openai/gpt-5", created=now - (60 * 86_400))
|
||||||
|
old = _or_model(model_id="openai/gpt-4-turbo", created=now - (700 * 86_400))
|
||||||
|
assert static_score_or(fresh, now_ts=now) > static_score_or(old, now_ts=now)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# static_score_yaml
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_static_score_yaml_includes_operator_bonus():
|
||||||
|
cfg = {
|
||||||
|
"provider": "AZURE_OPENAI",
|
||||||
|
"model_name": "gpt-5",
|
||||||
|
"litellm_params": {"base_model": "azure/gpt-5"},
|
||||||
|
}
|
||||||
|
score = static_score_yaml(cfg)
|
||||||
|
assert score >= _OPERATOR_TRUST_BONUS
|
||||||
|
|
||||||
|
|
||||||
|
def test_static_score_yaml_unknown_provider_still_carries_bonus():
|
||||||
|
cfg = {
|
||||||
|
"provider": "SOME_NEW_PROVIDER",
|
||||||
|
"model_name": "weird-model",
|
||||||
|
}
|
||||||
|
score = static_score_yaml(cfg)
|
||||||
|
assert score >= _OPERATOR_TRUST_BONUS
|
||||||
|
|
||||||
|
|
||||||
|
def test_static_score_yaml_clamped_0_to_100():
|
||||||
|
cfg = {
|
||||||
|
"provider": "AZURE_OPENAI",
|
||||||
|
"model_name": "gpt-5",
|
||||||
|
"litellm_params": {"base_model": "azure/gpt-5"},
|
||||||
|
}
|
||||||
|
assert 0 <= static_score_yaml(cfg) <= 100
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# aggregate_health
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_health_gates_when_uptime_below_threshold():
|
||||||
|
"""Live data showed Venice-routed cfgs at 53-68%; this guards that the
|
||||||
|
90% gate excludes them."""
|
||||||
|
venice_endpoints = [
|
||||||
|
{
|
||||||
|
"status": 0,
|
||||||
|
"uptime_last_30m": 0.55,
|
||||||
|
"uptime_last_1d": 0.60,
|
||||||
|
"uptime_last_5m": 0.50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"status": 0,
|
||||||
|
"uptime_last_30m": 0.65,
|
||||||
|
"uptime_last_1d": 0.68,
|
||||||
|
"uptime_last_5m": 0.62,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
gated, score = aggregate_health(venice_endpoints)
|
||||||
|
assert gated is True
|
||||||
|
assert score is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_health_passes_for_healthy_provider():
|
||||||
|
healthy = [
|
||||||
|
{
|
||||||
|
"status": 0,
|
||||||
|
"uptime_last_30m": 0.99,
|
||||||
|
"uptime_last_1d": 0.995,
|
||||||
|
"uptime_last_5m": 0.99,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
gated, score = aggregate_health(healthy)
|
||||||
|
assert gated is False
|
||||||
|
assert score is not None
|
||||||
|
assert score >= _HEALTH_GATE_UPTIME_PCT
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_health_picks_best_endpoint_across_multiple():
|
||||||
|
"""Multi-endpoint aggregation should reward the best non-null uptime."""
|
||||||
|
mixed = [
|
||||||
|
{"status": 0, "uptime_last_30m": 0.55},
|
||||||
|
{"status": 0, "uptime_last_30m": 0.97}, # this one passes the gate
|
||||||
|
]
|
||||||
|
gated, score = aggregate_health(mixed)
|
||||||
|
assert gated is False
|
||||||
|
assert score is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_health_empty_endpoints_gated():
|
||||||
|
gated, score = aggregate_health([])
|
||||||
|
assert gated is True
|
||||||
|
assert score is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_health_no_status_zero_gated():
|
||||||
|
"""Even with high uptime, no OK status means the cfg is broken upstream."""
|
||||||
|
endpoints = [
|
||||||
|
{"status": 1, "uptime_last_30m": 0.99},
|
||||||
|
{"status": 2, "uptime_last_30m": 0.98},
|
||||||
|
]
|
||||||
|
gated, score = aggregate_health(endpoints)
|
||||||
|
assert gated is True
|
||||||
|
assert score is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_health_all_uptime_null_gated():
|
||||||
|
endpoints = [
|
||||||
|
{"status": 0, "uptime_last_30m": None, "uptime_last_1d": None},
|
||||||
|
]
|
||||||
|
gated, score = aggregate_health(endpoints)
|
||||||
|
assert gated is True
|
||||||
|
assert score is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_health_pct_normalisation():
|
||||||
|
"""OpenRouter returns 0-1 fractions; some endpoints surface 0-100%
|
||||||
|
percentages. Both should reach the same gate decision."""
|
||||||
|
fraction_form = [{"status": 0, "uptime_last_30m": 0.95}]
|
||||||
|
pct_form = [{"status": 0, "uptime_last_30m": 95.0}]
|
||||||
|
g1, s1 = aggregate_health(fraction_form)
|
||||||
|
g2, s2 = aggregate_health(pct_form)
|
||||||
|
assert g1 == g2 == False # noqa: E712
|
||||||
|
assert s1 is not None and s2 is not None
|
||||||
|
assert abs(s1 - s2) < 0.5
|
||||||
|
|
@ -0,0 +1,370 @@
|
||||||
|
"""Unit tests for the filesystem-tool branches of ``revert_service``.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
|
||||||
|
* Exact-name dispatch — ``rmdir`` does NOT mis-route to the document
|
||||||
|
branch (``"rmdir".startswith("rm")`` would mis-route under the legacy
|
||||||
|
prefix-based dispatch).
|
||||||
|
* ``rm`` revert re-INSERTs a fresh document from the snapshot, including
|
||||||
|
re-creating chunks. Falls back to ``(folder_id_before, title_before)``
|
||||||
|
when ``metadata_before["virtual_path"]`` is missing.
|
||||||
|
* ``write_file`` create-revert (``content_before IS NULL``) DELETEs the
|
||||||
|
document.
|
||||||
|
* ``rmdir`` revert re-INSERTs a fresh folder from the snapshot.
|
||||||
|
* ``mkdir`` revert DELETEs the empty folder; reports ``tool_unavailable``
|
||||||
|
when the folder gained children.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services import revert_service
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _stub_embeddings(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
revert_service,
|
||||||
|
"embed_texts",
|
||||||
|
lambda texts: [np.zeros(8, dtype=np.float32) for _ in texts],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeResult:
|
||||||
|
def __init__(self, rows: list[Any] | None = None, scalar: Any = None) -> None:
|
||||||
|
self._rows = rows or []
|
||||||
|
self._scalar = scalar
|
||||||
|
|
||||||
|
def all(self) -> list[Any]:
|
||||||
|
return list(self._rows)
|
||||||
|
|
||||||
|
def scalar_one_or_none(self) -> Any:
|
||||||
|
return self._scalar
|
||||||
|
|
||||||
|
def scalars(self) -> Any:
|
||||||
|
return _FakeScalarsProxy(self._rows)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeScalarsProxy:
|
||||||
|
def __init__(self, rows: list[Any]) -> None:
|
||||||
|
self._rows = rows
|
||||||
|
|
||||||
|
def first(self) -> Any:
|
||||||
|
return self._rows[0] if self._rows else None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.execute = AsyncMock()
|
||||||
|
self.added: list[Any] = []
|
||||||
|
self.deleted: list[Any] = []
|
||||||
|
self.flush = AsyncMock()
|
||||||
|
# session.get(Model, pk) lookup
|
||||||
|
self.get = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
async def _flush_assigning_ids() -> None:
|
||||||
|
for obj in self.added:
|
||||||
|
if getattr(obj, "id", None) is None:
|
||||||
|
obj.id = 999
|
||||||
|
|
||||||
|
self.flush.side_effect = _flush_assigning_ids
|
||||||
|
|
||||||
|
def add(self, obj: Any) -> None:
|
||||||
|
self.added.append(obj)
|
||||||
|
|
||||||
|
def add_all(self, objs: list[Any]) -> None:
|
||||||
|
self.added.extend(objs)
|
||||||
|
|
||||||
|
|
||||||
|
def _action(*, tool_name: str, action_id: int = 7):
|
||||||
|
return MagicMock(
|
||||||
|
id=action_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
thread_id=1,
|
||||||
|
search_space_id=2,
|
||||||
|
user_id="user-1",
|
||||||
|
reverse_descriptor=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _doc_revision(
|
||||||
|
*,
|
||||||
|
document_id: int | None = None,
|
||||||
|
content_before: str | None = "old content",
|
||||||
|
title_before: str | None = "notes.md",
|
||||||
|
folder_id_before: int | None = 5,
|
||||||
|
chunks_before: list[dict[str, str]] | None = None,
|
||||||
|
metadata_before: dict[str, str] | None = None,
|
||||||
|
):
|
||||||
|
revision = MagicMock()
|
||||||
|
revision.id = 100
|
||||||
|
revision.document_id = document_id
|
||||||
|
revision.search_space_id = 2
|
||||||
|
revision.content_before = content_before
|
||||||
|
revision.title_before = title_before
|
||||||
|
revision.folder_id_before = folder_id_before
|
||||||
|
revision.chunks_before = chunks_before or []
|
||||||
|
revision.metadata_before = metadata_before
|
||||||
|
return revision
|
||||||
|
|
||||||
|
|
||||||
|
def _folder_revision(
|
||||||
|
*,
|
||||||
|
folder_id: int | None = None,
|
||||||
|
name_before: str | None = "team",
|
||||||
|
parent_id_before: int | None = None,
|
||||||
|
position_before: str | None = "a0",
|
||||||
|
):
|
||||||
|
revision = MagicMock()
|
||||||
|
revision.id = 200
|
||||||
|
revision.folder_id = folder_id
|
||||||
|
revision.search_space_id = 2
|
||||||
|
revision.name_before = name_before
|
||||||
|
revision.parent_id_before = parent_id_before
|
||||||
|
revision.position_before = position_before
|
||||||
|
return revision
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Exact-name dispatch regression guards
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestExactDispatch:
|
||||||
|
"""Regression: ``rmdir`` MUST NOT route to the document branch."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rmdir_does_not_misroute_to_document(self) -> None:
|
||||||
|
# If dispatch used `startswith("rm")` we'd hit the document branch
|
||||||
|
# here. With exact-name lookup `rmdir` lands in `_FOLDER_TOOLS`.
|
||||||
|
session = _FakeSession()
|
||||||
|
action = _action(tool_name="rmdir")
|
||||||
|
# No folder revisions exist for this action.
|
||||||
|
session.execute.return_value = _FakeResult(rows=[])
|
||||||
|
outcome = await revert_service.revert_action(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
action=action,
|
||||||
|
requester_user_id="user-1",
|
||||||
|
)
|
||||||
|
assert outcome.status == "not_reversible"
|
||||||
|
assert "folder_revisions" in outcome.message
|
||||||
|
|
||||||
|
def test_dispatch_sets_split_doc_and_folder(self) -> None:
|
||||||
|
# Static guards on the dispatch tables themselves so a future
|
||||||
|
# refactor doesn't accidentally reintroduce the prefix bug.
|
||||||
|
assert "rm" in revert_service._DOC_TOOLS
|
||||||
|
assert "rmdir" in revert_service._FOLDER_TOOLS
|
||||||
|
assert "rmdir" not in revert_service._DOC_TOOLS
|
||||||
|
assert "rm" not in revert_service._FOLDER_TOOLS
|
||||||
|
# ``move_file`` lives only in document tools (it's a doc rename).
|
||||||
|
assert "move_file" in revert_service._DOC_TOOLS
|
||||||
|
assert "move_file" not in revert_service._FOLDER_TOOLS
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# rm revert (re-INSERT)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRmRevert:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_re_inserts_document_with_chunks(self) -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
revision = _doc_revision(
|
||||||
|
document_id=None, # row was hard-deleted
|
||||||
|
content_before="hello world",
|
||||||
|
title_before="x.md",
|
||||||
|
folder_id_before=None,
|
||||||
|
chunks_before=[{"content": "alpha"}, {"content": "beta"}],
|
||||||
|
metadata_before={"virtual_path": "/documents/x.md"},
|
||||||
|
)
|
||||||
|
# No collision check hit and the resulting query returns nothing.
|
||||||
|
session.execute.return_value = _FakeResult(scalar=None)
|
||||||
|
|
||||||
|
outcome = await revert_service._reinsert_document_from_revision(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert outcome.status == "ok"
|
||||||
|
# New Document + 2 chunks must have been added.
|
||||||
|
from app.db import Chunk, Document
|
||||||
|
|
||||||
|
added_docs = [obj for obj in session.added if isinstance(obj, Document)]
|
||||||
|
added_chunks = [obj for obj in session.added if isinstance(obj, Chunk)]
|
||||||
|
assert len(added_docs) == 1
|
||||||
|
assert added_docs[0].title == "x.md"
|
||||||
|
assert len(added_chunks) == 2
|
||||||
|
# Snapshot was repointed at the new doc id so a follow-up revert works.
|
||||||
|
assert revision.document_id == added_docs[0].id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_falls_back_to_folder_id_and_title_for_virtual_path(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
# Snapshot with NO metadata_before — the fallback path must kick in.
|
||||||
|
revision = _doc_revision(
|
||||||
|
document_id=None,
|
||||||
|
content_before="hello",
|
||||||
|
title_before="cap.md",
|
||||||
|
folder_id_before=42,
|
||||||
|
chunks_before=[],
|
||||||
|
metadata_before=None,
|
||||||
|
)
|
||||||
|
# session.get(Folder, 42) returns a folder with a name.
|
||||||
|
folder = MagicMock()
|
||||||
|
folder.name = "team"
|
||||||
|
folder.parent_id = None
|
||||||
|
# First .get is for the folder lookup in the path-derivation.
|
||||||
|
session.get = AsyncMock(return_value=folder)
|
||||||
|
session.execute.return_value = _FakeResult(scalar=None)
|
||||||
|
|
||||||
|
outcome = await revert_service._reinsert_document_from_revision(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
assert outcome.status == "ok"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_falls_back_to_root_path_when_no_folder(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
"""metadata_before is None and folder_id_before is None still
|
||||||
|
resolves: title fallback yields ``/documents/<title>`` so revert
|
||||||
|
proceeds at the root of the documents tree."""
|
||||||
|
session = _FakeSession()
|
||||||
|
revision = _doc_revision(
|
||||||
|
document_id=None,
|
||||||
|
content_before="hello",
|
||||||
|
title_before="x.md",
|
||||||
|
folder_id_before=None,
|
||||||
|
metadata_before=None,
|
||||||
|
)
|
||||||
|
# No collision in the documents tree at /documents/x.md.
|
||||||
|
session.execute.return_value = _FakeResult(scalar=None)
|
||||||
|
outcome = await revert_service._reinsert_document_from_revision(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
assert outcome.status == "ok"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_collision_with_live_doc_returns_tool_unavailable(self) -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
revision = _doc_revision(
|
||||||
|
document_id=None,
|
||||||
|
content_before="hi",
|
||||||
|
title_before="x.md",
|
||||||
|
folder_id_before=None,
|
||||||
|
metadata_before={"virtual_path": "/documents/x.md"},
|
||||||
|
)
|
||||||
|
# SELECT for unique_identifier_hash collision hits an existing row.
|
||||||
|
session.execute.return_value = _FakeResult(scalar=42)
|
||||||
|
outcome = await revert_service._reinsert_document_from_revision(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
assert outcome.status == "tool_unavailable"
|
||||||
|
assert "collide" in outcome.message
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# write_file create revert (DELETE)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestWriteFileCreateRevert:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deletes_created_doc(self) -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
revision = _doc_revision(
|
||||||
|
document_id=99,
|
||||||
|
content_before=None, # marker for "created in this action"
|
||||||
|
title_before=None,
|
||||||
|
)
|
||||||
|
outcome = await revert_service._delete_created_document(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
assert outcome.status == "ok"
|
||||||
|
# Exactly one DELETE was issued.
|
||||||
|
assert session.execute.await_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# rmdir revert (re-INSERT folder)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRmdirRevert:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_re_inserts_folder_from_snapshot(self) -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
revision = _folder_revision(
|
||||||
|
folder_id=None,
|
||||||
|
name_before="team",
|
||||||
|
parent_id_before=None,
|
||||||
|
position_before="a0",
|
||||||
|
)
|
||||||
|
outcome = await revert_service._reinsert_folder_from_revision(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
from app.db import Folder
|
||||||
|
|
||||||
|
assert outcome.status == "ok"
|
||||||
|
added_folders = [obj for obj in session.added if isinstance(obj, Folder)]
|
||||||
|
assert len(added_folders) == 1
|
||||||
|
assert added_folders[0].name == "team"
|
||||||
|
assert revision.folder_id == added_folders[0].id
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# mkdir revert (DELETE folder)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMkdirRevert:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deletes_empty_folder(self) -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
revision = _folder_revision(folder_id=42)
|
||||||
|
# Both the doc-existence check and the child-folder check return None.
|
||||||
|
session.execute.side_effect = [
|
||||||
|
_FakeResult(scalar=None), # docs
|
||||||
|
_FakeResult(scalar=None), # children
|
||||||
|
_FakeResult(scalar=None), # delete (no return value)
|
||||||
|
]
|
||||||
|
outcome = await revert_service._delete_created_folder(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
assert outcome.status == "ok"
|
||||||
|
# 3 executes: docs check, children check, delete.
|
||||||
|
assert session.execute.await_count == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reports_tool_unavailable_when_folder_has_children(self) -> None:
|
||||||
|
session = _FakeSession()
|
||||||
|
revision = _folder_revision(folder_id=42)
|
||||||
|
# First check (docs) returns "row found".
|
||||||
|
session.execute.return_value = _FakeResult(scalar=1)
|
||||||
|
outcome = await revert_service._delete_created_folder(
|
||||||
|
session, # type: ignore[arg-type]
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
assert outcome.status == "tool_unavailable"
|
||||||
|
assert "no longer empty" in outcome.message
|
||||||
0
surfsense_backend/tests/unit/tasks/__init__.py
Normal file
0
surfsense_backend/tests/unit/tasks/__init__.py
Normal file
0
surfsense_backend/tests/unit/tasks/chat/__init__.py
Normal file
0
surfsense_backend/tests/unit/tasks/chat/__init__.py
Normal file
|
|
@ -0,0 +1,228 @@
|
||||||
|
"""Unit tests for ``stream_new_chat._extract_chunk_parts``.
|
||||||
|
|
||||||
|
Earlier versions only handled ``isinstance(chunk.content, str)`` and
|
||||||
|
silently dropped every other shape (Anthropic typed-block lists,
|
||||||
|
Bedrock reasoning blocks, ``additional_kwargs.reasoning_content`` from
|
||||||
|
a few providers). These regression tests pin those four shapes plus the
|
||||||
|
defensive cases (``None`` chunk, mixed types, missing fields).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.tasks.chat.stream_new_chat import _extract_chunk_parts
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _FakeChunk:
|
||||||
|
"""Minimal stand-in for ``AIMessageChunk`` used in unit tests."""
|
||||||
|
|
||||||
|
content: Any = ""
|
||||||
|
additional_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||||
|
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStringContent:
|
||||||
|
def test_plain_string_content_extracts_as_text(self) -> None:
|
||||||
|
chunk = _FakeChunk(content="hello world")
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["text"] == "hello world"
|
||||||
|
assert out["reasoning"] == ""
|
||||||
|
assert out["tool_call_chunks"] == []
|
||||||
|
|
||||||
|
def test_empty_string_content_yields_empty_text(self) -> None:
|
||||||
|
chunk = _FakeChunk(content="")
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["text"] == ""
|
||||||
|
assert out["reasoning"] == ""
|
||||||
|
assert out["tool_call_chunks"] == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestListContent:
|
||||||
|
def test_list_of_text_blocks_concatenates(self) -> None:
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "Hello "},
|
||||||
|
{"type": "text", "text": "world"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["text"] == "Hello world"
|
||||||
|
assert out["reasoning"] == ""
|
||||||
|
|
||||||
|
def test_mixed_text_and_reasoning_blocks(self) -> None:
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
content=[
|
||||||
|
{"type": "reasoning", "reasoning": "Let me think... "},
|
||||||
|
{"type": "reasoning", "text": "still thinking."},
|
||||||
|
{"type": "text", "text": "The answer is 42."},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["text"] == "The answer is 42."
|
||||||
|
assert out["reasoning"] == "Let me think... still thinking."
|
||||||
|
|
||||||
|
def test_tool_call_chunks_in_content_list_extracted(self) -> None:
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "Calling tool..."},
|
||||||
|
{
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
"id": "call_123",
|
||||||
|
"name": "make_widget",
|
||||||
|
"args": '{"color":"red"}',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["text"] == "Calling tool..."
|
||||||
|
assert out["reasoning"] == ""
|
||||||
|
assert len(out["tool_call_chunks"]) == 1
|
||||||
|
assert out["tool_call_chunks"][0]["id"] == "call_123"
|
||||||
|
assert out["tool_call_chunks"][0]["name"] == "make_widget"
|
||||||
|
|
||||||
|
def test_tool_use_blocks_also_extracted(self) -> None:
|
||||||
|
"""Some providers (Anthropic) emit ``type='tool_use'`` instead."""
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
content=[
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "call_xyz",
|
||||||
|
"name": "search",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["tool_call_chunks"] == [
|
||||||
|
{"type": "tool_use", "id": "call_xyz", "name": "search"}
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_unknown_block_types_are_ignored(self) -> None:
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
content=[
|
||||||
|
{"type": "image_url", "url": "https://example.com/x.png"},
|
||||||
|
{"type": "text", "text": "ok"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["text"] == "ok"
|
||||||
|
|
||||||
|
def test_blocks_without_text_field_are_ignored(self) -> None:
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
content=[
|
||||||
|
{"type": "text"}, # no text/content key
|
||||||
|
{"type": "text", "text": "kept"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["text"] == "kept"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdditionalKwargsReasoning:
|
||||||
|
def test_reasoning_content_in_additional_kwargs(self) -> None:
|
||||||
|
"""Some providers stash reasoning in ``additional_kwargs.reasoning_content``."""
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
content="visible answer",
|
||||||
|
additional_kwargs={"reasoning_content": "internal monologue"},
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["text"] == "visible answer"
|
||||||
|
assert out["reasoning"] == "internal monologue"
|
||||||
|
|
||||||
|
def test_reasoning_appended_to_typed_block_reasoning(self) -> None:
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
content=[{"type": "reasoning", "text": "from blocks. "}],
|
||||||
|
additional_kwargs={"reasoning_content": "from kwargs."},
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["reasoning"] == "from blocks. from kwargs."
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolCallChunksAttribute:
|
||||||
|
def test_tool_call_chunks_attribute_extracted_alongside_string_content(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
content="streaming text",
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"name": "save_document", "args": '{"title":"x"}', "id": "tc-9"}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert out["text"] == "streaming text"
|
||||||
|
assert len(out["tool_call_chunks"]) == 1
|
||||||
|
assert out["tool_call_chunks"][0]["id"] == "tc-9"
|
||||||
|
|
||||||
|
def test_attribute_and_typed_block_chunks_both_collected(self) -> None:
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
content=[
|
||||||
|
{
|
||||||
|
"type": "tool_call_chunk",
|
||||||
|
"id": "from-block",
|
||||||
|
"name": "x",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_call_chunks=[{"id": "from-attr", "name": "y"}],
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
ids = [tcc.get("id") for tcc in out["tool_call_chunks"]]
|
||||||
|
assert ids == ["from-block", "from-attr"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestDefensive:
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"chunk_value",
|
||||||
|
[None, _FakeChunk(content=None), _FakeChunk(content=42)],
|
||||||
|
)
|
||||||
|
def test_invalid_chunk_returns_empty_parts(self, chunk_value: Any) -> None:
|
||||||
|
out = _extract_chunk_parts(chunk_value)
|
||||||
|
assert out["text"] == ""
|
||||||
|
assert out["reasoning"] == ""
|
||||||
|
assert out["tool_call_chunks"] == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestIdlessContinuationChunks:
|
||||||
|
"""Per LangChain ``ToolCallChunk`` semantics, the FIRST chunk for a
|
||||||
|
tool call carries id+name; later chunks for the same call have
|
||||||
|
``id=None, name=None`` and only ``args`` + ``index``. Live tool-call
|
||||||
|
argument streaming relies on those idless continuation chunks
|
||||||
|
flowing through ``_extract_chunk_parts`` UNTOUCHED so the upstream
|
||||||
|
chunk-emission loop can still route them by ``index``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_idless_continuation_chunk_preserved_verbatim(self) -> None:
|
||||||
|
chunk = _FakeChunk(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out = _extract_chunk_parts(chunk)
|
||||||
|
assert len(out["tool_call_chunks"]) == 1
|
||||||
|
tcc = out["tool_call_chunks"][0]
|
||||||
|
assert tcc.get("id") is None
|
||||||
|
assert tcc.get("name") is None
|
||||||
|
assert tcc.get("args") == '_path":"/x"}'
|
||||||
|
assert tcc.get("index") == 0
|
||||||
|
|
||||||
|
def test_first_then_idless_sequence_preserves_index(self) -> None:
|
||||||
|
"""Both chunks for the same call share an ``index`` key — the
|
||||||
|
index-routing loop in ``stream_new_chat`` depends on it."""
|
||||||
|
first = _FakeChunk(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
cont = _FakeChunk(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
out_first = _extract_chunk_parts(first)
|
||||||
|
out_cont = _extract_chunk_parts(cont)
|
||||||
|
assert out_first["tool_call_chunks"][0]["index"] == 0
|
||||||
|
assert out_cont["tool_call_chunks"][0]["index"] == 0
|
||||||
|
assert out_cont["tool_call_chunks"][0].get("id") is None
|
||||||
|
|
@ -0,0 +1,527 @@
|
||||||
|
"""Unit tests for live tool-call argument streaming.
|
||||||
|
|
||||||
|
Pins the wire format that ``_stream_agent_events`` emits when
|
||||||
|
``SURFSENSE_ENABLE_STREAM_PARITY_V2=true``: ``tool-input-start`` →
|
||||||
|
``tool-input-delta``... → ``tool-input-available`` → ``tool-output-available``
|
||||||
|
all keyed by the same LangChain ``tool_call.id``.
|
||||||
|
|
||||||
|
Identity is tracked in ``index_to_meta`` (per-chunk ``index``) and
|
||||||
|
``ui_tool_call_id_by_run`` (LangGraph ``run_id``); both are private to
|
||||||
|
``_stream_agent_events`` so we exercise them via the public wire output.
|
||||||
|
|
||||||
|
These tests also lock in the legacy / parity_v2-OFF behaviour so the
|
||||||
|
synthetic ``call_<run_id>`` shape stays stable for older clients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import app.tasks.chat.stream_new_chat as stream_module
|
||||||
|
from app.agents.new_chat.feature_flags import AgentFeatureFlags
|
||||||
|
from app.services.new_streaming_service import VercelStreamingService
|
||||||
|
from app.tasks.chat.stream_new_chat import (
|
||||||
|
StreamResult,
|
||||||
|
_legacy_match_lc_id,
|
||||||
|
_stream_agent_events,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _FakeChunk:
|
||||||
|
"""Minimal stand-in for ``AIMessageChunk``."""
|
||||||
|
|
||||||
|
content: Any = ""
|
||||||
|
additional_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||||
|
tool_call_chunks: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _FakeToolMessage:
|
||||||
|
"""Stand-in for ``ToolMessage`` returned by ``on_tool_end``."""
|
||||||
|
|
||||||
|
content: Any
|
||||||
|
tool_call_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeAgentState:
|
||||||
|
"""Stand-in for ``StateSnapshot`` returned by ``aget_state``."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# Empty values keeps the cloud-fallback safety-net branch a no-op,
|
||||||
|
# and an empty ``tasks`` list keeps the post-stream interrupt
|
||||||
|
# check a no-op too.
|
||||||
|
self.values: dict[str, Any] = {}
|
||||||
|
self.tasks: list[Any] = []
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeAgent:
|
||||||
|
"""Replays a list of ``astream_events`` events."""
|
||||||
|
|
||||||
|
def __init__(self, events: list[dict[str, Any]]) -> None:
|
||||||
|
self._events = events
|
||||||
|
|
||||||
|
async def astream_events( # type: ignore[no-untyped-def]
|
||||||
|
self, _input_data: Any, *, config: dict[str, Any], version: str
|
||||||
|
) -> AsyncGenerator[dict[str, Any], None]:
|
||||||
|
del config, version # unused, contract-compatible
|
||||||
|
for ev in self._events:
|
||||||
|
yield ev
|
||||||
|
|
||||||
|
async def aget_state(self, _config: dict[str, Any]) -> _FakeAgentState:
|
||||||
|
# Called once after astream_events drains so the cloud-fallback
|
||||||
|
# safety net can inspect staged filesystem work. The fake stays
|
||||||
|
# empty so the safety net is a no-op.
|
||||||
|
return _FakeAgentState()
|
||||||
|
|
||||||
|
|
||||||
|
def _model_stream(
|
||||||
|
*,
|
||||||
|
text: str = "",
|
||||||
|
reasoning: str = "",
|
||||||
|
tool_call_chunks: list[dict[str, Any]] | None = None,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return (
|
||||||
|
{
|
||||||
|
"event": "on_chat_model_stream",
|
||||||
|
"tags": tags or [],
|
||||||
|
"data": {
|
||||||
|
"chunk": _FakeChunk(
|
||||||
|
content=text,
|
||||||
|
tool_call_chunks=list(tool_call_chunks or []),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
# reasoning piggybacks via additional_kwargs path; if needed,
|
||||||
|
# override content to a typed-block list. Most tests just check
|
||||||
|
# tool_call_chunks routing so this is fine.
|
||||||
|
}
|
||||||
|
if not reasoning
|
||||||
|
else {
|
||||||
|
"event": "on_chat_model_stream",
|
||||||
|
"tags": tags or [],
|
||||||
|
"data": {
|
||||||
|
"chunk": _FakeChunk(
|
||||||
|
content=text,
|
||||||
|
additional_kwargs={"reasoning_content": reasoning},
|
||||||
|
tool_call_chunks=list(tool_call_chunks or []),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_start(
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
run_id: str,
|
||||||
|
input_payload: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"event": "on_tool_start",
|
||||||
|
"name": name,
|
||||||
|
"run_id": run_id,
|
||||||
|
"data": {"input": input_payload or {}},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_end(
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
run_id: str,
|
||||||
|
tool_call_id: str | None = None,
|
||||||
|
output: Any = "ok",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"event": "on_tool_end",
|
||||||
|
"name": name,
|
||||||
|
"run_id": run_id,
|
||||||
|
"data": {
|
||||||
|
"output": _FakeToolMessage(
|
||||||
|
content=json.dumps(output) if not isinstance(output, str) else output,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def parity_v2_on(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
stream_module,
|
||||||
|
"get_flags",
|
||||||
|
lambda: AgentFeatureFlags(enable_stream_parity_v2=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def parity_v2_off(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
stream_module,
|
||||||
|
"get_flags",
|
||||||
|
lambda: AgentFeatureFlags(enable_stream_parity_v2=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _drain(events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Run ``_stream_agent_events`` against a fake agent and return the
|
||||||
|
SSE payloads (parsed JSON) it yielded.
|
||||||
|
"""
|
||||||
|
agent = _FakeAgent(events)
|
||||||
|
service = VercelStreamingService()
|
||||||
|
result = StreamResult()
|
||||||
|
config = {"configurable": {"thread_id": "test-thread"}}
|
||||||
|
sse_lines: list[str] = []
|
||||||
|
async for sse in _stream_agent_events(
|
||||||
|
agent, config, {}, service, result, step_prefix="thinking"
|
||||||
|
):
|
||||||
|
sse_lines.append(sse)
|
||||||
|
|
||||||
|
parsed: list[dict[str, Any]] = []
|
||||||
|
for line in sse_lines:
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
body = line[len("data: ") :].rstrip("\n")
|
||||||
|
if not body or body == "[DONE]":
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
parsed.append(json.loads(body))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
def _types(payloads: list[dict[str, Any]]) -> list[str]:
|
||||||
|
return [p.get("type", "?") for p in payloads]
|
||||||
|
|
||||||
|
|
||||||
|
def _of_type(payloads: list[dict[str, Any]], type_name: str) -> list[dict[str, Any]]:
|
||||||
|
return [p for p in payloads if p.get("type") == type_name]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helper: ``_legacy_match_lc_id`` is a pure refactor; assert behaviour.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestLegacyMatch:
|
||||||
|
def test_pops_first_id_bearing_chunk_with_matching_name(self) -> None:
|
||||||
|
chunks: list[dict[str, Any]] = [
|
||||||
|
{"id": "x1", "name": "ls"},
|
||||||
|
{"id": "y1", "name": "write_file"},
|
||||||
|
]
|
||||||
|
runs: dict[str, str] = {}
|
||||||
|
result = _legacy_match_lc_id(chunks, "write_file", "run-1", runs)
|
||||||
|
assert result == "y1"
|
||||||
|
assert chunks == [{"id": "x1", "name": "ls"}]
|
||||||
|
assert runs == {"run-1": "y1"}
|
||||||
|
|
||||||
|
def test_falls_back_to_any_id_bearing_when_name_mismatches(self) -> None:
|
||||||
|
chunks: list[dict[str, Any]] = [{"id": "anon", "name": None}]
|
||||||
|
runs: dict[str, str] = {}
|
||||||
|
out = _legacy_match_lc_id(chunks, "ls", "run-2", runs)
|
||||||
|
assert out == "anon"
|
||||||
|
assert chunks == []
|
||||||
|
|
||||||
|
def test_returns_none_when_no_id_bearing_chunk(self) -> None:
|
||||||
|
chunks: list[dict[str, Any]] = [{"id": None, "name": None}]
|
||||||
|
runs: dict[str, str] = {}
|
||||||
|
assert _legacy_match_lc_id(chunks, "ls", "run-3", runs) is None
|
||||||
|
assert chunks == [{"id": None, "name": None}]
|
||||||
|
assert runs == {}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# parity_v2 wire format tests.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_idless_chunk_merging_by_index(parity_v2_on: None) -> None:
|
||||||
|
"""First chunk carries id+name; later idless chunks at the same
|
||||||
|
``index`` merge into the SAME ``tool-input-start`` ui id and emit
|
||||||
|
one ``tool-input-delta`` per chunk."""
|
||||||
|
events = [
|
||||||
|
_model_stream(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": "lc-1", "name": "write_file", "args": '{"file', "index": 0}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
_model_stream(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": None, "name": None, "args": '_path":"/x"}', "index": 0}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
_tool_start(
|
||||||
|
name="write_file", run_id="run-A", input_payload={"file_path": "/x"}
|
||||||
|
),
|
||||||
|
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
|
||||||
|
]
|
||||||
|
|
||||||
|
payloads = await _drain(events)
|
||||||
|
|
||||||
|
starts = _of_type(payloads, "tool-input-start")
|
||||||
|
deltas = _of_type(payloads, "tool-input-delta")
|
||||||
|
available = _of_type(payloads, "tool-input-available")
|
||||||
|
output = _of_type(payloads, "tool-output-available")
|
||||||
|
|
||||||
|
assert len(starts) == 1
|
||||||
|
assert starts[0]["toolCallId"] == "lc-1"
|
||||||
|
assert starts[0]["toolName"] == "write_file"
|
||||||
|
assert starts[0]["langchainToolCallId"] == "lc-1"
|
||||||
|
|
||||||
|
assert [d["inputTextDelta"] for d in deltas] == ['{"file', '_path":"/x"}']
|
||||||
|
assert all(d["toolCallId"] == "lc-1" for d in deltas)
|
||||||
|
|
||||||
|
assert len(available) == 1
|
||||||
|
assert available[0]["toolCallId"] == "lc-1"
|
||||||
|
|
||||||
|
assert len(output) == 1
|
||||||
|
assert output[0]["toolCallId"] == "lc-1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_two_interleaved_tool_calls_route_by_index(
|
||||||
|
parity_v2_on: None,
|
||||||
|
) -> None:
|
||||||
|
"""Two same-name calls with distinct indices keep their deltas
|
||||||
|
routed to the right card."""
|
||||||
|
events = [
|
||||||
|
_model_stream(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": "lc-A", "name": "write_file", "args": '{"a":1', "index": 0},
|
||||||
|
{"id": "lc-B", "name": "write_file", "args": '{"b":2', "index": 1},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
_model_stream(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": None, "name": None, "args": "}", "index": 0},
|
||||||
|
{"id": None, "name": None, "args": "}", "index": 1},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
_tool_start(name="write_file", run_id="run-A", input_payload={"a": 1}),
|
||||||
|
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-A"),
|
||||||
|
_tool_start(name="write_file", run_id="run-B", input_payload={"b": 2}),
|
||||||
|
_tool_end(name="write_file", run_id="run-B", tool_call_id="lc-B"),
|
||||||
|
]
|
||||||
|
|
||||||
|
payloads = await _drain(events)
|
||||||
|
|
||||||
|
starts = _of_type(payloads, "tool-input-start")
|
||||||
|
deltas = _of_type(payloads, "tool-input-delta")
|
||||||
|
output = _of_type(payloads, "tool-output-available")
|
||||||
|
|
||||||
|
assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"}
|
||||||
|
|
||||||
|
by_id: dict[str, list[str]] = {"lc-A": [], "lc-B": []}
|
||||||
|
for d in deltas:
|
||||||
|
by_id[d["toolCallId"]].append(d["inputTextDelta"])
|
||||||
|
assert by_id["lc-A"] == ['{"a":1', "}"]
|
||||||
|
assert by_id["lc-B"] == ['{"b":2', "}"]
|
||||||
|
|
||||||
|
assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_identity_stable_across_lifecycle(parity_v2_on: None) -> None:
|
||||||
|
"""Whatever id ``tool-input-start`` chose must be the SAME id used
|
||||||
|
on ``tool-input-available`` AND ``tool-output-available``."""
|
||||||
|
events = [
|
||||||
|
_model_stream(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": "lc-9", "name": "ls", "args": '{"path":"/"}', "index": 0}
|
||||||
|
]
|
||||||
|
),
|
||||||
|
_tool_start(name="ls", run_id="run-X", input_payload={"path": "/"}),
|
||||||
|
_tool_end(name="ls", run_id="run-X", tool_call_id="lc-9"),
|
||||||
|
]
|
||||||
|
payloads = await _drain(events)
|
||||||
|
relevant = [
|
||||||
|
p
|
||||||
|
for p in payloads
|
||||||
|
if p.get("type")
|
||||||
|
in {"tool-input-start", "tool-input-available", "tool-output-available"}
|
||||||
|
]
|
||||||
|
assert {p["toolCallId"] for p in relevant} == {"lc-9"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_duplicate_tool_input_start(parity_v2_on: None) -> None:
|
||||||
|
"""When the chunk-emission loop already fired ``tool-input-start``
|
||||||
|
for this run, ``on_tool_start`` MUST NOT emit a second one."""
|
||||||
|
events = [
|
||||||
|
_model_stream(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": "lc-1", "name": "write_file", "args": "{}", "index": 0}
|
||||||
|
]
|
||||||
|
),
|
||||||
|
_tool_start(name="write_file", run_id="run-A", input_payload={}),
|
||||||
|
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
|
||||||
|
]
|
||||||
|
payloads = await _drain(events)
|
||||||
|
starts = _of_type(payloads, "tool-input-start")
|
||||||
|
assert len(starts) == 1
|
||||||
|
assert starts[0]["toolCallId"] == "lc-1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_active_text_closes_before_early_tool_input_start(
|
||||||
|
parity_v2_on: None,
|
||||||
|
) -> None:
|
||||||
|
"""Streaming a text-delta then a tool-call chunk in subsequent
|
||||||
|
chunks: the wire MUST contain ``text-end`` before the FIRST
|
||||||
|
``tool-input-start`` (clean part boundary on the frontend)."""
|
||||||
|
events = [
|
||||||
|
_model_stream(text="Working on it"),
|
||||||
|
_model_stream(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": "lc-1", "name": "write_file", "args": "{}", "index": 0}
|
||||||
|
]
|
||||||
|
),
|
||||||
|
_tool_start(name="write_file", run_id="run-A", input_payload={}),
|
||||||
|
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
|
||||||
|
]
|
||||||
|
types = _types(await _drain(events))
|
||||||
|
text_end_idx = types.index("text-end")
|
||||||
|
start_idx = types.index("tool-input-start")
|
||||||
|
assert text_end_idx < start_idx
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mixed_text_and_tool_chunk_preserve_order(
|
||||||
|
parity_v2_on: None,
|
||||||
|
) -> None:
|
||||||
|
"""One AIMessageChunk that carries BOTH ``text`` content AND
|
||||||
|
``tool_call_chunks`` should emit the text delta FIRST, then close
|
||||||
|
text, then ``tool-input-start``+``tool-input-delta``."""
|
||||||
|
events = [
|
||||||
|
_model_stream(
|
||||||
|
text="I'll update it",
|
||||||
|
tool_call_chunks=[
|
||||||
|
{
|
||||||
|
"id": "lc-1",
|
||||||
|
"name": "write_file",
|
||||||
|
"args": '{"file_path":"/x"}',
|
||||||
|
"index": 0,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
_tool_start(
|
||||||
|
name="write_file", run_id="run-A", input_payload={"file_path": "/x"}
|
||||||
|
),
|
||||||
|
_tool_end(name="write_file", run_id="run-A", tool_call_id="lc-1"),
|
||||||
|
]
|
||||||
|
types = _types(await _drain(events))
|
||||||
|
# text-start … text-delta … text-end … tool-input-start … tool-input-delta
|
||||||
|
assert types.index("text-start") < types.index("text-delta")
|
||||||
|
assert types.index("text-delta") < types.index("text-end")
|
||||||
|
assert types.index("text-end") < types.index("tool-input-start")
|
||||||
|
assert types.index("tool-input-start") < types.index("tool-input-delta")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parity_v2_off_preserves_legacy_shape(
|
||||||
|
parity_v2_off: None,
|
||||||
|
) -> None:
|
||||||
|
"""When the flag is OFF, no deltas are emitted and the ``toolCallId``
|
||||||
|
is ``call_<run_id>`` (NOT the lc id)."""
|
||||||
|
events = [
|
||||||
|
_model_stream(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": "lc-1", "name": "ls", "args": '{"path":"/"}', "index": 0}
|
||||||
|
]
|
||||||
|
),
|
||||||
|
_tool_start(name="ls", run_id="run-A", input_payload={"path": "/"}),
|
||||||
|
_tool_end(name="ls", run_id="run-A", tool_call_id="lc-1"),
|
||||||
|
]
|
||||||
|
payloads = await _drain(events)
|
||||||
|
|
||||||
|
assert _of_type(payloads, "tool-input-delta") == []
|
||||||
|
starts = _of_type(payloads, "tool-input-start")
|
||||||
|
assert len(starts) == 1
|
||||||
|
assert starts[0]["toolCallId"].startswith("call_run-A")
|
||||||
|
# No ``langchainToolCallId`` propagation on ``tool-input-start`` in
|
||||||
|
# legacy mode (the start event fires before the ToolMessage is
|
||||||
|
# available, so we can't extract the authoritative LangChain id yet).
|
||||||
|
assert "langchainToolCallId" not in starts[0]
|
||||||
|
output = _of_type(payloads, "tool-output-available")
|
||||||
|
assert output[0]["toolCallId"].startswith("call_run-A")
|
||||||
|
# ``tool-output-available`` MUST carry ``langchainToolCallId`` even
|
||||||
|
# in legacy mode: the chat tool card uses it to backfill the
|
||||||
|
# LangChain id and join against the ``data-action-log`` SSE event
|
||||||
|
# (keyed by ``lc_tool_call_id``) so the inline Revert button can
|
||||||
|
# light up. Sourced from the returned ``ToolMessage.tool_call_id``,
|
||||||
|
# which is populated regardless of feature-flag state.
|
||||||
|
assert output[0]["langchainToolCallId"] == "lc-1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skip_append_prevents_stale_id_reuse(
|
||||||
|
parity_v2_on: None,
|
||||||
|
) -> None:
|
||||||
|
"""Two same-name tools: the SECOND tool's ``langchainToolCallId``
|
||||||
|
must NOT come from the first tool's chunk (``pending_tool_call_chunks``
|
||||||
|
must stay empty for indexed-registered chunks)."""
|
||||||
|
events = [
|
||||||
|
_model_stream(
|
||||||
|
tool_call_chunks=[
|
||||||
|
{"id": "lc-A", "name": "write_file", "args": "{}", "index": 0},
|
||||||
|
{"id": "lc-B", "name": "write_file", "args": "{}", "index": 1},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
_tool_start(name="write_file", run_id="run-1", input_payload={}),
|
||||||
|
_tool_end(name="write_file", run_id="run-1", tool_call_id="lc-A"),
|
||||||
|
_tool_start(name="write_file", run_id="run-2", input_payload={}),
|
||||||
|
_tool_end(name="write_file", run_id="run-2", tool_call_id="lc-B"),
|
||||||
|
]
|
||||||
|
payloads = await _drain(events)
|
||||||
|
|
||||||
|
starts = _of_type(payloads, "tool-input-start")
|
||||||
|
# Two distinct lc ids, each its own card.
|
||||||
|
assert {s["toolCallId"] for s in starts} == {"lc-A", "lc-B"}
|
||||||
|
# Each tool-output-available landed on its respective card.
|
||||||
|
output = _of_type(payloads, "tool-output-available")
|
||||||
|
assert {o["toolCallId"] for o in output} == {"lc-A", "lc-B"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_registration_waits_for_both_id_and_name(
|
||||||
|
parity_v2_on: None,
|
||||||
|
) -> None:
|
||||||
|
"""An id-only chunk (no name yet) must NOT emit ``tool-input-start``."""
|
||||||
|
events = [
|
||||||
|
_model_stream(
|
||||||
|
tool_call_chunks=[{"id": "lc-1", "name": None, "args": "", "index": 0}]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
payloads = await _drain(events)
|
||||||
|
assert _of_type(payloads, "tool-input-start") == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unmatched_fallback_still_attaches_lc_id(
|
||||||
|
parity_v2_on: None,
|
||||||
|
) -> None:
|
||||||
|
"""parity_v2 ON, but the provider didn't include an ``index``: the
|
||||||
|
legacy fallback path must still emit ``tool-input-start`` with the
|
||||||
|
matching ``langchainToolCallId``."""
|
||||||
|
events = [
|
||||||
|
# No index on the chunk → not registered into index_to_meta;
|
||||||
|
# falls through to ``pending_tool_call_chunks`` so the legacy
|
||||||
|
# match path can pop it at on_tool_start.
|
||||||
|
_model_stream(tool_call_chunks=[{"id": "lc-orphan", "name": "ls", "args": ""}]),
|
||||||
|
_tool_start(name="ls", run_id="run-1", input_payload={"path": "/"}),
|
||||||
|
_tool_end(name="ls", run_id="run-1", tool_call_id="lc-orphan"),
|
||||||
|
]
|
||||||
|
payloads = await _drain(events)
|
||||||
|
starts = _of_type(payloads, "tool-input-start")
|
||||||
|
assert len(starts) == 1
|
||||||
|
assert starts[0]["toolCallId"].startswith("call_run-1")
|
||||||
|
assert starts[0]["langchainToolCallId"] == "lc-orphan"
|
||||||
|
|
@ -1,9 +1,21 @@
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import app.tasks.chat.stream_new_chat as stream_new_chat_module
|
||||||
|
from app.agents.new_chat.errors import BusyError
|
||||||
|
from app.agents.new_chat.middleware.busy_mutex import request_cancel, reset_cancel
|
||||||
from app.tasks.chat.stream_new_chat import (
|
from app.tasks.chat.stream_new_chat import (
|
||||||
StreamResult,
|
StreamResult,
|
||||||
|
_classify_stream_exception,
|
||||||
_contract_enforcement_active,
|
_contract_enforcement_active,
|
||||||
_evaluate_file_contract_outcome,
|
_evaluate_file_contract_outcome,
|
||||||
|
_extract_resolved_file_path,
|
||||||
|
_log_chat_stream_error,
|
||||||
_tool_output_has_error,
|
_tool_output_has_error,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -17,6 +29,39 @@ def test_tool_output_error_detection():
|
||||||
assert not _tool_output_has_error({"result": "Updated file /notes.md"})
|
assert not _tool_output_has_error({"result": "Updated file /notes.md"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_resolved_file_path_prefers_structured_path():
|
||||||
|
assert (
|
||||||
|
_extract_resolved_file_path(
|
||||||
|
tool_name="write_file",
|
||||||
|
tool_output={"status": "completed", "path": "/docs/note.md"},
|
||||||
|
tool_input=None,
|
||||||
|
)
|
||||||
|
== "/docs/note.md"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_resolved_file_path_falls_back_to_tool_input():
|
||||||
|
assert (
|
||||||
|
_extract_resolved_file_path(
|
||||||
|
tool_name="edit_file",
|
||||||
|
tool_output={"status": "completed", "result": "updated"},
|
||||||
|
tool_input={"file_path": "/docs/edited.md"},
|
||||||
|
)
|
||||||
|
== "/docs/edited.md"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_resolved_file_path_does_not_parse_result_text():
|
||||||
|
assert (
|
||||||
|
_extract_resolved_file_path(
|
||||||
|
tool_name="write_file",
|
||||||
|
tool_output={"result": "Updated file /docs/from-text.md"},
|
||||||
|
tool_input=None,
|
||||||
|
)
|
||||||
|
is None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_file_write_contract_outcome_reasons():
|
def test_file_write_contract_outcome_reasons():
|
||||||
result = StreamResult(intent_detected="file_write")
|
result = StreamResult(intent_detected="file_write")
|
||||||
passed, reason = _evaluate_file_contract_outcome(result)
|
passed, reason = _evaluate_file_contract_outcome(result)
|
||||||
|
|
@ -45,3 +90,433 @@ def test_contract_enforcement_local_only():
|
||||||
|
|
||||||
result.filesystem_mode = "cloud"
|
result.filesystem_mode = "cloud"
|
||||||
assert not _contract_enforcement_active(result)
|
assert not _contract_enforcement_active(result)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_chat_stream_payload(record_message: str) -> dict:
|
||||||
|
prefix = "[chat_stream_error] "
|
||||||
|
assert record_message.startswith(prefix)
|
||||||
|
return json.loads(record_message[len(prefix) :])
|
||||||
|
|
||||||
|
|
||||||
|
def test_unified_chat_stream_error_log_schema(caplog):
|
||||||
|
with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"):
|
||||||
|
_log_chat_stream_error(
|
||||||
|
flow="new",
|
||||||
|
error_kind="server_error",
|
||||||
|
error_code="SERVER_ERROR",
|
||||||
|
severity="warn",
|
||||||
|
is_expected=False,
|
||||||
|
request_id="req-123",
|
||||||
|
thread_id=101,
|
||||||
|
search_space_id=202,
|
||||||
|
user_id="user-1",
|
||||||
|
message="Error during chat: boom",
|
||||||
|
)
|
||||||
|
|
||||||
|
record = next(r for r in caplog.records if "[chat_stream_error]" in r.message)
|
||||||
|
payload = _extract_chat_stream_payload(record.message)
|
||||||
|
|
||||||
|
required_keys = {
|
||||||
|
"event",
|
||||||
|
"flow",
|
||||||
|
"error_kind",
|
||||||
|
"error_code",
|
||||||
|
"severity",
|
||||||
|
"is_expected",
|
||||||
|
"request_id",
|
||||||
|
"thread_id",
|
||||||
|
"search_space_id",
|
||||||
|
"user_id",
|
||||||
|
"message",
|
||||||
|
}
|
||||||
|
assert required_keys.issubset(payload.keys())
|
||||||
|
assert payload["event"] == "chat_stream_error"
|
||||||
|
assert payload["flow"] == "new"
|
||||||
|
assert payload["error_code"] == "SERVER_ERROR"
|
||||||
|
|
||||||
|
|
||||||
|
def test_premium_quota_uses_unified_chat_stream_log_shape(caplog):
|
||||||
|
with caplog.at_level(logging.INFO, logger="app.tasks.chat.stream_new_chat"):
|
||||||
|
_log_chat_stream_error(
|
||||||
|
flow="resume",
|
||||||
|
error_kind="premium_quota_exhausted",
|
||||||
|
error_code="PREMIUM_QUOTA_EXHAUSTED",
|
||||||
|
severity="info",
|
||||||
|
is_expected=True,
|
||||||
|
request_id="req-premium",
|
||||||
|
thread_id=303,
|
||||||
|
search_space_id=404,
|
||||||
|
user_id="user-2",
|
||||||
|
message="Buy more tokens to continue with this model, or switch to a free model",
|
||||||
|
extra={"auto_fallback": False},
|
||||||
|
)
|
||||||
|
|
||||||
|
record = next(r for r in caplog.records if "[chat_stream_error]" in r.message)
|
||||||
|
payload = _extract_chat_stream_payload(record.message)
|
||||||
|
assert payload["event"] == "chat_stream_error"
|
||||||
|
assert payload["error_kind"] == "premium_quota_exhausted"
|
||||||
|
assert payload["error_code"] == "PREMIUM_QUOTA_EXHAUSTED"
|
||||||
|
assert payload["flow"] == "resume"
|
||||||
|
assert payload["is_expected"] is True
|
||||||
|
assert payload["auto_fallback"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_error_emission_keeps_machine_error_codes():
|
||||||
|
source = inspect.getsource(stream_new_chat_module)
|
||||||
|
format_error_calls = re.findall(r"format_error\(", source)
|
||||||
|
emitted_error_codes = set(re.findall(r'error_code="([A-Z_]+)"', source))
|
||||||
|
|
||||||
|
# All stream paths should route through one shared terminal error emitter.
|
||||||
|
assert len(format_error_calls) == 1
|
||||||
|
assert {
|
||||||
|
"PREMIUM_QUOTA_EXHAUSTED",
|
||||||
|
"SERVER_ERROR",
|
||||||
|
}.issubset(emitted_error_codes)
|
||||||
|
assert 'flow: Literal["new", "regenerate"] = "new"' in source
|
||||||
|
assert "_emit_stream_terminal_error" in source
|
||||||
|
assert "flow=flow" in source
|
||||||
|
assert 'flow="resume"' in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_exception_classifies_rate_limited():
|
||||||
|
exc = Exception(
|
||||||
|
'{"error":{"type":"rate_limit_error","message":"Rate limited. Please try again later."}}'
|
||||||
|
)
|
||||||
|
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||||
|
exc, flow_label="chat"
|
||||||
|
)
|
||||||
|
assert kind == "rate_limited"
|
||||||
|
assert code == "RATE_LIMITED"
|
||||||
|
assert severity == "warn"
|
||||||
|
assert is_expected is True
|
||||||
|
assert "temporarily rate-limited" in user_message
|
||||||
|
assert extra is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_exception_classifies_openrouter_429_payload():
|
||||||
|
exc = Exception(
|
||||||
|
'OpenrouterException - {"error":{"message":"Provider returned error","code":429,'
|
||||||
|
'"metadata":{"raw":"foo is temporarily rate-limited upstream"}}}'
|
||||||
|
)
|
||||||
|
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||||
|
exc, flow_label="chat"
|
||||||
|
)
|
||||||
|
assert kind == "rate_limited"
|
||||||
|
assert code == "RATE_LIMITED"
|
||||||
|
assert severity == "warn"
|
||||||
|
assert is_expected is True
|
||||||
|
assert "temporarily rate-limited" in user_message
|
||||||
|
assert extra is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preflight_swallows_non_rate_limit_errors_and_re_raises_429(monkeypatch):
|
||||||
|
"""``_preflight_llm`` is best-effort.
|
||||||
|
|
||||||
|
- On rate-limit shaped exceptions (provider 429) it MUST re-raise so the
|
||||||
|
caller can drive the cooldown/repin branch.
|
||||||
|
- On any other transient failure it MUST swallow the error so the normal
|
||||||
|
stream path continues without surfacing preflight noise to the user.
|
||||||
|
"""
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from app.tasks.chat.stream_new_chat import _preflight_llm
|
||||||
|
|
||||||
|
class _RateLimitedError(Exception):
|
||||||
|
"""Class-name carries 'RateLimit' so _is_provider_rate_limited triggers."""
|
||||||
|
|
||||||
|
rate_calls: list[dict] = []
|
||||||
|
other_calls: list[dict] = []
|
||||||
|
|
||||||
|
async def _fake_acompletion_429(**kwargs):
|
||||||
|
rate_calls.append(kwargs)
|
||||||
|
raise _RateLimitedError("simulated 429")
|
||||||
|
|
||||||
|
async def _fake_acompletion_other(**kwargs):
|
||||||
|
other_calls.append(kwargs)
|
||||||
|
raise RuntimeError("some unrelated transient failure")
|
||||||
|
|
||||||
|
fake_llm = SimpleNamespace(
|
||||||
|
model="openrouter/google/gemma-4-31b-it:free",
|
||||||
|
api_key="test",
|
||||||
|
api_base=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
import litellm # type: ignore[import-not-found]
|
||||||
|
|
||||||
|
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_429)
|
||||||
|
with pytest.raises(_RateLimitedError):
|
||||||
|
await _preflight_llm(fake_llm)
|
||||||
|
assert len(rate_calls) == 1
|
||||||
|
assert rate_calls[0]["max_tokens"] == 1
|
||||||
|
assert rate_calls[0]["stream"] is False
|
||||||
|
|
||||||
|
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion_other)
|
||||||
|
# MUST NOT raise: non-rate-limit failures are swallowed.
|
||||||
|
await _preflight_llm(fake_llm)
|
||||||
|
assert len(other_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preflight_skipped_for_auto_router_model():
|
||||||
|
"""Router-mode ``model='auto'`` has no single deployment to ping; the
|
||||||
|
LiteLLM router itself owns per-deployment rate-limit accounting, so the
|
||||||
|
preflight helper must short-circuit instead of issuing a probe."""
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from app.tasks.chat.stream_new_chat import _preflight_llm
|
||||||
|
|
||||||
|
fake_llm = SimpleNamespace(model="auto", api_key="x", api_base=None)
|
||||||
|
# Should return without raising or making any LiteLLM call.
|
||||||
|
await _preflight_llm(fake_llm)
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_exception_classifies_thread_busy():
|
||||||
|
exc = BusyError(request_id="thread-123")
|
||||||
|
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||||
|
exc, flow_label="chat"
|
||||||
|
)
|
||||||
|
assert kind == "thread_busy"
|
||||||
|
assert code == "THREAD_BUSY"
|
||||||
|
assert severity == "warn"
|
||||||
|
assert is_expected is True
|
||||||
|
assert "still finishing for this thread" in user_message
|
||||||
|
assert extra is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_exception_classifies_thread_busy_from_message():
|
||||||
|
exc = Exception("Thread is busy with another request")
|
||||||
|
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||||
|
exc, flow_label="chat"
|
||||||
|
)
|
||||||
|
assert kind == "thread_busy"
|
||||||
|
assert code == "THREAD_BUSY"
|
||||||
|
assert severity == "warn"
|
||||||
|
assert is_expected is True
|
||||||
|
assert "still finishing for this thread" in user_message
|
||||||
|
assert extra is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_exception_classifies_turn_cancelling_when_cancel_requested():
|
||||||
|
thread_id = "thread-cancelling-1"
|
||||||
|
reset_cancel(thread_id)
|
||||||
|
request_cancel(thread_id)
|
||||||
|
exc = BusyError(request_id=thread_id)
|
||||||
|
kind, code, severity, is_expected, user_message, extra = _classify_stream_exception(
|
||||||
|
exc, flow_label="chat"
|
||||||
|
)
|
||||||
|
assert kind == "thread_busy"
|
||||||
|
assert code == "TURN_CANCELLING"
|
||||||
|
assert severity == "info"
|
||||||
|
assert is_expected is True
|
||||||
|
assert "stopping" in user_message
|
||||||
|
assert isinstance(extra, dict)
|
||||||
|
assert "retry_after_ms" in extra
|
||||||
|
|
||||||
|
|
||||||
|
def test_premium_classification_is_error_code_driven():
|
||||||
|
classifier_path = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "surfsense_web/lib/chat/chat-error-classifier.ts"
|
||||||
|
)
|
||||||
|
source = classifier_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert "PREMIUM_KEYWORDS" not in source
|
||||||
|
assert "RATE_LIMIT_KEYWORDS" not in source
|
||||||
|
assert "normalized.includes(" not in source
|
||||||
|
assert 'if (errorCode === "PREMIUM_QUOTA_EXHAUSTED") {' in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_terminal_error_handler_has_pre_accept_soft_rollback_hook():
|
||||||
|
page_path = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
|
||||||
|
)
|
||||||
|
source = page_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert "onPreAcceptFailure?: () => Promise<void>;" in source
|
||||||
|
assert "if (!accepted) {" in source
|
||||||
|
assert "await onPreAcceptFailure?.();" in source
|
||||||
|
assert "await onAcceptedStreamError?.();" in source
|
||||||
|
assert "setMessages((prev) => prev.filter((m) => m.id !== userMsgId));" in source
|
||||||
|
assert "setMessageDocumentsMap((prev) => {" in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_toast_only_pre_accept_policy_has_no_inline_failed_marker():
|
||||||
|
user_message_path = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "surfsense_web/components/assistant-ui/user-message.tsx"
|
||||||
|
)
|
||||||
|
source = user_message_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert "Not sent. Edit and retry." not in source
|
||||||
|
assert "failed_pre_accept" not in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_network_send_failures_use_unified_retry_toast_message():
|
||||||
|
classifier_path = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "surfsense_web/lib/chat/chat-error-classifier.ts"
|
||||||
|
)
|
||||||
|
classifier_source = classifier_path.read_text(encoding="utf-8")
|
||||||
|
request_errors_path = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "surfsense_web/lib/chat/chat-request-errors.ts"
|
||||||
|
)
|
||||||
|
request_errors_source = request_errors_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert '"send_failed_pre_accept"' in classifier_source
|
||||||
|
assert 'errorCode === "SEND_FAILED_PRE_ACCEPT"' in classifier_source
|
||||||
|
assert 'errorCode === "TURN_CANCELLING"' in classifier_source
|
||||||
|
assert "if (withCode.code) return withCode.code;" in classifier_source
|
||||||
|
assert 'userMessage: "Message not sent. Please retry."' in classifier_source
|
||||||
|
assert 'userMessage: "Connection issue. Please try again."' in classifier_source
|
||||||
|
assert "const passthroughCodes = new Set([" in request_errors_source
|
||||||
|
assert '"PREMIUM_QUOTA_EXHAUSTED"' in request_errors_source
|
||||||
|
assert '"THREAD_BUSY"' in request_errors_source
|
||||||
|
assert '"TURN_CANCELLING"' in request_errors_source
|
||||||
|
assert '"AUTH_EXPIRED"' in request_errors_source
|
||||||
|
assert '"UNAUTHORIZED"' in request_errors_source
|
||||||
|
assert '"RATE_LIMITED"' in request_errors_source
|
||||||
|
assert '"NETWORK_ERROR"' in request_errors_source
|
||||||
|
assert '"STREAM_PARSE_ERROR"' in request_errors_source
|
||||||
|
assert '"TOOL_EXECUTION_ERROR"' in request_errors_source
|
||||||
|
assert '"PERSIST_MESSAGE_FAILED"' in request_errors_source
|
||||||
|
assert '"SERVER_ERROR"' in request_errors_source
|
||||||
|
assert "passthroughCodes.has(existingCode)" in request_errors_source
|
||||||
|
assert 'errorCode: "SEND_FAILED_PRE_ACCEPT"' in request_errors_source
|
||||||
|
assert 'errorCode: "NETWORK_ERROR"' not in request_errors_source
|
||||||
|
assert "Failed to start chat. Please try again." not in classifier_source
|
||||||
|
|
||||||
|
|
||||||
|
def test_pre_post_accept_abort_contract_exists_for_new_resume_regenerate_flows():
|
||||||
|
page_path = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx"
|
||||||
|
)
|
||||||
|
source = page_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
# Each flow tracks accepted boundary and passes it into shared terminal handling.
|
||||||
|
assert "let newAccepted = false;" in source
|
||||||
|
assert "let resumeAccepted = false;" in source
|
||||||
|
assert "let regenerateAccepted = false;" in source
|
||||||
|
assert "accepted: newAccepted," in source
|
||||||
|
assert "accepted: resumeAccepted," in source
|
||||||
|
assert "accepted: regenerateAccepted," in source
|
||||||
|
|
||||||
|
# Pre-accept abort in resume/regenerate exits without persistence.
|
||||||
|
assert "if (!resumeAccepted) return;" in source
|
||||||
|
assert "if (!regenerateAccepted) return;" in source
|
||||||
|
|
||||||
|
# New flow persists only when accepted and not already persisted.
|
||||||
|
assert "if (newAccepted && !userPersisted) {" in source
|
||||||
|
assert "const fetchWithTurnCancellingRetry = useCallback(" in source
|
||||||
|
assert "computeFallbackTurnCancellingRetryDelay" in source
|
||||||
|
assert 'withMeta.errorCode === "TURN_CANCELLING"' in source
|
||||||
|
assert 'withMeta.errorCode === "THREAD_BUSY"' in source
|
||||||
|
assert "await fetchWithTurnCancellingRetry(() =>" in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_active_turn_route_contract_exists():
|
||||||
|
routes_path = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
||||||
|
)
|
||||||
|
source = routes_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert '@router.post(\n "/threads/{thread_id}/cancel-active-turn",' in source
|
||||||
|
assert "response_model=CancelActiveTurnResponse" in source
|
||||||
|
assert 'status="cancelling",' in source
|
||||||
|
assert 'error_code="TURN_CANCELLING",' in source
|
||||||
|
assert "retry_after_ms=retry_after_ms if retry_after_ms > 0 else None," in source
|
||||||
|
assert "retry_after_at=" in source
|
||||||
|
assert 'status="idle",' in source
|
||||||
|
assert 'error_code="NO_ACTIVE_TURN",' in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_turn_status_route_contract_exists():
|
||||||
|
routes_path = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
||||||
|
)
|
||||||
|
source = routes_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert '@router.get(\n "/threads/{thread_id}/turn-status",' in source
|
||||||
|
assert "response_model=TurnStatusResponse" in source
|
||||||
|
assert "_build_turn_status_payload(thread_id)" in source
|
||||||
|
assert "Permission.CHATS_READ.value" in source
|
||||||
|
assert "_raise_if_thread_busy_for_start(" in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_turn_cancelling_retry_policy_contract_exists():
|
||||||
|
routes_path = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "surfsense_backend/app/routes/new_chat_routes.py"
|
||||||
|
)
|
||||||
|
source = routes_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert "TURN_CANCELLING_INITIAL_DELAY_MS = 200" in source
|
||||||
|
assert "TURN_CANCELLING_BACKOFF_FACTOR = 2" in source
|
||||||
|
assert "TURN_CANCELLING_MAX_DELAY_MS = 1500" in source
|
||||||
|
assert "def _compute_turn_cancelling_retry_delay(" in source
|
||||||
|
assert "retry-after-ms" in source
|
||||||
|
assert '"Retry-After"' in source
|
||||||
|
assert '"errorCode": "TURN_CANCELLING"' in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_turn_status_sse_contract_exists():
|
||||||
|
stream_source = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "surfsense_backend/app/tasks/chat/stream_new_chat.py"
|
||||||
|
).read_text(encoding="utf-8")
|
||||||
|
state_source = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "surfsense_web/lib/chat/streaming-state.ts"
|
||||||
|
).read_text(encoding="utf-8")
|
||||||
|
pipeline_source = (
|
||||||
|
Path(__file__).resolve().parents[3]
|
||||||
|
/ "surfsense_web/lib/chat/stream-pipeline.ts"
|
||||||
|
).read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert '"turn-status"' in stream_source
|
||||||
|
assert '"status": "busy"' in stream_source
|
||||||
|
assert '"status": "idle"' in stream_source
|
||||||
|
assert 'type: "data-turn-status"' in state_source
|
||||||
|
assert 'case "data-turn-status":' in pipeline_source
|
||||||
|
assert "end_turn(str(chat_id))" in stream_source
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_deepagent_forwards_resolved_model_name_to_both_builders():
|
||||||
|
"""Regression guard: both system-prompt builders in chat_deepagent.py
|
||||||
|
must receive ``model_name=_resolve_prompt_model_name(...)`` so the
|
||||||
|
provider-variant dispatch can render the right ``<provider_hints>``
|
||||||
|
block. Without this the prompt silently falls back to the empty
|
||||||
|
``"default"`` variant — the original bug being fixed.
|
||||||
|
|
||||||
|
This test mirrors :func:`test_stream_error_emission_keeps_machine_error_codes`
|
||||||
|
in style: it inspects module source text + a regex to enforce the
|
||||||
|
call-site shape, not just the wrapper layer (the wrappers already
|
||||||
|
forward ``model_name`` correctly, so testing them would not catch
|
||||||
|
the actual missed plumbing).
|
||||||
|
"""
|
||||||
|
import app.agents.new_chat.chat_deepagent as chat_deepagent_module
|
||||||
|
|
||||||
|
source = inspect.getsource(chat_deepagent_module)
|
||||||
|
|
||||||
|
# Helper itself must be defined.
|
||||||
|
assert "def _resolve_prompt_model_name(" in source
|
||||||
|
|
||||||
|
# Both builder calls must forward the resolved model name. Match
|
||||||
|
# across newlines + whitespace because the kwargs are split over
|
||||||
|
# multiple lines.
|
||||||
|
pattern = re.compile(
|
||||||
|
r"build_(?:surfsense|configurable)_system_prompt\([^)]*"
|
||||||
|
r"model_name=_resolve_prompt_model_name\(",
|
||||||
|
re.DOTALL,
|
||||||
|
)
|
||||||
|
matches = pattern.findall(source)
|
||||||
|
assert len(matches) == 2, (
|
||||||
|
"Expected both system-prompt builder call sites to forward "
|
||||||
|
"`model_name=_resolve_prompt_model_name(...)`, found "
|
||||||
|
f"{len(matches)}"
|
||||||
|
)
|
||||||
|
|
|
||||||
160
surfsense_backend/uv.lock
generated
160
surfsense_backend/uv.lock
generated
|
|
@ -62,7 +62,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiohttp"
|
name = "aiohttp"
|
||||||
version = "3.13.5"
|
version = "3.13.4"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "aiohappyeyeballs" },
|
{ name = "aiohappyeyeballs" },
|
||||||
|
|
@ -73,76 +73,76 @@ dependencies = [
|
||||||
{ name = "propcache" },
|
{ name = "propcache" },
|
||||||
{ name = "yarl" },
|
{ name = "yarl" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/77/9a/152096d4808df8e4268befa55fba462f440f14beab85e8ad9bf990516918/aiohttp-3.13.5.tar.gz", hash = "sha256:9d98cc980ecc96be6eb4c1994ce35d28d8b1f5e5208a23b421187d1209dbb7d1", size = 7858271 }
|
sdist = { url = "https://files.pythonhosted.org/packages/45/4a/064321452809dae953c1ed6e017504e72551a26b6f5708a5a80e4bf556ff/aiohttp-3.13.4.tar.gz", hash = "sha256:d97a6d09c66087890c2ab5d49069e1e570583f7ac0314ecf98294c1b6aaebd38", size = 7859748 }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/be/6f/353954c29e7dcce7cf00280a02c75f30e133c00793c7a2ed3776d7b2f426/aiohttp-3.13.5-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:023ecba036ddd840b0b19bf195bfae970083fd7024ce1ac22e9bba90464620e9", size = 748876 },
|
{ url = "https://files.pythonhosted.org/packages/1e/bd/ede278648914cabbabfdf95e436679b5d4156e417896a9b9f4587169e376/aiohttp-3.13.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ee62d4471ce86b108b19c3364db4b91180d13fe3510144872d6bad5401957360", size = 752158 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/f5/1b/428a7c64687b3b2e9cd293186695affc0e1e54a445d0361743b231f11066/aiohttp-3.13.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:15c933ad7920b7d9a20de151efcd05a6e38302cbf0e10c9b2acb9a42210a2416", size = 499557 },
|
{ url = "https://files.pythonhosted.org/packages/90/de/581c053253c07b480b03785196ca5335e3c606a37dc73e95f6527f1591fe/aiohttp-3.13.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c0fd8f41b54b58636402eb493afd512c23580456f022c1ba2db0f810c959ed0d", size = 501037 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/29/47/7be41556bfbb6917069d6a6634bb7dd5e163ba445b783a90d40f5ac7e3a7/aiohttp-3.13.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ab2899f9fa2f9f741896ebb6fa07c4c883bfa5c7f2ddd8cf2aafa86fa981b2d2", size = 500258 },
|
{ url = "https://files.pythonhosted.org/packages/fa/f9/a5ede193c08f13cc42c0a5b50d1e246ecee9115e4cf6e900d8dbd8fd6acb/aiohttp-3.13.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4baa48ce49efd82d6b1a0be12d6a36b35e5594d1dd42f8bfba96ea9f8678b88c", size = 501556 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/67/84/c9ecc5828cb0b3695856c07c0a6817a99d51e2473400f705275a2b3d9239/aiohttp-3.13.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a60eaa2d440cd4707696b52e40ed3e2b0f73f65be07fd0ef23b6b539c9c0b0b4", size = 1749199 },
|
{ url = "https://files.pythonhosted.org/packages/d6/10/88ff67cd48a6ec36335b63a640abe86135791544863e0cfe1f065d6cef7a/aiohttp-3.13.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d738ebab9f71ee652d9dbd0211057690022201b11197f9a7324fd4dba128aa97", size = 1757314 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/f0/d3/3c6d610e66b495657622edb6ae7c7fd31b2e9086b4ec50b47897ad6042a9/aiohttp-3.13.5-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:55b3bdd3292283295774ab585160c4004f4f2f203946997f49aac032c84649e9", size = 1721013 },
|
{ url = "https://files.pythonhosted.org/packages/8b/15/fdb90a5cf5a1f52845c276e76298c75fbbcc0ac2b4a86551906d54529965/aiohttp-3.13.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:0ce692c3468fa831af7dceed52edf51ac348cebfc8d3feb935927b63bd3e8576", size = 1731819 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/49/a0/24409c12217456df0bae7babe3b014e460b0b38a8e60753d6cb339f6556d/aiohttp-3.13.5-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c2b2355dc094e5f7d45a7bb262fe7207aa0460b37a0d87027dcf21b5d890e7d5", size = 1781501 },
|
{ url = "https://files.pythonhosted.org/packages/ec/df/28146785a007f7820416be05d4f28cc207493efd1e8c6c1068e9bdc29198/aiohttp-3.13.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8e08abcfe752a454d2cb89ff0c08f2d1ecd057ae3e8cc6d84638de853530ebab", size = 1793279 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/98/9d/b65ec649adc5bccc008b0957a9a9c691070aeac4e41cea18559fef49958b/aiohttp-3.13.5-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b38765950832f7d728297689ad78f5f2cf79ff82487131c4d26fe6ceecdc5f8e", size = 1878981 },
|
{ url = "https://files.pythonhosted.org/packages/10/47/689c743abf62ea7a77774d5722f220e2c912a77d65d368b884d9779ef41b/aiohttp-3.13.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5977f701b3fff36367a11087f30ea73c212e686d41cd363c50c022d48b011d8d", size = 1891082 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/57/d8/8d44036d7eb7b6a8ec4c5494ea0c8c8b94fbc0ed3991c1a7adf230df03bf/aiohttp-3.13.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b18f31b80d5a33661e08c89e202edabf1986e9b49c42b4504371daeaa11b47c1", size = 1767934 },
|
{ url = "https://files.pythonhosted.org/packages/b0/b6/f7f4f318c7e58c23b761c9b13b9a3c9b394e0f9d5d76fbc6622fa98509f6/aiohttp-3.13.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:54203e10405c06f8b6020bd1e076ae0fe6c194adcee12a5a78af3ffa3c57025e", size = 1773938 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/31/04/d3f8211f273356f158e3464e9e45484d3fb8c4ce5eb2f6fe9405c3273983/aiohttp-3.13.5-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:33add2463dde55c4f2d9635c6ab33ce154e5ecf322bd26d09af95c5f81cfa286", size = 1566671 },
|
{ url = "https://files.pythonhosted.org/packages/aa/06/f207cb3121852c989586a6fc16ff854c4fcc8651b86c5d3bd1fc83057650/aiohttp-3.13.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:358a6af0145bc4dda037f13167bef3cce54b132087acc4c295c739d05d16b1c3", size = 1579548 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/41/db/073e4ebe00b78e2dfcacff734291651729a62953b48933d765dc513bf798/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:327cc432fdf1356fb4fbc6fe833ad4e9f6aacb71a8acaa5f1855e4b25910e4a9", size = 1705219 },
|
{ url = "https://files.pythonhosted.org/packages/6c/58/e1289661a32161e24c1fe479711d783067210d266842523752869cc1d9c2/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:898ea1850656d7d61832ef06aa9846ab3ddb1621b74f46de78fbc5e1a586ba83", size = 1714669 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/48/45/7dfba71a2f9fd97b15c95c06819de7eb38113d2cdb6319669195a7d64270/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:7c35b0bf0b48a70b4cb4fc5d7bed9b932532728e124874355de1a0af8ec4bc88", size = 1743049 },
|
{ url = "https://files.pythonhosted.org/packages/96/0a/3e86d039438a74a86e6a948a9119b22540bae037d6ba317a042ae3c22711/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:7bc30cceb710cf6a44e9617e43eebb6e3e43ad855a34da7b4b6a73537d8a6763", size = 1754175 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/18/71/901db0061e0f717d226386a7f471bb59b19566f2cae5f0d93874b017271f/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:df23d57718f24badef8656c49743e11a89fd6f5358fa8a7b96e728fda2abf7d3", size = 1749557 },
|
{ url = "https://files.pythonhosted.org/packages/f4/30/e717fc5df83133ba467a560b6d8ef20197037b4bb5d7075b90037de1018e/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4a31c0c587a8a038f19a4c7e60654a6c899c9de9174593a13e7cc6e15ff271f9", size = 1762049 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/08/d5/41eebd16066e59cd43728fe74bce953d7402f2b4ddfdfef2c0e9f17ca274/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:02e048037a6501a5ec1f6fc9736135aec6eb8a004ce48838cb951c515f32c80b", size = 1558931 },
|
{ url = "https://files.pythonhosted.org/packages/e4/28/8f7a2d4492e336e40005151bdd94baf344880a4707573378579f833a64c1/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:2062f675f3fe6e06d6113eb74a157fb9df58953ffed0cdb4182554b116545758", size = 1570861 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/30/e6/4a799798bf05740e66c3a1161079bda7a3dd8e22ca392481d7a7f9af82a6/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:31cebae8b26f8a615d2b546fee45d5ffb76852ae6450e2a03f42c9102260d6fe", size = 1774125 },
|
{ url = "https://files.pythonhosted.org/packages/78/45/12e1a3d0645968b1c38de4b23fdf270b8637735ea057d4f84482ff918ad9/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:3d1ba8afb847ff80626d5e408c1fdc99f942acc877d0702fe137015903a220a9", size = 1790003 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/84/63/7749337c90f92bc2cb18f9560d67aa6258c7060d1397d21529b8004fcf6f/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:888e78eb5ca55a615d285c3c09a7a91b42e9dd6fc699b166ebd5dee87c9ccf14", size = 1732427 },
|
{ url = "https://files.pythonhosted.org/packages/eb/0f/60374e18d590de16dcb39d6ff62f39c096c1b958e6f37727b5870026ea30/aiohttp-3.13.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b08149419994cdd4d5eecf7fd4bc5986b5a9380285bcd01ab4c0d6bfca47b79d", size = 1737289 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/98/de/cf2f44ff98d307e72fb97d5f5bbae3bfcb442f0ea9790c0bf5c5c2331404/aiohttp-3.13.5-cp312-cp312-win32.whl", hash = "sha256:8bd3ec6376e68a41f9f95f5ed170e2fcf22d4eb27a1f8cb361d0508f6e0557f3", size = 433534 },
|
{ url = "https://files.pythonhosted.org/packages/02/bf/535e58d886cfbc40a8b0013c974afad24ef7632d645bca0b678b70033a60/aiohttp-3.13.4-cp312-cp312-win32.whl", hash = "sha256:fc432f6a2c4f720180959bc19aa37259651c1a4ed8af8afc84dd41c60f15f791", size = 434185 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/aa/ca/eadf6f9c8fa5e31d40993e3db153fb5ed0b11008ad5d9de98a95045bed84/aiohttp-3.13.5-cp312-cp312-win_amd64.whl", hash = "sha256:110e448e02c729bcebb18c60b9214a87ba33bac4a9fa5e9a5f139938b56c6cb1", size = 460446 },
|
{ url = "https://files.pythonhosted.org/packages/1e/1a/d92e3325134ebfff6f4069f270d3aac770d63320bd1fcd0eca023e74d9a8/aiohttp-3.13.4-cp312-cp312-win_amd64.whl", hash = "sha256:6148c9ae97a3e8bff9a1fc9c757fa164116f86c100468339730e717590a3fb77", size = 461285 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/78/e9/d76bf503005709e390122d34e15256b88f7008e246c4bdbe915cd4f1adce/aiohttp-3.13.5-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a5029cc80718bbd545123cd8fe5d15025eccaaaace5d0eeec6bd556ad6163d61", size = 742930 },
|
{ url = "https://files.pythonhosted.org/packages/e3/ac/892f4162df9b115b4758d615f32ec63d00f3084c705ff5526630887b9b42/aiohttp-3.13.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:63dd5e5b1e43b8fb1e91b79b7ceba1feba588b317d1edff385084fcc7a0a4538", size = 745744 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/57/00/4b7b70223deaebd9bb85984d01a764b0d7bd6526fcdc73cca83bcbe7243e/aiohttp-3.13.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4bb6bf5811620003614076bdc807ef3b5e38244f9d25ca5fe888eaccea2a9832", size = 496927 },
|
{ url = "https://files.pythonhosted.org/packages/97/a9/c5b87e4443a2f0ea88cb3000c93a8fdad1ee63bffc9ded8d8c8e0d66efc6/aiohttp-3.13.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:746ac3cc00b5baea424dacddea3ec2c2702f9590de27d837aa67004db1eebc6e", size = 498178 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/9c/f5/0fb20fb49f8efdcdce6cd8127604ad2c503e754a8f139f5e02b01626523f/aiohttp-3.13.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a84792f8631bf5a94e52d9cc881c0b824ab42717165a5579c760b830d9392ac9", size = 497141 },
|
{ url = "https://files.pythonhosted.org/packages/94/42/07e1b543a61250783650df13da8ddcdc0d0a5538b2bd15cef6e042aefc61/aiohttp-3.13.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bda8f16ea99d6a6705e5946732e48487a448be874e54a4f73d514660ff7c05d3", size = 498331 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/3b/86/b7c870053e36a94e8951b803cb5b909bfbc9b90ca941527f5fcafbf6b0fa/aiohttp-3.13.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:57653eac22c6a4c13eb22ecf4d673d64a12f266e72785ab1c8b8e5940d0e8090", size = 1732476 },
|
{ url = "https://files.pythonhosted.org/packages/20/d6/492f46bf0328534124772d0cf58570acae5b286ea25006900650f69dae0e/aiohttp-3.13.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4b061e7b5f840391e3f64d0ddf672973e45c4cfff7a0feea425ea24e51530fc2", size = 1744414 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/b5/e5/4e161f84f98d80c03a238671b4136e6530453d65262867d989bbe78244d0/aiohttp-3.13.5-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e5e5f7debc7a57af53fdf5c5009f9391d9f4c12867049d509bf7bb164a6e295b", size = 1706507 },
|
{ url = "https://files.pythonhosted.org/packages/e2/4d/e02627b2683f68051246215d2d62b2d2f249ff7a285e7a858dc47d6b6a14/aiohttp-3.13.4-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:b252e8d5cd66184b570d0d010de742736e8a4fab22c58299772b0c5a466d4b21", size = 1719226 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/d4/56/ea11a9f01518bd5a2a2fcee869d248c4b8a0cfa0bb13401574fa31adf4d4/aiohttp-3.13.5-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c719f65bebcdf6716f10e9eff80d27567f7892d8988c06de12bbbd39307c6e3a", size = 1773465 },
|
{ url = "https://files.pythonhosted.org/packages/7b/6c/5d0a3394dd2b9f9aeba6e1b6065d0439e4b75d41f1fb09a3ec010b43552b/aiohttp-3.13.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:20af8aad61d1803ff11152a26146d8d81c266aa8c5aa9b4504432abb965c36a0", size = 1782110 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/eb/40/333ca27fb74b0383f17c90570c748f7582501507307350a79d9f9f3c6eb1/aiohttp-3.13.5-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d97f93fdae594d886c5a866636397e2bcab146fd7a132fd6bb9ce182224452f8", size = 1873523 },
|
{ url = "https://files.pythonhosted.org/packages/0d/2d/c20791e3437700a7441a7edfb59731150322424f5aadf635602d1d326101/aiohttp-3.13.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:13a5cc924b59859ad2adb1478e31f410a7ed46e92a2a619d6d1dd1a63c1a855e", size = 1884809 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/f0/d2/e2f77eef1acb7111405433c707dc735e63f67a56e176e72e9e7a2cd3f493/aiohttp-3.13.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3df334e39d4c2f899a914f1dba283c1aadc311790733f705182998c6f7cae665", size = 1754113 },
|
{ url = "https://files.pythonhosted.org/packages/c8/94/d99dbfbd1924a87ef643833932eb2a3d9e5eee87656efea7d78058539eff/aiohttp-3.13.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:534913dfb0a644d537aebb4123e7d466d94e3be5549205e6a31f72368980a81a", size = 1764938 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/fb/56/3f653d7f53c89669301ec9e42c95233e2a0c0a6dd051269e6e678db4fdb0/aiohttp-3.13.5-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:fe6970addfea9e5e081401bcbadf865d2b6da045472f58af08427e108d618540", size = 1562351 },
|
{ url = "https://files.pythonhosted.org/packages/49/61/3ce326a1538781deb89f6cf5e094e2029cd308ed1e21b2ba2278b08426f6/aiohttp-3.13.4-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:320e40192a2dcc1cf4b5576936e9652981ab596bf81eb309535db7e2f5b5672f", size = 1570697 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/ec/a6/9b3e91eb8ae791cce4ee736da02211c85c6f835f1bdfac0594a8a3b7018c/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:7becdf835feff2f4f335d7477f121af787e3504b48b449ff737afb35869ba7bb", size = 1693205 },
|
{ url = "https://files.pythonhosted.org/packages/b6/77/4ab5a546857bb3028fbaf34d6eea180267bdab022ee8b1168b1fcde4bfdd/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9e587fcfce2bcf06526a43cb705bdee21ac089096f2e271d75de9c339db3100c", size = 1702258 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/98/fc/bfb437a99a2fcebd6b6eaec609571954de2ed424f01c352f4b5504371dd3/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:676e5651705ad5d8a70aeb8eb6936c436d8ebbd56e63436cb7dd9bb36d2a9a46", size = 1730618 },
|
{ url = "https://files.pythonhosted.org/packages/79/63/d8f29021e39bc5af8e5d5e9da1b07976fb9846487a784e11e4f4eeda4666/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:9eb9c2eea7278206b5c6c1441fdd9dc420c278ead3f3b2cc87f9b693698cc500", size = 1740287 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/e4/b6/c8534862126191a034f68153194c389addc285a0f1347d85096d349bbc15/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:9b16c653d38eb1a611cc898c41e76859ca27f119d25b53c12875fd0474ae31a8", size = 1745185 },
|
{ url = "https://files.pythonhosted.org/packages/55/3a/cbc6b3b124859a11bc8055d3682c26999b393531ef926754a3445b99dfef/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:29be00c51972b04bf9d5c8f2d7f7314f48f96070ca40a873a53056e652e805f7", size = 1753011 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/0b/93/4ca8ee2ef5236e2707e0fd5fecb10ce214aee1ff4ab307af9c558bda3b37/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:999802d5fa0389f58decd24b537c54aa63c01c3219ce17d1214cbda3c2b22d2d", size = 1557311 },
|
{ url = "https://files.pythonhosted.org/packages/e0/30/836278675205d58c1368b21520eab9572457cf19afd23759216c04483048/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:90c06228a6c3a7c9f776fe4fc0b7ff647fffd3bed93779a6913c804ae00c1073", size = 1566359 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/57/ae/76177b15f18c5f5d094f19901d284025db28eccc5ae374d1d254181d33f4/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:ec707059ee75732b1ba130ed5f9580fe10ff75180c812bc267ded039db5128c6", size = 1773147 },
|
{ url = "https://files.pythonhosted.org/packages/50/b4/8032cc9b82d17e4277704ba30509eaccb39329dc18d6a35f05e424439e32/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:a533ec132f05fd9a1d959e7f34184cd7d5e8511584848dab85faefbaac573069", size = 1785537 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/01/a4/62f05a0a98d88af59d93b7fcac564e5f18f513cb7471696ac286db970d6a/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2d6d44a5b48132053c2f6cd5c8cb14bc67e99a63594e336b0f2af81e94d5530c", size = 1730356 },
|
{ url = "https://files.pythonhosted.org/packages/17/7d/5873e98230bde59f493bf1f7c3e327486a4b5653fa401144704df5d00211/aiohttp-3.13.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1c946f10f413836f82ea4cfb90200d2a59578c549f00857e03111cf45ad01ca5", size = 1740752 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/e4/85/fc8601f59dfa8c9523808281f2da571f8b4699685f9809a228adcc90838d/aiohttp-3.13.5-cp313-cp313-win32.whl", hash = "sha256:329f292ed14d38a6c4c435e465f48bebb47479fd676a0411936cc371643225cc", size = 432637 },
|
{ url = "https://files.pythonhosted.org/packages/7b/f2/13e46e0df051494d7d3c68b7f72d071f48c384c12716fc294f75d5b1a064/aiohttp-3.13.4-cp313-cp313-win32.whl", hash = "sha256:48708e2706106da6967eff5908c78ca3943f005ed6bcb75da2a7e4da94ef8c70", size = 433187 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/c0/1b/ac685a8882896acf0f6b31d689e3792199cfe7aba37969fa91da63a7fa27/aiohttp-3.13.5-cp313-cp313-win_amd64.whl", hash = "sha256:69f571de7500e0557801c0b51f4780482c0ec5fe2ac851af5a92cfce1af1cb83", size = 458896 },
|
{ url = "https://files.pythonhosted.org/packages/ea/c0/649856ee655a843c8f8664592cfccb73ac80ede6a8c8db33a25d810c12db/aiohttp-3.13.4-cp313-cp313-win_amd64.whl", hash = "sha256:74a2eb058da44fa3a877a49e2095b591d4913308bb424c418b77beb160c55ce3", size = 459778 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/5d/ce/46572759afc859e867a5bc8ec3487315869013f59281ce61764f76d879de/aiohttp-3.13.5-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:eb4639f32fd4a9904ab8fb45bf3383ba71137f3d9d4ba25b3b3f3109977c5b8c", size = 745721 },
|
{ url = "https://files.pythonhosted.org/packages/6d/29/6657cc37ae04cacc2dbf53fb730a06b6091cc4cbe745028e047c53e6d840/aiohttp-3.13.4-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:e0a2c961fc92abeff61d6444f2ce6ad35bb982db9fc8ff8a47455beacf454a57", size = 749363 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/13/fe/8a2efd7626dbe6049b2ef8ace18ffda8a4dfcbe1bcff3ac30c0c7575c20b/aiohttp-3.13.5-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:7e5dc4311bd5ac493886c63cbf76ab579dbe4641268e7c74e48e774c74b6f2be", size = 497663 },
|
{ url = "https://files.pythonhosted.org/packages/90/7f/30ccdf67ca3d24b610067dc63d64dcb91e5d88e27667811640644aa4a85d/aiohttp-3.13.4-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:153274535985a0ff2bff1fb6c104ed547cec898a09213d21b0f791a44b14d933", size = 499317 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/9b/91/cc8cc78a111826c54743d88651e1687008133c37e5ee615fee9b57990fac/aiohttp-3.13.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:756c3c304d394977519824449600adaf2be0ccee76d206ee339c5e76b70ded25", size = 499094 },
|
{ url = "https://files.pythonhosted.org/packages/93/13/e372dd4e68ad04ee25dafb050c7f98b0d91ea643f7352757e87231102555/aiohttp-3.13.4-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:351f3171e2458da3d731ce83f9e6b9619e325c45cbd534c7759750cabf453ad7", size = 500477 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/0a/33/a8362cb15cf16a3af7e86ed11962d5cd7d59b449202dc576cdc731310bde/aiohttp-3.13.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ecc26751323224cf8186efcf7fbcbc30f4e1d8c7970659daf25ad995e4032a56", size = 1726701 },
|
{ url = "https://files.pythonhosted.org/packages/e5/fe/ee6298e8e586096fb6f5eddd31393d8544f33ae0792c71ecbb4c2bef98ac/aiohttp-3.13.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f989ac8bc5595ff761a5ccd32bdb0768a117f36dd1504b1c2c074ed5d3f4df9c", size = 1737227 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/45/0c/c091ac5c3a17114bd76cbf85d674650969ddf93387876cf67f754204bd77/aiohttp-3.13.5-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:10a75acfcf794edf9d8db50e5a7ec5fc818b2a8d3f591ce93bc7b1210df016d2", size = 1683360 },
|
{ url = "https://files.pythonhosted.org/packages/b0/b9/a7a0463a09e1a3fe35100f74324f23644bfc3383ac5fd5effe0722a5f0b7/aiohttp-3.13.4-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:d36fc1709110ec1e87a229b201dd3ddc32aa01e98e7868083a794609b081c349", size = 1694036 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/23/73/bcee1c2b79bc275e964d1446c55c54441a461938e70267c86afaae6fba27/aiohttp-3.13.5-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:0f7a18f258d124cd678c5fe072fe4432a4d5232b0657fca7c1847f599233c83a", size = 1773023 },
|
{ url = "https://files.pythonhosted.org/packages/57/7c/8972ae3fb7be00a91aee6b644b2a6a909aedb2c425269a3bfd90115e6f8f/aiohttp-3.13.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:42adaeea83cbdf069ab94f5103ce0787c21fb1a0153270da76b59d5578302329", size = 1786814 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/c7/ef/720e639df03004fee2d869f771799d8c23046dec47d5b81e396c7cda583a/aiohttp-3.13.5-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:df6104c009713d3a89621096f3e3e88cc323fd269dbd7c20afe18535094320be", size = 1853795 },
|
{ url = "https://files.pythonhosted.org/packages/93/01/c81e97e85c774decbaf0d577de7d848934e8166a3a14ad9f8aa5be329d28/aiohttp-3.13.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:92deb95469928cc41fd4b42a95d8012fa6df93f6b1c0a83af0ffbc4a5e218cde", size = 1866676 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/bd/c9/989f4034fb46841208de7aeeac2c6d8300745ab4f28c42f629ba77c2d916/aiohttp-3.13.5-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:241a94f7de7c0c3b616627aaad530fe2cb620084a8b144d3be7b6ecfe95bae3b", size = 1730405 },
|
{ url = "https://files.pythonhosted.org/packages/5a/5f/5b46fe8694a639ddea2cd035bf5729e4677ea882cb251396637e2ef1590d/aiohttp-3.13.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0c0c7c07c4257ef3a1df355f840bc62d133bcdef5c1c5ba75add3c08553e2eed", size = 1740842 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/ce/75/ee1fd286ca7dc599d824b5651dad7b3be7ff8d9a7e7b3fe9820d9180f7db/aiohttp-3.13.5-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c974fb66180e58709b6fc402846f13791240d180b74de81d23913abe48e96d94", size = 1558082 },
|
{ url = "https://files.pythonhosted.org/packages/20/a2/0d4b03d011cca6b6b0acba8433193c1e484efa8d705ea58295590fe24203/aiohttp-3.13.4-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f062c45de8a1098cb137a1898819796a2491aec4e637a06b03f149315dff4d8f", size = 1566508 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/c3/20/1e9e6650dfc436340116b7aa89ff8cb2bbdf0abc11dfaceaad8f74273a10/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:6e27ea05d184afac78aabbac667450c75e54e35f62238d44463131bd3f96753d", size = 1692346 },
|
{ url = "https://files.pythonhosted.org/packages/98/17/e689fd500da52488ec5f889effd6404dece6a59de301e380f3c64f167beb/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:76093107c531517001114f0ebdb4f46858ce818590363e3e99a4a2280334454a", size = 1700569 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/d8/40/8ebc6658d48ea630ac7903912fe0dd4e262f0e16825aa4c833c56c9f1f56/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:a79a6d399cef33a11b6f004c67bb07741d91f2be01b8d712d52c75711b1e07c7", size = 1698891 },
|
{ url = "https://files.pythonhosted.org/packages/d8/0d/66402894dbcf470ef7db99449e436105ea862c24f7ea4c95c683e635af35/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:6f6ec32162d293b82f8b63a16edc80769662fbd5ae6fbd4936d3206a2c2cc63b", size = 1707407 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/d8/78/ea0ae5ec8ba7a5c10bdd6e318f1ba5e76fcde17db8275188772afc7917a4/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:c632ce9c0b534fbe25b52c974515ed674937c5b99f549a92127c85f771a78772", size = 1742113 },
|
{ url = "https://files.pythonhosted.org/packages/2f/eb/af0ab1a3650092cbd8e14ef29e4ab0209e1460e1c299996c3f8288b3f1ff/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:5903e2db3d202a00ad9f0ec35a122c005e85d90c9836ab4cda628f01edf425e2", size = 1752214 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/8a/66/9d308ed71e3f2491be1acb8769d96c6f0c47d92099f3bc9119cada27b357/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:fceedde51fbd67ee2bcc8c0b33d0126cc8b51ef3bbde2f86662bd6d5a6f10ec5", size = 1553088 },
|
{ url = "https://files.pythonhosted.org/packages/5a/bf/72326f8a98e4c666f292f03c385545963cc65e358835d2a7375037a97b57/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:2d5bea57be7aca98dbbac8da046d99b5557c5cf4e28538c4c786313078aca09e", size = 1562162 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/da/a6/6cc25ed8dfc6e00c90f5c6d126a98e2cf28957ad06fa1036bd34b6f24a2c/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:f92995dfec9420bb69ae629abf422e516923ba79ba4403bc750d94fb4a6c68c1", size = 1757976 },
|
{ url = "https://files.pythonhosted.org/packages/67/9f/13b72435f99151dd9a5469c96b3b5f86aa29b7e785ca7f35cf5e538f74c0/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:bcf0c9902085976edc0232b75006ef38f89686901249ce14226b6877f88464fb", size = 1768904 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/c1/2b/cce5b0ffe0de99c83e5e36d8f828e4161e415660a9f3e58339d07cce3006/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:20ae0ff08b1f2c8788d6fb85afcb798654ae6ba0b747575f8562de738078457b", size = 1712444 },
|
{ url = "https://files.pythonhosted.org/packages/18/bc/28d4970e7d5452ac7776cdb5431a1164a0d9cf8bd2fffd67b4fb463aa56d/aiohttp-3.13.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:c3295f98bfeed2e867cab588f2a146a9db37a85e3ae9062abf46ba062bd29165", size = 1723378 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/6c/cf/9e1795b4160c58d29421eafd1a69c6ce351e2f7c8d3c6b7e4ca44aea1a5b/aiohttp-3.13.5-cp314-cp314-win32.whl", hash = "sha256:b20df693de16f42b2472a9c485e1c948ee55524786a0a34345511afdd22246f3", size = 438128 },
|
{ url = "https://files.pythonhosted.org/packages/53/74/b32458ca1a7f34d65bdee7aef2036adbe0438123d3d53e2b083c453c24dd/aiohttp-3.13.4-cp314-cp314-win32.whl", hash = "sha256:a598a5c5767e1369d8f5b08695cab1d8160040f796c4416af76fd773d229b3c9", size = 438711 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/22/4d/eaedff67fc805aeba4ba746aec891b4b24cebb1a7d078084b6300f79d063/aiohttp-3.13.5-cp314-cp314-win_amd64.whl", hash = "sha256:f85c6f327bf0b8c29da7d93b1cabb6363fb5e4e160a32fa241ed2dce21b73162", size = 464029 },
|
{ url = "https://files.pythonhosted.org/packages/40/b2/54b487316c2df3e03a8f3435e9636f8a81a42a69d942164830d193beb56a/aiohttp-3.13.4-cp314-cp314-win_amd64.whl", hash = "sha256:c555db4bc7a264bead5a7d63d92d41a1122fcd39cc62a4db815f45ad46f9c2c8", size = 464977 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/79/11/c27d9332ee20d68dd164dc12a6ecdef2e2e35ecc97ed6cf0d2442844624b/aiohttp-3.13.5-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:1efb06900858bb618ff5cee184ae2de5828896c448403d51fb633f09e109be0a", size = 778758 },
|
{ url = "https://files.pythonhosted.org/packages/47/fb/e41b63c6ce71b07a59243bb8f3b457ee0c3402a619acb9d2c0d21ef0e647/aiohttp-3.13.4-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:45abbbf09a129825d13c18c7d3182fecd46d9da3cfc383756145394013604ac1", size = 781549 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/04/fb/377aead2e0a3ba5f09b7624f702a964bdf4f08b5b6728a9799830c80041e/aiohttp-3.13.5-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:fee86b7c4bd29bdaf0d53d14739b08a106fdda809ca5fe032a15f52fae5fe254", size = 512883 },
|
{ url = "https://files.pythonhosted.org/packages/97/53/532b8d28df1e17e44c4d9a9368b78dcb6bf0b51037522136eced13afa9e8/aiohttp-3.13.4-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:74c80b2bc2c2adb7b3d1941b2b60701ee2af8296fc8aad8b8bc48bc25767266c", size = 514383 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/bb/a6/aa109a33671f7a5d3bd78b46da9d852797c5e665bfda7d6b373f56bff2ec/aiohttp-3.13.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:20058e23909b9e65f9da62b396b77dfa95965cbe840f8def6e572538b1d32e36", size = 516668 },
|
{ url = "https://files.pythonhosted.org/packages/1b/1f/62e5d400603e8468cd635812d99cb81cfdc08127a3dc474c647615f31339/aiohttp-3.13.4-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:c97989ae40a9746650fa196894f317dafc12227c808c774929dda0ff873a5954", size = 518304 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/79/b3/ca078f9f2fa9563c36fb8ef89053ea2bb146d6f792c5104574d49d8acb63/aiohttp-3.13.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8cf20a8d6868cb15a73cab329ffc07291ba8c22b1b88176026106ae39aa6df0f", size = 1883461 },
|
{ url = "https://files.pythonhosted.org/packages/90/57/2326b37b10896447e3c6e0cbef4fe2486d30913639a5cfd1332b5d870f82/aiohttp-3.13.4-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dae86be9811493f9990ef44fff1685f5c1a3192e9061a71a109d527944eed551", size = 1893433 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/b7/e3/a7ad633ca1ca497b852233a3cce6906a56c3225fb6d9217b5e5e60b7419d/aiohttp-3.13.5-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:330f5da04c987f1d5bdb8ae189137c77139f36bd1cb23779ca1a354a4b027800", size = 1747661 },
|
{ url = "https://files.pythonhosted.org/packages/d2/b4/a24d82112c304afdb650167ef2fe190957d81cbddac7460bedd245f765aa/aiohttp-3.13.4-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:1db491abe852ca2fa6cc48a3341985b0174b3741838e1341b82ac82c8bd9e871", size = 1755901 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/33/b9/cd6fe579bed34a906d3d783fe60f2fa297ef55b27bb4538438ee49d4dc41/aiohttp-3.13.5-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:6f1cbf0c7926d315c3c26c2da41fd2b5d2fe01ac0e157b78caefc51a782196cf", size = 1863800 },
|
{ url = "https://files.pythonhosted.org/packages/9e/2d/0883ef9d878d7846287f036c162a951968f22aabeef3ac97b0bea6f76d5d/aiohttp-3.13.4-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:0e5d701c0aad02a7dce72eef6b93226cf3734330f1a31d69ebbf69f33b86666e", size = 1876093 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/c0/3f/2c1e2f5144cefa889c8afd5cf431994c32f3b29da9961698ff4e3811b79a/aiohttp-3.13.5-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:53fc049ed6390d05423ba33103ded7281fe897cf97878f369a527070bd95795b", size = 1958382 },
|
{ url = "https://files.pythonhosted.org/packages/ad/52/9204bb59c014869b71971addad6778f005daa72a96eed652c496789d7468/aiohttp-3.13.4-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8ac32a189081ae0a10ba18993f10f338ec94341f0d5df8fff348043962f3c6f8", size = 1970815 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/66/1d/f31ec3f1013723b3babe3609e7f119c2c2fb6ef33da90061a705ef3e1bc8/aiohttp-3.13.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:898703aa2667e3c5ca4c54ca36cd73f58b7a38ef87a5606414799ebce4d3fd3a", size = 1803724 },
|
{ url = "https://files.pythonhosted.org/packages/d6/b5/e4eb20275a866dde0f570f411b36c6b48f7b53edfe4f4071aa1b0728098a/aiohttp-3.13.4-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:98e968cdaba43e45c73c3f306fca418c8009a957733bac85937c9f9cf3f4de27", size = 1816223 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/0e/b4/57712dfc6f1542f067daa81eb61da282fab3e6f1966fca25db06c4fc62d5/aiohttp-3.13.5-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:0494a01ca9584eea1e5fbd6d748e61ecff218c51b576ee1999c23db7066417d8", size = 1640027 },
|
{ url = "https://files.pythonhosted.org/packages/d8/23/e98075c5bb146aa61a1239ee1ac7714c85e814838d6cebbe37d3fe19214a/aiohttp-3.13.4-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ca114790c9144c335d538852612d3e43ea0f075288f4849cf4b05d6cd2238ce7", size = 1649145 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/25/3c/734c878fb43ec083d8e31bf029daae1beafeae582d1b35da234739e82ee7/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:6cf81fe010b8c17b09495cbd15c1d35afbc8fb405c0c9cf4738e5ae3af1d65be", size = 1806644 },
|
{ url = "https://files.pythonhosted.org/packages/d6/c1/7bad8be33bb06c2bb224b6468874346026092762cbec388c3bdb65a368ee/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:ea2e071661ba9cfe11eabbc81ac5376eaeb3061f6e72ec4cc86d7cdd1ffbdbbb", size = 1816562 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/20/a5/f671e5cbec1c21d044ff3078223f949748f3a7f86b14e34a365d74a5d21f/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:c564dd5f09ddc9d8f2c2d0a301cd30a79a2cc1b46dd1a73bef8f0038863d016b", size = 1791630 },
|
{ url = "https://files.pythonhosted.org/packages/5c/10/c00323348695e9a5e316825969c88463dcc24c7e9d443244b8a2c9cf2eae/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:34e89912b6c20e0fd80e07fa401fd218a410aa1ce9f1c2f1dad6db1bd0ce0927", size = 1800333 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/0b/63/fb8d0ad63a0b8a99be97deac8c04dacf0785721c158bdf23d679a87aa99e/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:2994be9f6e51046c4f864598fd9abeb4fba6e88f0b2152422c9666dcd4aea9c6", size = 1809403 },
|
{ url = "https://files.pythonhosted.org/packages/84/43/9b2147a1df3559f49bd723e22905b46a46c068a53adb54abdca32c4de180/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:0e217cf9f6a42908c52b46e42c568bd57adc39c9286ced31aaace614b6087965", size = 1820617 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/59/0c/bfed7f30662fcf12206481c2aac57dedee43fe1c49275e85b3a1e1742294/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:157826e2fa245d2ef46c83ea8a5faf77ca19355d278d425c29fda0beb3318037", size = 1634924 },
|
{ url = "https://files.pythonhosted.org/packages/a9/7f/b3481a81e7a586d02e99387b18c6dafff41285f6efd3daa2124c01f87eae/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:0c296f1221e21ba979f5ac1964c3b78cfde15c5c5f855ffd2caab337e9cd9182", size = 1643417 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/17/d6/fd518d668a09fd5a3319ae5e984d4d80b9a4b3df4e21c52f02251ef5a32e/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:a8aca50daa9493e9e13c0f566201a9006f080e7c50e5e90d0b06f53146a54500", size = 1836119 },
|
{ url = "https://files.pythonhosted.org/packages/8f/72/07181226bc99ce1124e0f89280f5221a82d3ae6a6d9d1973ce429d48e52b/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:d99a9d168ebaffb74f36d011750e490085ac418f4db926cce3989c8fe6cb6b1b", size = 1849286 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/78/b7/15fb7a9d52e112a25b621c67b69c167805cb1f2ab8f1708a5c490d1b52fe/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:3b13560160d07e047a93f23aaa30718606493036253d5430887514715b67c9d9", size = 1772072 },
|
{ url = "https://files.pythonhosted.org/packages/1a/e6/1b3566e103eca6da5be4ae6713e112a053725c584e96574caf117568ffef/aiohttp-3.13.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:cb19177205d93b881f3f89e6081593676043a6828f59c78c17a0fd6c1fbed2ba", size = 1782635 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/7e/df/57ba7f0c4a553fc2bd8b6321df236870ec6fd64a2a473a8a13d4f733214e/aiohttp-3.13.5-cp314-cp314t-win32.whl", hash = "sha256:9a0f4474b6ea6818b41f82172d799e4b3d29e22c2c520ce4357856fced9af2f8", size = 471819 },
|
{ url = "https://files.pythonhosted.org/packages/37/58/1b11c71904b8d079eb0c39fe664180dd1e14bebe5608e235d8bfbadc8929/aiohttp-3.13.4-cp314-cp314t-win32.whl", hash = "sha256:c606aa5656dab6552e52ca368e43869c916338346bfaf6304e15c58fb113ea30", size = 472537 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/62/29/2f8418269e46454a26171bfdd6a055d74febf32234e474930f2f60a17145/aiohttp-3.13.5-cp314-cp314t-win_amd64.whl", hash = "sha256:18a2f6c1182c51baa1d28d68fea51513cb2a76612f038853c0ad3c145423d3d9", size = 505441 },
|
{ url = "https://files.pythonhosted.org/packages/bc/8f/87c56a1a1977d7dddea5b31e12189665a140fdb48a71e9038ff90bb564ec/aiohttp-3.13.4-cp314-cp314t-win_amd64.whl", hash = "sha256:014dcc10ec8ab8db681f0d68e939d1e9286a5aa2b993cbbdb0db130853e02144", size = 506381 },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -3723,7 +3723,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "1.83.4"
|
version = "1.83.14"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "aiohttp" },
|
{ name = "aiohttp" },
|
||||||
|
|
@ -3739,9 +3739,9 @@ dependencies = [
|
||||||
{ name = "tiktoken" },
|
{ name = "tiktoken" },
|
||||||
{ name = "tokenizers" },
|
{ name = "tokenizers" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/03/c4/30469c06ae7437a4406bc11e3c433cfd380a6771068cca15ea918dcd158f/litellm-1.83.4.tar.gz", hash = "sha256:6458d2030a41229460b321adee00517a91dbd8e63213cc953d355cb41d16f2d4", size = 17733899 }
|
sdist = { url = "https://files.pythonhosted.org/packages/8d/7c/c095649380adc96c8630273c1768c2ad1e74aa2ee1dd8dd05d218a60569f/litellm-1.83.14.tar.gz", hash = "sha256:24aef9b47cdc424c833e32f3727f411741c690832cd1fe4405e0077144fe09c9", size = 14836599 }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/b8/bd/df19d3f8f6654535ee343a341fd921f81c411abf601a53e3eaef58129b02/litellm-1.83.4-py3-none-any.whl", hash = "sha256:17d7b4d48d47aca988ea4f762ddda5e7bd72cda3270192b22813d0330869d7b4", size = 16015555 },
|
{ url = "https://files.pythonhosted.org/packages/7f/5c/1b5691575420135e90578543b2bf219497caa33cfd0af64cb38f30288450/litellm-1.83.14-py3-none-any.whl", hash = "sha256:92b11ba2a32cf80707ddf388d18526696c7999a21b418c5e3b6eda1243d2cfdb", size = 16457054 },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -5124,7 +5124,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "openai"
|
name = "openai"
|
||||||
version = "2.30.0"
|
version = "2.24.0"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "anyio" },
|
{ name = "anyio" },
|
||||||
|
|
@ -5136,9 +5136,9 @@ dependencies = [
|
||||||
{ name = "tqdm" },
|
{ name = "tqdm" },
|
||||||
{ name = "typing-extensions" },
|
{ name = "typing-extensions" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/88/15/52580c8fbc16d0675d516e8749806eda679b16de1e4434ea06fb6feaa610/openai-2.30.0.tar.gz", hash = "sha256:92f7661c990bda4b22a941806c83eabe4896c3094465030dd882a71abe80c885", size = 676084 }
|
sdist = { url = "https://files.pythonhosted.org/packages/55/13/17e87641b89b74552ed408a92b231283786523edddc95f3545809fab673c/openai-2.24.0.tar.gz", hash = "sha256:1e5769f540dbd01cb33bc4716a23e67b9d695161a734aff9c5f925e2bf99a673", size = 658717 }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/2a/9e/5bfa2270f902d5b92ab7d41ce0475b8630572e71e349b2a4996d14bdda93/openai-2.30.0-py3-none-any.whl", hash = "sha256:9a5ae616888eb2748ec5e0c5b955a51592e0b201a11f4262db920f2a78c5231d", size = 1146656 },
|
{ url = "https://files.pythonhosted.org/packages/c9/30/844dc675ee6902579b8eef01ed23917cc9319a1c9c0c14ec6e39340c96d0/openai-2.24.0-py3-none-any.whl", hash = "sha256:fed30480d7d6c884303287bde864980a4b137b60553ffbcf9ab4a233b7a73d94", size = 1120122 },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -6780,11 +6780,11 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "python-dotenv"
|
name = "python-dotenv"
|
||||||
version = "1.0.1"
|
version = "1.2.2"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/bc/57/e84d88dfe0aec03b7a2d4327012c1627ab5f03652216c63d49846d7a6c58/python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca", size = 39115 }
|
sdist = { url = "https://files.pythonhosted.org/packages/82/ed/0301aeeac3e5353ef3d94b6ec08bbcabd04a72018415dcb29e588514bba8/python_dotenv-1.2.2.tar.gz", hash = "sha256:2c371a91fbd7ba082c2c1dc1f8bf89ca22564a087c2c287cd9b662adde799cf3", size = 50135 }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863 },
|
{ url = "https://files.pythonhosted.org/packages/0b/d7/1959b9648791274998a9c3526f6d0ec8fd2233e4d4acce81bbae76b44b2a/python_dotenv-1.2.2-py3-none-any.whl", hash = "sha256:1d8214789a24de455a8b8bd8ae6fe3c6b69a5e3d64aa8a8e5d68e694bbcb285a", size = 22101 },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -8070,7 +8070,7 @@ requires-dist = [
|
||||||
{ name = "langgraph", specifier = ">=1.1.3" },
|
{ name = "langgraph", specifier = ">=1.1.3" },
|
||||||
{ name = "langgraph-checkpoint-postgres", specifier = ">=3.0.2" },
|
{ name = "langgraph-checkpoint-postgres", specifier = ">=3.0.2" },
|
||||||
{ name = "linkup-sdk", specifier = ">=0.2.4" },
|
{ name = "linkup-sdk", specifier = ">=0.2.4" },
|
||||||
{ name = "litellm", specifier = ">=1.83.4" },
|
{ name = "litellm", specifier = ">=1.83.7" },
|
||||||
{ name = "llama-cloud-services", specifier = ">=0.6.25" },
|
{ name = "llama-cloud-services", specifier = ">=0.6.25" },
|
||||||
{ name = "markdown", specifier = ">=3.7" },
|
{ name = "markdown", specifier = ">=3.7" },
|
||||||
{ name = "markdownify", specifier = ">=0.14.1" },
|
{ name = "markdownify", specifier = ">=0.14.1" },
|
||||||
|
|
|
||||||
35
surfsense_desktop/build/entitlements.mac.plist
Normal file
35
surfsense_desktop/build/entitlements.mac.plist
Normal file
|
|
@ -0,0 +1,35 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||||
|
<plist version="1.0">
|
||||||
|
<dict>
|
||||||
|
<!-- Required for Electron's V8 JIT under hardened runtime -->
|
||||||
|
<key>com.apple.security.cs.allow-jit</key>
|
||||||
|
<true/>
|
||||||
|
<key>com.apple.security.cs.allow-unsigned-executable-memory</key>
|
||||||
|
<true/>
|
||||||
|
|
||||||
|
<!-- node-mac-permissions and other native deps load dylibs at runtime -->
|
||||||
|
<key>com.apple.security.cs.allow-dyld-environment-variables</key>
|
||||||
|
<true/>
|
||||||
|
<key>com.apple.security.cs.disable-library-validation</key>
|
||||||
|
<true/>
|
||||||
|
|
||||||
|
<!-- Networking (OAuth, API calls, auto-updater, deep links) -->
|
||||||
|
<key>com.apple.security.network.client</key>
|
||||||
|
<true/>
|
||||||
|
<key>com.apple.security.network.server</key>
|
||||||
|
<true/>
|
||||||
|
|
||||||
|
<!-- Screen Capture / Screenshot Assist -->
|
||||||
|
<key>com.apple.security.device.camera</key>
|
||||||
|
<true/>
|
||||||
|
|
||||||
|
<!-- Accessibility / Apple Events used by general-assist -->
|
||||||
|
<key>com.apple.security.automation.apple-events</key>
|
||||||
|
<true/>
|
||||||
|
|
||||||
|
<!-- File access for folder watcher / agent filesystem features -->
|
||||||
|
<key>com.apple.security.files.user-selected.read-write</key>
|
||||||
|
<true/>
|
||||||
|
</dict>
|
||||||
|
</plist>
|
||||||
|
|
@ -46,8 +46,11 @@ mac:
|
||||||
icon: assets/icon.icns
|
icon: assets/icon.icns
|
||||||
category: public.app-category.productivity
|
category: public.app-category.productivity
|
||||||
artifactName: "${productName}-${version}-${arch}.${ext}"
|
artifactName: "${productName}-${version}-${arch}.${ext}"
|
||||||
hardenedRuntime: false
|
hardenedRuntime: true
|
||||||
gatekeeperAssess: false
|
gatekeeperAssess: false
|
||||||
|
entitlements: build/entitlements.mac.plist
|
||||||
|
entitlementsInherit: build/entitlements.mac.plist
|
||||||
|
notarize: true
|
||||||
extendInfo:
|
extendInfo:
|
||||||
NSAccessibilityUsageDescription: "SurfSense uses accessibility features to bring the app to the foreground and interact with the active application when you use desktop assists."
|
NSAccessibilityUsageDescription: "SurfSense uses accessibility features to bring the app to the foreground and interact with the active application when you use desktop assists."
|
||||||
NSScreenCaptureUsageDescription: "SurfSense uses screen capture so you can attach a selected region to chat (Screenshot Assist) or capture the full screen from the composer."
|
NSScreenCaptureUsageDescription: "SurfSense uses screen capture so you can attach a selected region to chat (Screenshot Assist) or capture the full screen from the composer."
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,11 +1,8 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
|
||||||
import { CheckCircle2 } from "lucide-react";
|
import { CheckCircle2 } from "lucide-react";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { useParams } from "next/navigation";
|
import { useParams } from "next/navigation";
|
||||||
import { useEffect } from "react";
|
|
||||||
import { USER_QUERY_KEY } from "@/atoms/user/user-query.atoms";
|
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import {
|
import {
|
||||||
Card,
|
Card,
|
||||||
|
|
@ -18,14 +15,8 @@ import {
|
||||||
|
|
||||||
export default function PurchaseSuccessPage() {
|
export default function PurchaseSuccessPage() {
|
||||||
const params = useParams();
|
const params = useParams();
|
||||||
const queryClient = useQueryClient();
|
|
||||||
const searchSpaceId = String(params.search_space_id ?? "");
|
const searchSpaceId = String(params.search_space_id ?? "");
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
void queryClient.invalidateQueries({ queryKey: USER_QUERY_KEY });
|
|
||||||
void queryClient.invalidateQueries({ queryKey: ["token-status"] });
|
|
||||||
}, [queryClient]);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex min-h-[calc(100vh-64px)] items-center justify-center px-4 py-8">
|
<div className="flex min-h-[calc(100vh-64px)] items-center justify-center px-4 py-8">
|
||||||
<Card className="w-full max-w-lg">
|
<Card className="w-full max-w-lg">
|
||||||
|
|
|
||||||
|
|
@ -132,8 +132,8 @@ export default function DesktopPermissionsPage() {
|
||||||
<div className="space-y-1">
|
<div className="space-y-1">
|
||||||
<h1 className="text-2xl font-semibold tracking-tight">System Permissions</h1>
|
<h1 className="text-2xl font-semibold tracking-tight">System Permissions</h1>
|
||||||
<p className="text-sm text-muted-foreground">
|
<p className="text-sm text-muted-foreground">
|
||||||
SurfSense needs two macOS permissions for Screenshot Assist and for desktop features that
|
SurfSense needs two macOS permissions for Screenshot Assist and for desktop features
|
||||||
require focusing the app or the active application.
|
that require focusing the app or the active application.
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,14 @@ export const setThreadVisibilityAtom = atom(null, (get, set, newVisibility: Chat
|
||||||
|
|
||||||
export const resetCurrentThreadAtom = atom(null, (_, set) => {
|
export const resetCurrentThreadAtom = atom(null, (_, set) => {
|
||||||
set(currentThreadAtom, initialState);
|
set(currentThreadAtom, initialState);
|
||||||
set(reportPanelAtom, { isOpen: false, reportId: null, title: null, wordCount: null });
|
set(reportPanelAtom, {
|
||||||
|
isOpen: false,
|
||||||
|
reportId: null,
|
||||||
|
title: null,
|
||||||
|
wordCount: null,
|
||||||
|
shareToken: null,
|
||||||
|
contentType: "markdown",
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
/** Target comment ID to scroll to (from URL navigation or inbox click) */
|
/** Target comment ID to scroll to (from URL navigation or inbox click) */
|
||||||
|
|
|
||||||
45
surfsense_web/atoms/chat/premium-alert.atom.ts
Normal file
45
surfsense_web/atoms/chat/premium-alert.atom.ts
Normal file
|
|
@ -0,0 +1,45 @@
|
||||||
|
import { atom } from "jotai";
|
||||||
|
|
||||||
|
export type PremiumAlertState = {
|
||||||
|
message: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const premiumAlertByThreadAtom = atom<Record<number, PremiumAlertState>>({});
|
||||||
|
|
||||||
|
export const setPremiumAlertForThreadAtom = atom(
|
||||||
|
null,
|
||||||
|
(
|
||||||
|
get,
|
||||||
|
set,
|
||||||
|
payload: {
|
||||||
|
threadId: number;
|
||||||
|
message: string;
|
||||||
|
userId?: string | null;
|
||||||
|
}
|
||||||
|
) => {
|
||||||
|
const storageKey = `surfsense-premium-alert-seen-v1:${payload.userId ?? "anonymous"}`;
|
||||||
|
|
||||||
|
if (typeof window !== "undefined") {
|
||||||
|
const hasSeen = localStorage.getItem(storageKey) === "true";
|
||||||
|
if (hasSeen) return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const current = get(premiumAlertByThreadAtom);
|
||||||
|
set(premiumAlertByThreadAtom, {
|
||||||
|
...current,
|
||||||
|
[payload.threadId]: { message: payload.message },
|
||||||
|
});
|
||||||
|
|
||||||
|
if (typeof window !== "undefined") {
|
||||||
|
localStorage.setItem(storageKey, "true");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
export const clearPremiumAlertForThreadAtom = atom(null, (get, set, threadId: number) => {
|
||||||
|
const current = get(premiumAlertByThreadAtom);
|
||||||
|
if (!(threadId in current)) return;
|
||||||
|
const next = { ...current };
|
||||||
|
delete next[threadId];
|
||||||
|
set(premiumAlertByThreadAtom, next);
|
||||||
|
});
|
||||||
|
|
@ -8,7 +8,10 @@ const userQueryFn = () => userApiService.getMe();
|
||||||
export const currentUserAtom = atomWithQuery(() => {
|
export const currentUserAtom = atomWithQuery(() => {
|
||||||
return {
|
return {
|
||||||
queryKey: USER_QUERY_KEY,
|
queryKey: USER_QUERY_KEY,
|
||||||
staleTime: 5 * 60 * 1000,
|
// Live-changing numeric fields (pages_*, premium_tokens_*) are now
|
||||||
|
// pushed via Zero (queries.user.me()), so /users/me only needs to
|
||||||
|
// fire once per session for the static profile fields.
|
||||||
|
staleTime: Infinity,
|
||||||
enabled: !!getBearerToken(),
|
enabled: !!getBearerToken(),
|
||||||
queryFn: userQueryFn,
|
queryFn: userQueryFn,
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -17,16 +17,12 @@ import {
|
||||||
import { Badge } from "@/components/ui/badge";
|
import { Badge } from "@/components/ui/badge";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import { Separator } from "@/components/ui/separator";
|
import { Separator } from "@/components/ui/separator";
|
||||||
import { getToolIcon } from "@/contracts/enums/toolIcons";
|
import { getToolDisplayName, getToolIcon } from "@/contracts/enums/toolIcons";
|
||||||
import { type AgentAction, agentActionsApiService } from "@/lib/apis/agent-actions-api.service";
|
import { type AgentAction, agentActionsApiService } from "@/lib/apis/agent-actions-api.service";
|
||||||
import { AppError } from "@/lib/error";
|
import { AppError } from "@/lib/error";
|
||||||
import { formatRelativeDate } from "@/lib/format-date";
|
import { formatRelativeDate } from "@/lib/format-date";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
function formatToolName(name: string): string {
|
|
||||||
return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase());
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ActionLogItemProps {
|
interface ActionLogItemProps {
|
||||||
action: AgentAction;
|
action: AgentAction;
|
||||||
threadId: number;
|
threadId: number;
|
||||||
|
|
@ -43,7 +39,7 @@ export function ActionLogItem({ action, threadId, onRevertSuccess }: ActionLogIt
|
||||||
const hasError = action.error !== null && action.error !== undefined;
|
const hasError = action.error !== null && action.error !== undefined;
|
||||||
|
|
||||||
const Icon = getToolIcon(action.tool_name);
|
const Icon = getToolIcon(action.tool_name);
|
||||||
const displayName = formatToolName(action.tool_name);
|
const displayName = getToolDisplayName(action.tool_name);
|
||||||
|
|
||||||
const argsPreview = action.args ? JSON.stringify(action.args, null, 2) : null;
|
const argsPreview = action.args ? JSON.stringify(action.args, null, 2) : null;
|
||||||
const truncatedArgs =
|
const truncatedArgs =
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { useQuery, useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { useAtom, useAtomValue } from "jotai";
|
import { useAtom, useAtomValue } from "jotai";
|
||||||
import { Activity, RefreshCcw } from "lucide-react";
|
import { Activity, RefreshCcw } from "lucide-react";
|
||||||
import { useCallback, useMemo } from "react";
|
import { useCallback } from "react";
|
||||||
import { actionLogSheetAtom } from "@/atoms/agent/action-log-sheet.atom";
|
import { actionLogSheetAtom } from "@/atoms/agent/action-log-sheet.atom";
|
||||||
import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom";
|
import { agentFlagsAtom } from "@/atoms/agent/agent-flags-query.atom";
|
||||||
import { Badge } from "@/components/ui/badge";
|
import { Badge } from "@/components/ui/badge";
|
||||||
|
|
@ -17,15 +17,9 @@ import {
|
||||||
SheetTitle,
|
SheetTitle,
|
||||||
} from "@/components/ui/sheet";
|
} from "@/components/ui/sheet";
|
||||||
import { Skeleton } from "@/components/ui/skeleton";
|
import { Skeleton } from "@/components/ui/skeleton";
|
||||||
import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service";
|
import { agentActionsQueryKey, useAgentActionsQuery } from "@/hooks/use-agent-actions-query";
|
||||||
import { ActionLogItem } from "./action-log-item";
|
import { ActionLogItem } from "./action-log-item";
|
||||||
|
|
||||||
const ACTION_LOG_PAGE_SIZE = 50;
|
|
||||||
|
|
||||||
function actionLogQueryKey(threadId: number) {
|
|
||||||
return ["agent-actions", threadId] as const;
|
|
||||||
}
|
|
||||||
|
|
||||||
function EmptyState() {
|
function EmptyState() {
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-1 flex-col items-center justify-center gap-3 px-6 text-center">
|
<div className="flex flex-1 flex-col items-center justify-center gap-3 px-6 text-center">
|
||||||
|
|
@ -85,25 +79,17 @@ export function ActionLogSheet() {
|
||||||
|
|
||||||
const threadId = state.threadId;
|
const threadId = state.threadId;
|
||||||
|
|
||||||
const { data, isLoading, isFetching, isError, error, refetch } = useQuery({
|
const { data, items, isLoading, isFetching, isError, error, refetch } = useAgentActionsQuery(
|
||||||
queryKey: threadId !== null ? actionLogQueryKey(threadId) : ["agent-actions", "none"],
|
threadId,
|
||||||
queryFn: () =>
|
{ enabled: state.open && actionLogEnabled }
|
||||||
agentActionsApiService.listForThread(threadId as number, {
|
);
|
||||||
page: 0,
|
|
||||||
pageSize: ACTION_LOG_PAGE_SIZE,
|
|
||||||
}),
|
|
||||||
enabled: state.open && threadId !== null && actionLogEnabled,
|
|
||||||
staleTime: 15 * 1000,
|
|
||||||
});
|
|
||||||
|
|
||||||
const handleRevertSuccess = useCallback(() => {
|
const handleRevertSuccess = useCallback(() => {
|
||||||
if (threadId !== null) {
|
if (threadId !== null) {
|
||||||
queryClient.invalidateQueries({ queryKey: actionLogQueryKey(threadId) });
|
queryClient.invalidateQueries({ queryKey: agentActionsQueryKey(threadId) });
|
||||||
}
|
}
|
||||||
}, [queryClient, threadId]);
|
}, [queryClient, threadId]);
|
||||||
|
|
||||||
const items = useMemo(() => data?.items ?? [], [data]);
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Sheet open={state.open} onOpenChange={(open) => setState((s) => ({ ...s, open }))}>
|
<Sheet open={state.open} onOpenChange={(open) => setState((s) => ({ ...s, open }))}>
|
||||||
<SheetContent
|
<SheetContent
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,8 @@ import {
|
||||||
useAllCitationMetadata,
|
useAllCitationMetadata,
|
||||||
} from "@/components/assistant-ui/citation-metadata-context";
|
} from "@/components/assistant-ui/citation-metadata-context";
|
||||||
import { MarkdownText } from "@/components/assistant-ui/markdown-text";
|
import { MarkdownText } from "@/components/assistant-ui/markdown-text";
|
||||||
|
import { ReasoningMessagePart } from "@/components/assistant-ui/reasoning-message-part";
|
||||||
|
import { RevertTurnButton } from "@/components/assistant-ui/revert-turn-button";
|
||||||
import { useTokenUsage } from "@/components/assistant-ui/token-usage-context";
|
import { useTokenUsage } from "@/components/assistant-ui/token-usage-context";
|
||||||
import { ToolFallback } from "@/components/assistant-ui/tool-fallback";
|
import { ToolFallback } from "@/components/assistant-ui/tool-fallback";
|
||||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||||
|
|
@ -491,6 +493,7 @@ const AssistantMessageInner: FC = () => {
|
||||||
<MessagePrimitive.Parts
|
<MessagePrimitive.Parts
|
||||||
components={{
|
components={{
|
||||||
Text: MarkdownText,
|
Text: MarkdownText,
|
||||||
|
Reasoning: ReasoningMessagePart,
|
||||||
tools: {
|
tools: {
|
||||||
by_name: {
|
by_name: {
|
||||||
generate_report: GenerateReportToolUI,
|
generate_report: GenerateReportToolUI,
|
||||||
|
|
@ -545,9 +548,11 @@ const AssistantMessageInner: FC = () => {
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
<div className="aui-assistant-message-footer mt-3 mb-5 ml-2 flex items-center gap-2">
|
<div className="aui-assistant-message-footer mt-3 mb-5 ml-2 h-6">
|
||||||
|
<div className="h-full opacity-100 transition-opacity">
|
||||||
<AssistantActionBar />
|
<AssistantActionBar />
|
||||||
</div>
|
</div>
|
||||||
|
</div>
|
||||||
</CitationMetadataProvider>
|
</CitationMetadataProvider>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
@ -639,17 +644,24 @@ export const AssistantMessage: FC = () => {
|
||||||
className="aui-assistant-message-root group fade-in slide-in-from-bottom-1 relative mx-auto w-full max-w-(--thread-max-width) animate-in py-3 duration-150"
|
className="aui-assistant-message-root group fade-in slide-in-from-bottom-1 relative mx-auto w-full max-w-(--thread-max-width) animate-in py-3 duration-150"
|
||||||
data-role="assistant"
|
data-role="assistant"
|
||||||
>
|
>
|
||||||
{/* Comment trigger — right-aligned, just below user query on all screen sizes */}
|
{/* Fixed trigger slot prevents any vertical reflow when visibility changes */}
|
||||||
{showCommentTrigger && (
|
<div className="mr-2 mb-1 flex h-7 justify-end">
|
||||||
<div className="mr-2 mb-1 flex justify-end">
|
|
||||||
<button
|
<button
|
||||||
ref={isDesktop ? commentTriggerRef : undefined}
|
ref={isDesktop ? commentTriggerRef : undefined}
|
||||||
type="button"
|
type="button"
|
||||||
onClick={
|
onClick={
|
||||||
isDesktop ? () => setIsInlineOpen((prev) => !prev) : () => setIsSheetOpen(true)
|
showCommentTrigger
|
||||||
|
? isDesktop
|
||||||
|
? () => setIsInlineOpen((prev) => !prev)
|
||||||
|
: () => setIsSheetOpen(true)
|
||||||
|
: undefined
|
||||||
}
|
}
|
||||||
|
aria-hidden={!showCommentTrigger}
|
||||||
|
tabIndex={showCommentTrigger ? 0 : -1}
|
||||||
className={cn(
|
className={cn(
|
||||||
"flex items-center gap-1.5 rounded-full px-3 py-1 text-sm transition-colors",
|
"flex items-center gap-1.5 rounded-full px-3 py-1 text-sm transition-colors",
|
||||||
|
"opacity-0 pointer-events-none",
|
||||||
|
showCommentTrigger && "opacity-100 pointer-events-auto",
|
||||||
isDesktop && isInlineOpen
|
isDesktop && isInlineOpen
|
||||||
? "bg-primary/10 text-primary"
|
? "bg-primary/10 text-primary"
|
||||||
: hasComments
|
: hasComments
|
||||||
|
|
@ -667,7 +679,6 @@ export const AssistantMessage: FC = () => {
|
||||||
)}
|
)}
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
)}
|
|
||||||
|
|
||||||
{/* Desktop floating comment panel — overlays on top of chat content */}
|
{/* Desktop floating comment panel — overlays on top of chat content */}
|
||||||
{showCommentTrigger && isDesktop && isInlineOpen && dbMessageId && (
|
{showCommentTrigger && isDesktop && isInlineOpen && dbMessageId && (
|
||||||
|
|
@ -699,6 +710,13 @@ const AssistantActionBar: FC = () => {
|
||||||
const isLast = useAuiState((s) => s.message.isLast);
|
const isLast = useAuiState((s) => s.message.isLast);
|
||||||
const aui = useAui();
|
const aui = useAui();
|
||||||
const api = useElectronAPI();
|
const api = useElectronAPI();
|
||||||
|
// Surface the persisted ``chat_turn_id`` so the per-turn revert
|
||||||
|
// affordance can scope to just this message's actions. Streamed
|
||||||
|
// turns get their id once the assistant message is hydrated/finalised.
|
||||||
|
const chatTurnId = useAuiState(({ message }) => {
|
||||||
|
const meta = message?.metadata as { custom?: { chatTurnId?: string | null } } | undefined;
|
||||||
|
return meta?.custom?.chatTurnId ?? null;
|
||||||
|
});
|
||||||
|
|
||||||
const isQuickAssist = !!api?.replaceText && IS_QUICK_ASSIST_WINDOW;
|
const isQuickAssist = !!api?.replaceText && IS_QUICK_ASSIST_WINDOW;
|
||||||
|
|
||||||
|
|
@ -743,6 +761,9 @@ const AssistantActionBar: FC = () => {
|
||||||
</TooltipIconButton>
|
</TooltipIconButton>
|
||||||
)}
|
)}
|
||||||
<MessageInfoDropdown />
|
<MessageInfoDropdown />
|
||||||
|
<div className="ml-auto">
|
||||||
|
<RevertTurnButton chatTurnId={chatTurnId} />
|
||||||
|
</div>
|
||||||
</ActionBarPrimitive.Root>
|
</ActionBarPrimitive.Root>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
|
||||||
52
surfsense_web/components/assistant-ui/chat-viewport.tsx
Normal file
52
surfsense_web/components/assistant-ui/chat-viewport.tsx
Normal file
|
|
@ -0,0 +1,52 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { ThreadPrimitive } from "@assistant-ui/react";
|
||||||
|
import { ArrowDownIcon } from "lucide-react";
|
||||||
|
import type { FC, ReactNode } from "react";
|
||||||
|
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||||
|
|
||||||
|
const ChatScrollToBottom: FC = () => (
|
||||||
|
<ThreadPrimitive.ScrollToBottom asChild>
|
||||||
|
<TooltipIconButton
|
||||||
|
tooltip="Scroll to bottom"
|
||||||
|
variant="outline"
|
||||||
|
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
|
||||||
|
>
|
||||||
|
<ArrowDownIcon />
|
||||||
|
</TooltipIconButton>
|
||||||
|
</ThreadPrimitive.ScrollToBottom>
|
||||||
|
);
|
||||||
|
|
||||||
|
export interface ChatViewportProps {
|
||||||
|
children: ReactNode;
|
||||||
|
footer?: ReactNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const ChatViewport: FC<ChatViewportProps> = ({ children, footer }) => (
|
||||||
|
<ThreadPrimitive.Viewport
|
||||||
|
turnAnchor="top"
|
||||||
|
autoScroll
|
||||||
|
scrollToBottomOnRunStart
|
||||||
|
scrollToBottomOnInitialize
|
||||||
|
scrollToBottomOnThreadSwitch
|
||||||
|
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 scroll-smooth"
|
||||||
|
style={{ scrollbarGutter: "stable" }}
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
aria-hidden
|
||||||
|
className="aui-chat-viewport-top-fade pointer-events-none sticky top-0 z-10 -mx-4 h-2 shrink-0 bg-gradient-to-b from-main-panel from-20% to-transparent"
|
||||||
|
/>
|
||||||
|
{children}
|
||||||
|
{footer ? (
|
||||||
|
<ThreadPrimitive.ViewportFooter
|
||||||
|
className="aui-chat-composer-footer sticky bottom-0 z-20 -mx-4 mt-auto flex flex-col items-stretch bg-gradient-to-t from-main-panel from-60% to-transparent px-4 pt-6"
|
||||||
|
style={{ paddingBottom: "max(0.5rem, env(safe-area-inset-bottom))" }}
|
||||||
|
>
|
||||||
|
<div className="aui-chat-composer-area relative mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-3 overflow-visible">
|
||||||
|
<ChatScrollToBottom />
|
||||||
|
{footer}
|
||||||
|
</div>
|
||||||
|
</ThreadPrimitive.ViewportFooter>
|
||||||
|
) : null}
|
||||||
|
</ThreadPrimitive.Viewport>
|
||||||
|
);
|
||||||
106
surfsense_web/components/assistant-ui/edit-message-dialog.tsx
Normal file
106
surfsense_web/components/assistant-ui/edit-message-dialog.tsx
Normal file
|
|
@ -0,0 +1,106 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Confirmation dialog shown when the user edits a message that has
|
||||||
|
* reversible downstream actions. Three buttons:
|
||||||
|
*
|
||||||
|
* • "Revert all & resubmit" — POST regenerate with revert_actions=true
|
||||||
|
* • "Continue without revert" — POST regenerate with revert_actions=false
|
||||||
|
* • "Cancel" — abort the edit entirely
|
||||||
|
*
|
||||||
|
* The dialog is auto-skipped when zero reversible downstream actions
|
||||||
|
* exist (the caller checks first via ``downstreamReversibleCount``).
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { useEffect, useRef, useState } from "react";
|
||||||
|
import {
|
||||||
|
AlertDialog,
|
||||||
|
AlertDialogCancel,
|
||||||
|
AlertDialogContent,
|
||||||
|
AlertDialogDescription,
|
||||||
|
AlertDialogFooter,
|
||||||
|
AlertDialogHeader,
|
||||||
|
AlertDialogTitle,
|
||||||
|
} from "@/components/ui/alert-dialog";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
|
||||||
|
export type EditMessageDialogChoice = "revert" | "continue" | "cancel";
|
||||||
|
|
||||||
|
export interface EditMessageDialogProps {
|
||||||
|
open: boolean;
|
||||||
|
onOpenChange: (open: boolean) => void;
|
||||||
|
downstreamReversibleCount: number;
|
||||||
|
downstreamTotalCount: number;
|
||||||
|
onChoose: (choice: EditMessageDialogChoice) => void | Promise<void>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function EditMessageDialog({
|
||||||
|
open,
|
||||||
|
onOpenChange,
|
||||||
|
downstreamReversibleCount,
|
||||||
|
downstreamTotalCount,
|
||||||
|
onChoose,
|
||||||
|
}: EditMessageDialogProps) {
|
||||||
|
const [busy, setBusy] = useState<EditMessageDialogChoice | null>(null);
|
||||||
|
|
||||||
|
// The parent's ``handleEditDialogChoice`` calls
|
||||||
|
// ``setEditDialogState(null)`` BEFORE awaiting ``handleRegenerate``.
|
||||||
|
// That collapses the dialog (Radix unmounts it) while ``onChoose``
|
||||||
|
// is still awaiting the long-running stream. Without this guard,
|
||||||
|
// the ``finally { setBusy(null) }`` below ran after unmount and
|
||||||
|
// produced a "state update on unmounted component" dev warning.
|
||||||
|
const mountedRef = useRef(true);
|
||||||
|
useEffect(() => {
|
||||||
|
mountedRef.current = true;
|
||||||
|
return () => {
|
||||||
|
mountedRef.current = false;
|
||||||
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const handle = async (choice: EditMessageDialogChoice) => {
|
||||||
|
setBusy(choice);
|
||||||
|
try {
|
||||||
|
await onChoose(choice);
|
||||||
|
} finally {
|
||||||
|
if (mountedRef.current) {
|
||||||
|
setBusy(null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AlertDialog open={open} onOpenChange={onOpenChange}>
|
||||||
|
<AlertDialogContent>
|
||||||
|
<AlertDialogHeader>
|
||||||
|
<AlertDialogTitle>Edit this message?</AlertDialogTitle>
|
||||||
|
<AlertDialogDescription>
|
||||||
|
This edit drops {downstreamTotalCount} downstream message
|
||||||
|
{downstreamTotalCount === 1 ? "" : "s"} from the thread. {downstreamReversibleCount}{" "}
|
||||||
|
action
|
||||||
|
{downstreamReversibleCount === 1 ? "" : "s"} (e.g. file writes, connector changes) can
|
||||||
|
be rolled back. Pick how to handle them before regenerating.
|
||||||
|
</AlertDialogDescription>
|
||||||
|
</AlertDialogHeader>
|
||||||
|
|
||||||
|
<div className="grid gap-2">
|
||||||
|
<Button variant="default" disabled={busy !== null} onClick={() => handle("revert")}>
|
||||||
|
{busy === "revert"
|
||||||
|
? "Reverting & resubmitting…"
|
||||||
|
: `Revert ${downstreamReversibleCount} action${
|
||||||
|
downstreamReversibleCount === 1 ? "" : "s"
|
||||||
|
} & resubmit`}
|
||||||
|
</Button>
|
||||||
|
<Button variant="outline" disabled={busy !== null} onClick={() => handle("continue")}>
|
||||||
|
{busy === "continue" ? "Resubmitting…" : "Continue without reverting"}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<AlertDialogFooter className="sm:justify-start">
|
||||||
|
<AlertDialogCancel disabled={busy !== null} onClick={() => handle("cancel")}>
|
||||||
|
Cancel
|
||||||
|
</AlertDialogCancel>
|
||||||
|
</AlertDialogFooter>
|
||||||
|
</AlertDialogContent>
|
||||||
|
</AlertDialog>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
@ -3,11 +3,11 @@
|
||||||
import { useQuery } from "@tanstack/react-query";
|
import { useQuery } from "@tanstack/react-query";
|
||||||
import { useSetAtom } from "jotai";
|
import { useSetAtom } from "jotai";
|
||||||
import { ExternalLink, FileText } from "lucide-react";
|
import { ExternalLink, FileText } from "lucide-react";
|
||||||
|
import dynamic from "next/dynamic";
|
||||||
import type { FC } from "react";
|
import type { FC } from "react";
|
||||||
import { useCallback, useEffect, useRef, useState } from "react";
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
import { openCitationPanelAtom } from "@/atoms/citation/citation-panel.atom";
|
import { openCitationPanelAtom } from "@/atoms/citation/citation-panel.atom";
|
||||||
import { useCitationMetadata } from "@/components/assistant-ui/citation-metadata-context";
|
import { useCitationMetadata } from "@/components/assistant-ui/citation-metadata-context";
|
||||||
import { MarkdownViewer } from "@/components/markdown-viewer";
|
|
||||||
import { Citation } from "@/components/tool-ui/citation";
|
import { Citation } from "@/components/tool-ui/citation";
|
||||||
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
|
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
|
||||||
import { Spinner } from "@/components/ui/spinner";
|
import { Spinner } from "@/components/ui/spinner";
|
||||||
|
|
@ -15,6 +15,16 @@ import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip
|
||||||
import { documentsApiService } from "@/lib/apis/documents-api.service";
|
import { documentsApiService } from "@/lib/apis/documents-api.service";
|
||||||
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
||||||
|
|
||||||
|
// Lazily load MarkdownViewer here to break the static import cycle:
|
||||||
|
// `markdown-viewer.tsx` → `citation-renderer.tsx` → `inline-citation.tsx`
|
||||||
|
// would otherwise pull `markdown-viewer.tsx` back in at module-init time.
|
||||||
|
// Only `SurfsenseDocCitation` (popover body) ever renders this viewer, so
|
||||||
|
// the lazy boundary is invisible to most call paths.
|
||||||
|
const MarkdownViewer = dynamic(
|
||||||
|
() => import("@/components/markdown-viewer").then((m) => m.MarkdownViewer),
|
||||||
|
{ ssr: false, loading: () => <Spinner size="xs" /> }
|
||||||
|
);
|
||||||
|
|
||||||
interface InlineCitationProps {
|
interface InlineCitationProps {
|
||||||
chunkId: number;
|
chunkId: number;
|
||||||
isDocsChunk?: boolean;
|
isDocsChunk?: boolean;
|
||||||
|
|
@ -172,7 +182,7 @@ const SurfsenseDocCitation: FC<{ chunkId: number }> = ({ chunkId }) => {
|
||||||
</p>
|
</p>
|
||||||
)}
|
)}
|
||||||
{!isLoading && !error && citedChunk?.content && (
|
{!isLoading && !error && citedChunk?.content && (
|
||||||
<MarkdownViewer content={citedChunk.content} maxLength={1500} />
|
<MarkdownViewer content={citedChunk.content} maxLength={1500} enableCitations />
|
||||||
)}
|
)}
|
||||||
{!isLoading && !error && !citedChunk?.content && (
|
{!isLoading && !error && !citedChunk?.content && (
|
||||||
<p className="py-4 text-xs text-muted-foreground">No content available.</p>
|
<p className="py-4 text-xs text-muted-foreground">No content available.</p>
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -12,14 +12,15 @@ import { ExternalLinkIcon } from "lucide-react";
|
||||||
import dynamic from "next/dynamic";
|
import dynamic from "next/dynamic";
|
||||||
import { useParams } from "next/navigation";
|
import { useParams } from "next/navigation";
|
||||||
import { useTheme } from "next-themes";
|
import { useTheme } from "next-themes";
|
||||||
import { memo, type ReactNode } from "react";
|
import { createContext, memo, type ReactNode, useCallback, useContext, useRef } from "react";
|
||||||
import rehypeKatex from "rehype-katex";
|
import rehypeKatex from "rehype-katex";
|
||||||
import remarkGfm from "remark-gfm";
|
import remarkGfm from "remark-gfm";
|
||||||
import remarkMath from "remark-math";
|
import remarkMath from "remark-math";
|
||||||
import { openEditorPanelAtom } from "@/atoms/editor/editor-panel.atom";
|
import { openEditorPanelAtom } from "@/atoms/editor/editor-panel.atom";
|
||||||
import { ImagePreview, ImageRoot, ImageZoom } from "@/components/assistant-ui/image";
|
import { ImagePreview, ImageRoot, ImageZoom } from "@/components/assistant-ui/image";
|
||||||
import "katex/dist/katex.min.css";
|
import "katex/dist/katex.min.css";
|
||||||
import { InlineCitation, UrlCitation } from "@/components/assistant-ui/inline-citation";
|
import { toast } from "sonner";
|
||||||
|
import { processChildrenWithCitations } from "@/components/citations/citation-renderer";
|
||||||
import { Skeleton } from "@/components/ui/skeleton";
|
import { Skeleton } from "@/components/ui/skeleton";
|
||||||
import {
|
import {
|
||||||
Table,
|
Table,
|
||||||
|
|
@ -30,6 +31,8 @@ import {
|
||||||
TableRow,
|
TableRow,
|
||||||
} from "@/components/ui/table";
|
} from "@/components/ui/table";
|
||||||
import { useElectronAPI } from "@/hooks/use-platform";
|
import { useElectronAPI } from "@/hooks/use-platform";
|
||||||
|
import { documentsApiService } from "@/lib/apis/documents-api.service";
|
||||||
|
import { type CitationUrlMap, preprocessCitationMarkdown } from "@/lib/citations/citation-parser";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
function MarkdownCodeBlockSkeleton() {
|
function MarkdownCodeBlockSkeleton() {
|
||||||
|
|
@ -59,31 +62,30 @@ const LazyMarkdownCodeBlock = dynamic(
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
// Storage for URL citations replaced during preprocess to avoid GFM autolink interference.
|
// Per-render URL placeholder map propagated to component overrides via
|
||||||
// Populated in preprocessMarkdown, consumed in parseTextWithCitations.
|
// React Context. Replaces the previous module-level `_pendingUrlCitations`
|
||||||
let _pendingUrlCitations = new Map<string, string>();
|
// state, which was unsafe under concurrent renders / SSR.
|
||||||
let _urlCiteIdx = 0;
|
type CitationUrlMapRef = { current: CitationUrlMap };
|
||||||
|
const EMPTY_URL_MAP: CitationUrlMap = new Map();
|
||||||
|
const CitationUrlMapContext = createContext<CitationUrlMapRef>({ current: EMPTY_URL_MAP });
|
||||||
|
|
||||||
|
function useCitationUrlMap(): CitationUrlMap {
|
||||||
|
return useContext(CitationUrlMapContext).current;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Preprocess raw markdown before it reaches the remark/rehype pipeline.
|
* Preprocess raw markdown before it reaches the remark/rehype pipeline.
|
||||||
* - Replaces URL-based citations with safe placeholders (prevents GFM autolinks)
|
* - Replaces URL-based citations with safe placeholders (prevents GFM autolinks)
|
||||||
* - Normalises LaTeX delimiters to dollar-sign syntax for remark-math
|
* - Normalises LaTeX delimiters to dollar-sign syntax for remark-math
|
||||||
*/
|
*/
|
||||||
function preprocessMarkdown(content: string): string {
|
function preprocessMarkdown(content: string, urlMapRef: CitationUrlMapRef): string {
|
||||||
// Replace URL-based citations with safe placeholders BEFORE markdown parsing.
|
// Replace URL-based citations with safe placeholders BEFORE markdown parsing.
|
||||||
// GFM autolinks would otherwise convert the https://... inside [citation:URL]
|
// GFM autolinks would otherwise convert the https://... inside [citation:URL]
|
||||||
// into an <a> element, splitting the text and preventing our citation regex
|
// into an <a> element, splitting the text and preventing our citation regex
|
||||||
// from matching the full pattern.
|
// from matching the full pattern.
|
||||||
_pendingUrlCitations = new Map();
|
const { content: rewritten, urlMap } = preprocessCitationMarkdown(content);
|
||||||
_urlCiteIdx = 0;
|
urlMapRef.current = urlMap;
|
||||||
content = content.replace(
|
content = rewritten;
|
||||||
/[[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+)\s*\u200B?[\]】]/g,
|
|
||||||
(_, url) => {
|
|
||||||
const key = `urlcite${_urlCiteIdx++}`;
|
|
||||||
_pendingUrlCitations.set(key, url.trim());
|
|
||||||
return `[citation:${key}]`;
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
// All math forms are normalised to $$...$$ so we can disable single-dollar
|
// All math forms are normalised to $$...$$ so we can disable single-dollar
|
||||||
// inline math in remark-math (otherwise currency like "$3,120.00 and $0.00"
|
// inline math in remark-math (otherwise currency like "$3,120.00 and $0.00"
|
||||||
|
|
@ -116,113 +118,25 @@ function preprocessMarkdown(content: string): string {
|
||||||
return content;
|
return content;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Matches [citation:...] with numeric IDs (incl. negative, doc- prefix, comma-separated),
|
|
||||||
// URL-based IDs from live web search, or urlciteN placeholders from preprocess.
|
|
||||||
// Also matches Chinese brackets 【】 and handles zero-width spaces that LLM sometimes inserts.
|
|
||||||
const CITATION_REGEX =
|
|
||||||
/[[【]\u200B?citation:\s*(https?:\/\/[^\]】\u200B]+|urlcite\d+|(?:doc-)?-?\d+(?:\s*,\s*(?:doc-)?-?\d+)*)\s*\u200B?[\]】]/g;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parses text and replaces [citation:XXX] patterns with citation components.
|
|
||||||
* Supports:
|
|
||||||
* - Numeric chunk IDs: [citation:123]
|
|
||||||
* - Doc-prefixed IDs: [citation:doc-123]
|
|
||||||
* - Comma-separated IDs: [citation:4149, 4150, 4151]
|
|
||||||
* - URL-based citations from live search: [citation:https://example.com/page]
|
|
||||||
*/
|
|
||||||
function parseTextWithCitations(text: string): ReactNode[] {
|
|
||||||
const parts: ReactNode[] = [];
|
|
||||||
let lastIndex = 0;
|
|
||||||
let match: RegExpExecArray | null;
|
|
||||||
let instanceIndex = 0;
|
|
||||||
|
|
||||||
CITATION_REGEX.lastIndex = 0;
|
|
||||||
|
|
||||||
match = CITATION_REGEX.exec(text);
|
|
||||||
while (match !== null) {
|
|
||||||
if (match.index > lastIndex) {
|
|
||||||
parts.push(text.substring(lastIndex, match.index));
|
|
||||||
}
|
|
||||||
|
|
||||||
const captured = match[1];
|
|
||||||
|
|
||||||
if (captured.startsWith("http://") || captured.startsWith("https://")) {
|
|
||||||
parts.push(<UrlCitation key={`citation-url-${instanceIndex}`} url={captured.trim()} />);
|
|
||||||
instanceIndex++;
|
|
||||||
} else if (captured.startsWith("urlcite")) {
|
|
||||||
const url = _pendingUrlCitations.get(captured);
|
|
||||||
if (url) {
|
|
||||||
parts.push(<UrlCitation key={`citation-url-${instanceIndex}`} url={url} />);
|
|
||||||
}
|
|
||||||
instanceIndex++;
|
|
||||||
} else {
|
|
||||||
const rawIds = captured.split(",").map((s) => s.trim());
|
|
||||||
for (const rawId of rawIds) {
|
|
||||||
const isDocsChunk = rawId.startsWith("doc-");
|
|
||||||
const chunkId = Number.parseInt(isDocsChunk ? rawId.slice(4) : rawId, 10);
|
|
||||||
parts.push(
|
|
||||||
<InlineCitation
|
|
||||||
key={`citation-${isDocsChunk ? "doc-" : ""}${chunkId}-${instanceIndex}`}
|
|
||||||
chunkId={chunkId}
|
|
||||||
isDocsChunk={isDocsChunk}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
instanceIndex++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
lastIndex = match.index + match[0].length;
|
|
||||||
match = CITATION_REGEX.exec(text);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (lastIndex < text.length) {
|
|
||||||
parts.push(text.substring(lastIndex));
|
|
||||||
}
|
|
||||||
|
|
||||||
return parts.length > 0 ? parts : [text];
|
|
||||||
}
|
|
||||||
|
|
||||||
const MarkdownTextImpl = () => {
|
const MarkdownTextImpl = () => {
|
||||||
|
const urlMapRef = useRef<CitationUrlMap>(EMPTY_URL_MAP);
|
||||||
|
const preprocess = useCallback((content: string) => preprocessMarkdown(content, urlMapRef), []);
|
||||||
return (
|
return (
|
||||||
|
<CitationUrlMapContext.Provider value={urlMapRef}>
|
||||||
<MarkdownTextPrimitive
|
<MarkdownTextPrimitive
|
||||||
smooth={false}
|
smooth={false}
|
||||||
remarkPlugins={[remarkGfm, [remarkMath, { singleDollarTextMath: false }]]}
|
remarkPlugins={[remarkGfm, [remarkMath, { singleDollarTextMath: false }]]}
|
||||||
rehypePlugins={[rehypeKatex]}
|
rehypePlugins={[rehypeKatex]}
|
||||||
className="aui-md"
|
className="aui-md"
|
||||||
components={defaultComponents}
|
components={defaultComponents}
|
||||||
preprocess={preprocessMarkdown}
|
preprocess={preprocess}
|
||||||
/>
|
/>
|
||||||
|
</CitationUrlMapContext.Provider>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export const MarkdownText = memo(MarkdownTextImpl);
|
export const MarkdownText = memo(MarkdownTextImpl);
|
||||||
|
|
||||||
/**
|
|
||||||
* Helper to process children and replace citation patterns with components
|
|
||||||
*/
|
|
||||||
function processChildrenWithCitations(children: ReactNode): ReactNode {
|
|
||||||
if (typeof children === "string") {
|
|
||||||
const parsed = parseTextWithCitations(children);
|
|
||||||
return parsed.length === 1 && typeof parsed[0] === "string" ? children : parsed;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Array.isArray(children)) {
|
|
||||||
return children.map((child) => {
|
|
||||||
if (typeof child === "string") {
|
|
||||||
const parsed = parseTextWithCitations(child);
|
|
||||||
return parsed.length === 1 && typeof parsed[0] === "string" ? (
|
|
||||||
child
|
|
||||||
) : (
|
|
||||||
<span key={child}>{parsed}</span>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return child;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
return children;
|
|
||||||
}
|
|
||||||
|
|
||||||
function extractDomain(url: string): string {
|
function extractDomain(url: string): string {
|
||||||
try {
|
try {
|
||||||
const parsed = new URL(url);
|
const parsed = new URL(url);
|
||||||
|
|
@ -282,6 +196,85 @@ function isVirtualFilePathToken(value: string): boolean {
|
||||||
return segments.length >= 2;
|
return segments.length >= 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function isStandaloneDocumentsPathText(node: ReactNode): string | null {
|
||||||
|
if (typeof node !== "string") return null;
|
||||||
|
const value = node.trim();
|
||||||
|
if (!value.startsWith("/documents/")) return null;
|
||||||
|
if (value.includes(" ")) return null;
|
||||||
|
const normalized = value.replace(/\/+$/, "");
|
||||||
|
const leaf = normalized.split("/").filter(Boolean).at(-1) ?? "";
|
||||||
|
if (!leaf || !leaf.includes(".")) return null;
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
function FilePathLink({ path, className }: { path: string; className?: string }) {
|
||||||
|
const openEditorPanel = useSetAtom(openEditorPanelAtom);
|
||||||
|
const params = useParams();
|
||||||
|
const electronAPI = useElectronAPI();
|
||||||
|
const searchSpaceIdParam = params?.search_space_id;
|
||||||
|
const parsedSearchSpaceId = Array.isArray(searchSpaceIdParam)
|
||||||
|
? Number(searchSpaceIdParam[0])
|
||||||
|
: Number(searchSpaceIdParam);
|
||||||
|
const resolvedSearchSpaceId = Number.isFinite(parsedSearchSpaceId)
|
||||||
|
? parsedSearchSpaceId
|
||||||
|
: undefined;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
className={cn(
|
||||||
|
"cursor-pointer font-mono text-[0.9em] font-medium text-primary underline underline-offset-4 transition-colors hover:text-primary/80",
|
||||||
|
className
|
||||||
|
)}
|
||||||
|
onClick={(event) => {
|
||||||
|
event.preventDefault();
|
||||||
|
event.stopPropagation();
|
||||||
|
void (async () => {
|
||||||
|
if (electronAPI) {
|
||||||
|
let resolvedLocalPath = path;
|
||||||
|
if (electronAPI.getAgentFilesystemMounts) {
|
||||||
|
try {
|
||||||
|
const mounts = (await electronAPI.getAgentFilesystemMounts(
|
||||||
|
resolvedSearchSpaceId
|
||||||
|
)) as AgentFilesystemMount[];
|
||||||
|
resolvedLocalPath = normalizeLocalVirtualPathForEditor(path, mounts);
|
||||||
|
} catch {
|
||||||
|
// Fall back to the raw path if mount lookup fails.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
openEditorPanel({
|
||||||
|
kind: "local_file",
|
||||||
|
localFilePath: resolvedLocalPath,
|
||||||
|
title: resolvedLocalPath.split("/").pop() || resolvedLocalPath,
|
||||||
|
searchSpaceId: resolvedSearchSpaceId,
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!resolvedSearchSpaceId || !path.startsWith("/documents/")) return;
|
||||||
|
try {
|
||||||
|
const doc = await documentsApiService.getDocumentByVirtualPath({
|
||||||
|
search_space_id: resolvedSearchSpaceId,
|
||||||
|
virtual_path: path,
|
||||||
|
});
|
||||||
|
openEditorPanel({
|
||||||
|
kind: "document",
|
||||||
|
documentId: doc.id,
|
||||||
|
searchSpaceId: resolvedSearchSpaceId,
|
||||||
|
title: doc.title,
|
||||||
|
});
|
||||||
|
} catch {
|
||||||
|
toast.error("Document not found in knowledge base.");
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
}}
|
||||||
|
title="Open in editor panel"
|
||||||
|
>
|
||||||
|
{path}
|
||||||
|
</button>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
function MarkdownImage({ src, alt }: { src?: string; alt?: string }) {
|
function MarkdownImage({ src, alt }: { src?: string; alt?: string }) {
|
||||||
if (!src) return null;
|
if (!src) return null;
|
||||||
|
|
||||||
|
|
@ -322,7 +315,9 @@ function MarkdownImage({ src, alt }: { src?: string; alt?: string }) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultComponents = memoizeMarkdownComponents({
|
const defaultComponents = memoizeMarkdownComponents({
|
||||||
h1: ({ className, children, ...props }) => (
|
h1: function H1({ className, children, ...props }) {
|
||||||
|
const urlMap = useCitationUrlMap();
|
||||||
|
return (
|
||||||
<h1
|
<h1
|
||||||
className={cn(
|
className={cn(
|
||||||
"aui-md-h1 mb-8 scroll-m-20 font-extrabold text-4xl tracking-tight last:mb-0",
|
"aui-md-h1 mb-8 scroll-m-20 font-extrabold text-4xl tracking-tight last:mb-0",
|
||||||
|
|
@ -330,10 +325,13 @@ const defaultComponents = memoizeMarkdownComponents({
|
||||||
)}
|
)}
|
||||||
{...props}
|
{...props}
|
||||||
>
|
>
|
||||||
{processChildrenWithCitations(children)}
|
{processChildrenWithCitations(children, urlMap)}
|
||||||
</h1>
|
</h1>
|
||||||
),
|
);
|
||||||
h2: ({ className, children, ...props }) => (
|
},
|
||||||
|
h2: function H2({ className, children, ...props }) {
|
||||||
|
const urlMap = useCitationUrlMap();
|
||||||
|
return (
|
||||||
<h2
|
<h2
|
||||||
className={cn(
|
className={cn(
|
||||||
"aui-md-h2 mt-8 mb-4 scroll-m-20 font-semibold text-3xl tracking-tight first:mt-0 last:mb-0",
|
"aui-md-h2 mt-8 mb-4 scroll-m-20 font-semibold text-3xl tracking-tight first:mt-0 last:mb-0",
|
||||||
|
|
@ -341,10 +339,13 @@ const defaultComponents = memoizeMarkdownComponents({
|
||||||
)}
|
)}
|
||||||
{...props}
|
{...props}
|
||||||
>
|
>
|
||||||
{processChildrenWithCitations(children)}
|
{processChildrenWithCitations(children, urlMap)}
|
||||||
</h2>
|
</h2>
|
||||||
),
|
);
|
||||||
h3: ({ className, children, ...props }) => (
|
},
|
||||||
|
h3: function H3({ className, children, ...props }) {
|
||||||
|
const urlMap = useCitationUrlMap();
|
||||||
|
return (
|
||||||
<h3
|
<h3
|
||||||
className={cn(
|
className={cn(
|
||||||
"aui-md-h3 mt-6 mb-4 scroll-m-20 font-semibold text-2xl tracking-tight first:mt-0 last:mb-0",
|
"aui-md-h3 mt-6 mb-4 scroll-m-20 font-semibold text-2xl tracking-tight first:mt-0 last:mb-0",
|
||||||
|
|
@ -352,10 +353,13 @@ const defaultComponents = memoizeMarkdownComponents({
|
||||||
)}
|
)}
|
||||||
{...props}
|
{...props}
|
||||||
>
|
>
|
||||||
{processChildrenWithCitations(children)}
|
{processChildrenWithCitations(children, urlMap)}
|
||||||
</h3>
|
</h3>
|
||||||
),
|
);
|
||||||
h4: ({ className, children, ...props }) => (
|
},
|
||||||
|
h4: function H4({ className, children, ...props }) {
|
||||||
|
const urlMap = useCitationUrlMap();
|
||||||
|
return (
|
||||||
<h4
|
<h4
|
||||||
className={cn(
|
className={cn(
|
||||||
"aui-md-h4 mt-6 mb-4 scroll-m-20 font-semibold text-xl tracking-tight first:mt-0 last:mb-0",
|
"aui-md-h4 mt-6 mb-4 scroll-m-20 font-semibold text-xl tracking-tight first:mt-0 last:mb-0",
|
||||||
|
|
@ -363,51 +367,75 @@ const defaultComponents = memoizeMarkdownComponents({
|
||||||
)}
|
)}
|
||||||
{...props}
|
{...props}
|
||||||
>
|
>
|
||||||
{processChildrenWithCitations(children)}
|
{processChildrenWithCitations(children, urlMap)}
|
||||||
</h4>
|
</h4>
|
||||||
),
|
);
|
||||||
h5: ({ className, children, ...props }) => (
|
},
|
||||||
|
h5: function H5({ className, children, ...props }) {
|
||||||
|
const urlMap = useCitationUrlMap();
|
||||||
|
return (
|
||||||
<h5
|
<h5
|
||||||
className={cn("aui-md-h5 my-4 font-semibold text-lg first:mt-0 last:mb-0", className)}
|
className={cn("aui-md-h5 my-4 font-semibold text-lg first:mt-0 last:mb-0", className)}
|
||||||
{...props}
|
{...props}
|
||||||
>
|
>
|
||||||
{processChildrenWithCitations(children)}
|
{processChildrenWithCitations(children, urlMap)}
|
||||||
</h5>
|
</h5>
|
||||||
),
|
);
|
||||||
h6: ({ className, children, ...props }) => (
|
},
|
||||||
|
h6: function H6({ className, children, ...props }) {
|
||||||
|
const urlMap = useCitationUrlMap();
|
||||||
|
return (
|
||||||
<h6 className={cn("aui-md-h6 my-4 font-semibold first:mt-0 last:mb-0", className)} {...props}>
|
<h6 className={cn("aui-md-h6 my-4 font-semibold first:mt-0 last:mb-0", className)} {...props}>
|
||||||
{processChildrenWithCitations(children)}
|
{processChildrenWithCitations(children, urlMap)}
|
||||||
</h6>
|
</h6>
|
||||||
),
|
);
|
||||||
p: ({ className, children, ...props }) => (
|
},
|
||||||
|
p: function P({ className, children, ...props }) {
|
||||||
|
const urlMap = useCitationUrlMap();
|
||||||
|
const standalonePath = isStandaloneDocumentsPathText(children);
|
||||||
|
return (
|
||||||
<p className={cn("aui-md-p mt-5 mb-5 leading-7 first:mt-0 last:mb-0", className)} {...props}>
|
<p className={cn("aui-md-p mt-5 mb-5 leading-7 first:mt-0 last:mb-0", className)} {...props}>
|
||||||
{processChildrenWithCitations(children)}
|
{standalonePath ? (
|
||||||
|
<FilePathLink path={standalonePath} />
|
||||||
|
) : (
|
||||||
|
processChildrenWithCitations(children, urlMap)
|
||||||
|
)}
|
||||||
</p>
|
</p>
|
||||||
),
|
);
|
||||||
a: ({ className, children, ...props }) => (
|
},
|
||||||
|
a: function A({ className, children, ...props }) {
|
||||||
|
const urlMap = useCitationUrlMap();
|
||||||
|
return (
|
||||||
<a
|
<a
|
||||||
className={cn("aui-md-a font-medium text-primary underline underline-offset-4", className)}
|
className={cn("aui-md-a font-medium text-primary underline underline-offset-4", className)}
|
||||||
{...props}
|
{...props}
|
||||||
>
|
>
|
||||||
{processChildrenWithCitations(children)}
|
{processChildrenWithCitations(children, urlMap)}
|
||||||
</a>
|
</a>
|
||||||
),
|
);
|
||||||
blockquote: ({ className, children, ...props }) => (
|
},
|
||||||
|
blockquote: function Blockquote({ className, children, ...props }) {
|
||||||
|
const urlMap = useCitationUrlMap();
|
||||||
|
return (
|
||||||
<blockquote className={cn("aui-md-blockquote border-l-2 pl-6 italic", className)} {...props}>
|
<blockquote className={cn("aui-md-blockquote border-l-2 pl-6 italic", className)} {...props}>
|
||||||
{processChildrenWithCitations(children)}
|
{processChildrenWithCitations(children, urlMap)}
|
||||||
</blockquote>
|
</blockquote>
|
||||||
),
|
);
|
||||||
|
},
|
||||||
ul: ({ className, ...props }) => (
|
ul: ({ className, ...props }) => (
|
||||||
<ul className={cn("aui-md-ul my-5 ml-6 list-disc [&>li]:mt-2", className)} {...props} />
|
<ul className={cn("aui-md-ul my-5 ml-6 list-disc [&>li]:mt-2", className)} {...props} />
|
||||||
),
|
),
|
||||||
ol: ({ className, ...props }) => (
|
ol: ({ className, ...props }) => (
|
||||||
<ol className={cn("aui-md-ol my-5 ml-6 list-decimal [&>li]:mt-2", className)} {...props} />
|
<ol className={cn("aui-md-ol my-5 ml-6 list-decimal [&>li]:mt-2", className)} {...props} />
|
||||||
),
|
),
|
||||||
li: ({ className, children, ...props }) => (
|
li: function Li({ className, children, ...props }) {
|
||||||
|
const urlMap = useCitationUrlMap();
|
||||||
|
return (
|
||||||
<li className={cn("aui-md-li", className)} {...props}>
|
<li className={cn("aui-md-li", className)} {...props}>
|
||||||
{processChildrenWithCitations(children)}
|
{processChildrenWithCitations(children, urlMap)}
|
||||||
</li>
|
</li>
|
||||||
),
|
);
|
||||||
|
},
|
||||||
hr: ({ className, ...props }) => (
|
hr: ({ className, ...props }) => (
|
||||||
<hr className={cn("aui-md-hr my-5 border-b", className)} {...props} />
|
<hr className={cn("aui-md-hr my-5 border-b", className)} {...props} />
|
||||||
),
|
),
|
||||||
|
|
@ -422,7 +450,9 @@ const defaultComponents = memoizeMarkdownComponents({
|
||||||
tbody: ({ className, ...props }) => (
|
tbody: ({ className, ...props }) => (
|
||||||
<TableBody className={cn("aui-md-tbody", className)} {...props} />
|
<TableBody className={cn("aui-md-tbody", className)} {...props} />
|
||||||
),
|
),
|
||||||
th: ({ className, children, ...props }) => (
|
th: function Th({ className, children, ...props }) {
|
||||||
|
const urlMap = useCitationUrlMap();
|
||||||
|
return (
|
||||||
<TableHead
|
<TableHead
|
||||||
className={cn(
|
className={cn(
|
||||||
"aui-md-th bg-muted/50 whitespace-normal [[align=center]]:text-center [[align=right]]:text-right",
|
"aui-md-th bg-muted/50 whitespace-normal [[align=center]]:text-center [[align=right]]:text-right",
|
||||||
|
|
@ -430,10 +460,13 @@ const defaultComponents = memoizeMarkdownComponents({
|
||||||
)}
|
)}
|
||||||
{...props}
|
{...props}
|
||||||
>
|
>
|
||||||
{processChildrenWithCitations(children)}
|
{processChildrenWithCitations(children, urlMap)}
|
||||||
</TableHead>
|
</TableHead>
|
||||||
),
|
);
|
||||||
td: ({ className, children, ...props }) => (
|
},
|
||||||
|
td: function Td({ className, children, ...props }) {
|
||||||
|
const urlMap = useCitationUrlMap();
|
||||||
|
return (
|
||||||
<TableCell
|
<TableCell
|
||||||
className={cn(
|
className={cn(
|
||||||
"aui-md-td whitespace-normal [[align=center]]:text-center [[align=right]]:text-right",
|
"aui-md-td whitespace-normal [[align=center]]:text-center [[align=right]]:text-right",
|
||||||
|
|
@ -441,9 +474,10 @@ const defaultComponents = memoizeMarkdownComponents({
|
||||||
)}
|
)}
|
||||||
{...props}
|
{...props}
|
||||||
>
|
>
|
||||||
{processChildrenWithCitations(children)}
|
{processChildrenWithCitations(children, urlMap)}
|
||||||
</TableCell>
|
</TableCell>
|
||||||
),
|
);
|
||||||
|
},
|
||||||
tr: ({ className, ...props }) => <TableRow className={cn("aui-md-tr", className)} {...props} />,
|
tr: ({ className, ...props }) => <TableRow className={cn("aui-md-tr", className)} {...props} />,
|
||||||
sup: ({ className, ...props }) => (
|
sup: ({ className, ...props }) => (
|
||||||
<sup className={cn("aui-md-sup [&>a]:text-xs [&>a]:no-underline", className)} {...props} />
|
<sup className={cn("aui-md-sup [&>a]:text-xs [&>a]:no-underline", className)} {...props} />
|
||||||
|
|
@ -452,8 +486,6 @@ const defaultComponents = memoizeMarkdownComponents({
|
||||||
code: function Code({ className, children, ...props }) {
|
code: function Code({ className, children, ...props }) {
|
||||||
const isCodeBlock = useIsMarkdownCodeBlock();
|
const isCodeBlock = useIsMarkdownCodeBlock();
|
||||||
const { resolvedTheme } = useTheme();
|
const { resolvedTheme } = useTheme();
|
||||||
const openEditorPanel = useSetAtom(openEditorPanelAtom);
|
|
||||||
const params = useParams();
|
|
||||||
const electronAPI = useElectronAPI();
|
const electronAPI = useElectronAPI();
|
||||||
const language = /language-(\w+)/.exec(className || "")?.[1] ?? "text";
|
const language = /language-(\w+)/.exec(className || "")?.[1] ?? "text";
|
||||||
const codeString = String(children).replace(/\n$/, "");
|
const codeString = String(children).replace(/\n$/, "");
|
||||||
|
|
@ -470,53 +502,17 @@ const defaultComponents = memoizeMarkdownComponents({
|
||||||
const isLikelyFolder =
|
const isLikelyFolder =
|
||||||
inlineValue.endsWith("/") || !leafSegment || !leafSegment.includes(".");
|
inlineValue.endsWith("/") || !leafSegment || !leafSegment.includes(".");
|
||||||
const isLocalPath =
|
const isLocalPath =
|
||||||
!!electronAPI &&
|
(isVirtualFilePathToken(inlineValue) &&
|
||||||
isVirtualFilePathToken(inlineValue) &&
|
|
||||||
!inlineValue.startsWith("//") &&
|
!inlineValue.startsWith("//") &&
|
||||||
!isLikelyFolder;
|
!isLikelyFolder &&
|
||||||
const displayLocalPath = inlineValue.replace(/^\/+/, "");
|
!!electronAPI) ||
|
||||||
const searchSpaceIdParam = params?.search_space_id;
|
(isVirtualFilePathToken(inlineValue) &&
|
||||||
const parsedSearchSpaceId = Array.isArray(searchSpaceIdParam)
|
!inlineValue.startsWith("//") &&
|
||||||
? Number(searchSpaceIdParam[0])
|
!isLikelyFolder &&
|
||||||
: Number(searchSpaceIdParam);
|
!electronAPI &&
|
||||||
|
inlineValue.startsWith("/documents/"));
|
||||||
if (isLocalPath) {
|
if (isLocalPath) {
|
||||||
return (
|
return <FilePathLink path={inlineValue} className="text-[0.9em]" />;
|
||||||
<button
|
|
||||||
type="button"
|
|
||||||
className={cn(
|
|
||||||
"cursor-pointer font-mono text-[0.9em] font-medium text-primary underline underline-offset-4 transition-colors hover:text-primary/80"
|
|
||||||
)}
|
|
||||||
onClick={(event) => {
|
|
||||||
event.preventDefault();
|
|
||||||
event.stopPropagation();
|
|
||||||
void (async () => {
|
|
||||||
let resolvedLocalPath = inlineValue;
|
|
||||||
const resolvedSearchSpaceId = Number.isFinite(parsedSearchSpaceId)
|
|
||||||
? parsedSearchSpaceId
|
|
||||||
: undefined;
|
|
||||||
if (electronAPI?.getAgentFilesystemMounts) {
|
|
||||||
try {
|
|
||||||
const mounts = (await electronAPI.getAgentFilesystemMounts(
|
|
||||||
resolvedSearchSpaceId
|
|
||||||
)) as AgentFilesystemMount[];
|
|
||||||
resolvedLocalPath = normalizeLocalVirtualPathForEditor(inlineValue, mounts);
|
|
||||||
} catch {
|
|
||||||
// Fall back to the raw inline path if mount lookup fails.
|
|
||||||
}
|
|
||||||
}
|
|
||||||
openEditorPanel({
|
|
||||||
kind: "local_file",
|
|
||||||
localFilePath: resolvedLocalPath,
|
|
||||||
title: resolvedLocalPath.split("/").pop() || resolvedLocalPath,
|
|
||||||
searchSpaceId: resolvedSearchSpaceId,
|
|
||||||
});
|
|
||||||
})();
|
|
||||||
}}
|
|
||||||
title="Open in editor panel"
|
|
||||||
>
|
|
||||||
{displayLocalPath}
|
|
||||||
</button>
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
return (
|
return (
|
||||||
<code
|
<code
|
||||||
|
|
@ -552,16 +548,22 @@ const defaultComponents = memoizeMarkdownComponents({
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
strong: ({ className, children, ...props }) => (
|
strong: function Strong({ className, children, ...props }) {
|
||||||
|
const urlMap = useCitationUrlMap();
|
||||||
|
return (
|
||||||
<strong className={cn("aui-md-strong font-semibold", className)} {...props}>
|
<strong className={cn("aui-md-strong font-semibold", className)} {...props}>
|
||||||
{processChildrenWithCitations(children)}
|
{processChildrenWithCitations(children, urlMap)}
|
||||||
</strong>
|
</strong>
|
||||||
),
|
);
|
||||||
em: ({ className, children, ...props }) => (
|
},
|
||||||
|
em: function Em({ className, children, ...props }) {
|
||||||
|
const urlMap = useCitationUrlMap();
|
||||||
|
return (
|
||||||
<em className={cn("aui-md-em", className)} {...props}>
|
<em className={cn("aui-md-em", className)} {...props}>
|
||||||
{processChildrenWithCitations(children)}
|
{processChildrenWithCitations(children, urlMap)}
|
||||||
</em>
|
</em>
|
||||||
),
|
);
|
||||||
|
},
|
||||||
img: ({ src, alt }) => (
|
img: ({ src, alt }) => (
|
||||||
<MarkdownImage src={typeof src === "string" ? src : undefined} alt={alt} />
|
<MarkdownImage src={typeof src === "string" ? src : undefined} alt={alt} />
|
||||||
),
|
),
|
||||||
|
|
|
||||||
24
surfsense_web/components/assistant-ui/nested-scroll.tsx
Normal file
24
surfsense_web/components/assistant-ui/nested-scroll.tsx
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { type ComponentPropsWithoutRef, forwardRef, type WheelEvent } from "react";
|
||||||
|
|
||||||
|
export type NestedScrollProps = ComponentPropsWithoutRef<"div">;
|
||||||
|
|
||||||
|
export const NestedScroll = forwardRef<HTMLDivElement, NestedScrollProps>(
|
||||||
|
({ onWheel, ...props }, ref) => {
|
||||||
|
const handleWheel = (event: WheelEvent<HTMLDivElement>) => {
|
||||||
|
const el = event.currentTarget;
|
||||||
|
const canScrollUp = el.scrollTop > 0;
|
||||||
|
const canScrollDown = el.scrollTop < el.scrollHeight - el.clientHeight - 1;
|
||||||
|
const goingUp = event.deltaY < 0;
|
||||||
|
const goingDown = event.deltaY > 0;
|
||||||
|
if ((goingUp && canScrollUp) || (goingDown && canScrollDown)) {
|
||||||
|
event.stopPropagation();
|
||||||
|
}
|
||||||
|
onWheel?.(event);
|
||||||
|
};
|
||||||
|
return <div ref={ref} onWheel={handleWheel} {...props} />;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
NestedScroll.displayName = "NestedScroll";
|
||||||
|
|
@ -0,0 +1,81 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import type { ReasoningMessagePartComponent } from "@assistant-ui/react";
|
||||||
|
import { ChevronRightIcon } from "lucide-react";
|
||||||
|
import { useEffect, useMemo, useState } from "react";
|
||||||
|
import { TextShimmerLoader } from "@/components/prompt-kit/loader";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Renders the structured `reasoning` part emitted by the backend's
|
||||||
|
* stream-parity v2 path (A1).
|
||||||
|
*
|
||||||
|
* Behaviour mirrors the existing `ThinkingStepsDisplay`:
|
||||||
|
* - collapsed by default;
|
||||||
|
* - auto-expanded while the part is still `running`;
|
||||||
|
* - auto-collapsed once status flips to `complete`.
|
||||||
|
*
|
||||||
|
* The component is registered via the `Reasoning` slot on
|
||||||
|
* `MessagePrimitive.Parts` in `assistant-message.tsx` so it lives at the
|
||||||
|
* exact ordinal position of the reasoning block in the message content
|
||||||
|
* array (i.e. above the assistant text that follows it).
|
||||||
|
*/
|
||||||
|
export const ReasoningMessagePart: ReasoningMessagePartComponent = ({ text, status }) => {
|
||||||
|
const isRunning = status?.type === "running";
|
||||||
|
const [isOpen, setIsOpen] = useState(() => isRunning);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (isRunning) {
|
||||||
|
setIsOpen(true);
|
||||||
|
} else if (status?.type === "complete") {
|
||||||
|
setIsOpen(false);
|
||||||
|
}
|
||||||
|
}, [isRunning, status?.type]);
|
||||||
|
|
||||||
|
const headerLabel = useMemo(() => {
|
||||||
|
if (isRunning) return "Thinking";
|
||||||
|
if (status?.type === "incomplete") return "Thinking interrupted";
|
||||||
|
return "Thought";
|
||||||
|
}, [isRunning, status?.type]);
|
||||||
|
|
||||||
|
if (!text || text.length === 0) {
|
||||||
|
if (!isRunning) return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="mx-auto w-full max-w-(--thread-max-width) px-2 py-2">
|
||||||
|
<div className="rounded-lg">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onClick={() => setIsOpen((prev) => !prev)}
|
||||||
|
className={cn(
|
||||||
|
"flex w-full items-center gap-1.5 text-left text-sm transition-colors",
|
||||||
|
"text-muted-foreground hover:text-foreground"
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{isRunning ? (
|
||||||
|
<TextShimmerLoader text={headerLabel} size="sm" />
|
||||||
|
) : (
|
||||||
|
<span>{headerLabel}</span>
|
||||||
|
)}
|
||||||
|
<ChevronRightIcon
|
||||||
|
className={cn("size-4 transition-transform duration-200", isOpen && "rotate-90")}
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"grid transition-[grid-template-rows] duration-300 ease-out",
|
||||||
|
isOpen ? "grid-rows-[1fr]" : "grid-rows-[0fr]"
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<div className="overflow-hidden">
|
||||||
|
<div className="mt-2 border-l border-muted-foreground/30 pl-3 text-sm leading-relaxed text-muted-foreground whitespace-pre-wrap wrap-break-word">
|
||||||
|
{text}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
213
surfsense_web/components/assistant-ui/revert-turn-button.tsx
Normal file
213
surfsense_web/components/assistant-ui/revert-turn-button.tsx
Normal file
|
|
@ -0,0 +1,213 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* "Revert turn" button rendered at the bottom of every completed
|
||||||
|
* assistant turn that has at least one reversible action.
|
||||||
|
*
|
||||||
|
* The button reads from the unified ``useAgentActionsQuery`` cache
|
||||||
|
* (the SAME react-query cache the agent-actions sheet and the inline
|
||||||
|
* Revert button consume) filtered by ``chat_turn_id``. It shows a
|
||||||
|
* confirmation dialog summarising "N reversible / M total" and, on
|
||||||
|
* confirm, calls ``POST /threads/{id}/revert-turn/{chat_turn_id}``.
|
||||||
|
*
|
||||||
|
* The route returns a per-action result list and never collapses the
|
||||||
|
* batch into a 4xx — so we render any failed/not_reversible rows inline
|
||||||
|
* with their messages.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
|
import { useAtomValue } from "jotai";
|
||||||
|
import { CheckIcon, RotateCcw, XCircleIcon } from "lucide-react";
|
||||||
|
import { useMemo, useState } from "react";
|
||||||
|
import { toast } from "sonner";
|
||||||
|
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
|
||||||
|
import {
|
||||||
|
AlertDialog,
|
||||||
|
AlertDialogAction,
|
||||||
|
AlertDialogCancel,
|
||||||
|
AlertDialogContent,
|
||||||
|
AlertDialogDescription,
|
||||||
|
AlertDialogFooter,
|
||||||
|
AlertDialogHeader,
|
||||||
|
AlertDialogTitle,
|
||||||
|
AlertDialogTrigger,
|
||||||
|
} from "@/components/ui/alert-dialog";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { getToolDisplayName } from "@/contracts/enums/toolIcons";
|
||||||
|
import {
|
||||||
|
applyRevertTurnResultsToCache,
|
||||||
|
useAgentActionsQuery,
|
||||||
|
} from "@/hooks/use-agent-actions-query";
|
||||||
|
import {
|
||||||
|
agentActionsApiService,
|
||||||
|
type RevertTurnActionResult,
|
||||||
|
} from "@/lib/apis/agent-actions-api.service";
|
||||||
|
import { AppError } from "@/lib/error";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
interface RevertTurnButtonProps {
|
||||||
|
chatTurnId: string | null | undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function RevertTurnButton({ chatTurnId }: RevertTurnButtonProps) {
|
||||||
|
const session = useAtomValue(chatSessionStateAtom);
|
||||||
|
const threadId = session?.threadId ?? null;
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
const { findByChatTurnId } = useAgentActionsQuery(threadId);
|
||||||
|
const [isReverting, setIsReverting] = useState(false);
|
||||||
|
const [confirmOpen, setConfirmOpen] = useState(false);
|
||||||
|
const [resultsOpen, setResultsOpen] = useState(false);
|
||||||
|
const [results, setResults] = useState<RevertTurnActionResult[]>([]);
|
||||||
|
|
||||||
|
const actions = useMemo(() => findByChatTurnId(chatTurnId), [findByChatTurnId, chatTurnId]);
|
||||||
|
|
||||||
|
const reversibleCount = useMemo(
|
||||||
|
() =>
|
||||||
|
actions.filter(
|
||||||
|
(a) =>
|
||||||
|
a.reversible &&
|
||||||
|
(a.reverted_by_action_id === null || a.reverted_by_action_id === undefined) &&
|
||||||
|
!a.is_revert_action &&
|
||||||
|
(a.error === null || a.error === undefined)
|
||||||
|
).length,
|
||||||
|
[actions]
|
||||||
|
);
|
||||||
|
const totalCount = useMemo(() => actions.filter((a) => !a.is_revert_action).length, [actions]);
|
||||||
|
|
||||||
|
if (!chatTurnId) return null;
|
||||||
|
if (reversibleCount === 0) return null;
|
||||||
|
if (!threadId) return null;
|
||||||
|
|
||||||
|
const handleRevertTurn = async () => {
|
||||||
|
setIsReverting(true);
|
||||||
|
try {
|
||||||
|
const response = await agentActionsApiService.revertTurn(threadId, chatTurnId);
|
||||||
|
setResults(response.results);
|
||||||
|
const revertedEntries = response.results
|
||||||
|
.filter((r) => r.status === "reverted" || r.status === "already_reverted")
|
||||||
|
.map((r) => ({ id: r.action_id, newActionId: r.new_action_id ?? null }));
|
||||||
|
if (revertedEntries.length > 0) {
|
||||||
|
applyRevertTurnResultsToCache(queryClient, threadId, revertedEntries);
|
||||||
|
}
|
||||||
|
if (response.status === "ok") {
|
||||||
|
toast.success(
|
||||||
|
response.reverted === 1 ? "Reverted 1 action." : `Reverted ${response.reverted} actions.`
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
// Every "not undone" bucket counts as a failure for the
|
||||||
|
// user-facing summary. ``skipped`` rows are batch
|
||||||
|
// artefacts (revert rows themselves) and intentionally
|
||||||
|
// excluded from the failure tally.
|
||||||
|
const failureCount =
|
||||||
|
response.failed + response.not_reversible + (response.permission_denied ?? 0);
|
||||||
|
toast.warning(
|
||||||
|
`Reverted ${response.reverted} of ${response.total}. ${failureCount} could not be undone.`
|
||||||
|
);
|
||||||
|
setResultsOpen(true);
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
if (err instanceof AppError && err.status === 503) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const message =
|
||||||
|
err instanceof AppError
|
||||||
|
? err.message
|
||||||
|
: err instanceof Error
|
||||||
|
? err.message
|
||||||
|
: "Failed to revert turn.";
|
||||||
|
toast.error(message);
|
||||||
|
} finally {
|
||||||
|
setIsReverting(false);
|
||||||
|
setConfirmOpen(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<AlertDialog open={confirmOpen} onOpenChange={setConfirmOpen}>
|
||||||
|
<AlertDialogTrigger asChild>
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
variant="ghost"
|
||||||
|
className="text-muted-foreground hover:text-foreground gap-1.5"
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
setConfirmOpen(true);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<RotateCcw className="size-3.5" />
|
||||||
|
<span>Revert turn</span>
|
||||||
|
<span className="text-xs tabular-nums opacity-70">
|
||||||
|
{reversibleCount}/{totalCount}
|
||||||
|
</span>
|
||||||
|
</Button>
|
||||||
|
</AlertDialogTrigger>
|
||||||
|
<AlertDialogContent>
|
||||||
|
<AlertDialogHeader>
|
||||||
|
<AlertDialogTitle>Revert this turn?</AlertDialogTitle>
|
||||||
|
<AlertDialogDescription>
|
||||||
|
This will undo {reversibleCount} of {totalCount} action
|
||||||
|
{totalCount === 1 ? "" : "s"} from this turn in reverse order. The chat history and
|
||||||
|
any read-only actions are preserved. Some rows may not be reversible — partial success
|
||||||
|
is normal.
|
||||||
|
</AlertDialogDescription>
|
||||||
|
</AlertDialogHeader>
|
||||||
|
<AlertDialogFooter>
|
||||||
|
<AlertDialogCancel disabled={isReverting}>Cancel</AlertDialogCancel>
|
||||||
|
<AlertDialogAction
|
||||||
|
onClick={(e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
handleRevertTurn();
|
||||||
|
}}
|
||||||
|
disabled={isReverting}
|
||||||
|
>
|
||||||
|
{isReverting ? "Reverting…" : "Revert turn"}
|
||||||
|
</AlertDialogAction>
|
||||||
|
</AlertDialogFooter>
|
||||||
|
</AlertDialogContent>
|
||||||
|
</AlertDialog>
|
||||||
|
|
||||||
|
<AlertDialog open={resultsOpen} onOpenChange={setResultsOpen}>
|
||||||
|
<AlertDialogContent>
|
||||||
|
<AlertDialogHeader>
|
||||||
|
<AlertDialogTitle>Revert results</AlertDialogTitle>
|
||||||
|
<AlertDialogDescription>
|
||||||
|
Some actions could not be reverted. Review per-row outcomes below.
|
||||||
|
</AlertDialogDescription>
|
||||||
|
</AlertDialogHeader>
|
||||||
|
<ul className="max-h-72 overflow-y-auto space-y-2 text-sm">
|
||||||
|
{results.map((r) => (
|
||||||
|
<RevertResultRow key={r.action_id} result={r} />
|
||||||
|
))}
|
||||||
|
</ul>
|
||||||
|
<AlertDialogFooter>
|
||||||
|
<AlertDialogAction onClick={() => setResultsOpen(false)}>Close</AlertDialogAction>
|
||||||
|
</AlertDialogFooter>
|
||||||
|
</AlertDialogContent>
|
||||||
|
</AlertDialog>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function RevertResultRow({ result }: { result: RevertTurnActionResult }) {
|
||||||
|
const isOk = result.status === "reverted" || result.status === "already_reverted";
|
||||||
|
const Icon = isOk ? CheckIcon : XCircleIcon;
|
||||||
|
return (
|
||||||
|
<li className="flex items-start gap-2 rounded-md border bg-muted/30 px-3 py-2">
|
||||||
|
<Icon
|
||||||
|
className={cn("size-4 mt-0.5 shrink-0", isOk ? "text-emerald-500" : "text-destructive")}
|
||||||
|
/>
|
||||||
|
<div className="min-w-0 flex-1">
|
||||||
|
<p className="font-medium truncate">
|
||||||
|
{getToolDisplayName(result.tool_name)}{" "}
|
||||||
|
<span className="ml-1 text-xs text-muted-foreground">
|
||||||
|
{result.status.replace(/_/g, " ")}
|
||||||
|
</span>
|
||||||
|
</p>
|
||||||
|
{(result.message || result.error) && (
|
||||||
|
<p className="text-xs text-muted-foreground mt-0.5">{result.error ?? result.message}</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</li>
|
||||||
|
);
|
||||||
|
}
|
||||||
27
surfsense_web/components/assistant-ui/step-separator.tsx
Normal file
27
surfsense_web/components/assistant-ui/step-separator.tsx
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { makeAssistantDataUI } from "@assistant-ui/react";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Renders a thin horizontal divider between model steps within a single
|
||||||
|
* assistant turn. The data part is pushed by `addStepSeparator` in
|
||||||
|
* `streaming-state.ts` whenever a `start-step` SSE event arrives after
|
||||||
|
* the message already has non-step content.
|
||||||
|
*
|
||||||
|
* Today the backend emits one `start-step` / `finish-step` pair per turn,
|
||||||
|
* so most messages won't contain a separator. The renderer is wired up so
|
||||||
|
* the planned per-model-step refactor (A2 follow-up) can light up without
|
||||||
|
* touching the persistence path.
|
||||||
|
*/
|
||||||
|
function StepSeparatorDataRenderer() {
|
||||||
|
return (
|
||||||
|
<div className="mx-auto my-3 w-full max-w-(--thread-max-width) px-2">
|
||||||
|
<div className="border-t border-border/60" />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export const StepSeparatorDataUI = makeAssistantDataUI({
|
||||||
|
name: "step-separator",
|
||||||
|
render: StepSeparatorDataRenderer,
|
||||||
|
});
|
||||||
|
|
@ -1,18 +0,0 @@
|
||||||
import { ThreadPrimitive } from "@assistant-ui/react";
|
|
||||||
import { ArrowDownIcon } from "lucide-react";
|
|
||||||
import type { FC } from "react";
|
|
||||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
|
||||||
|
|
||||||
export const ThreadScrollToBottom: FC = () => {
|
|
||||||
return (
|
|
||||||
<ThreadPrimitive.ScrollToBottom asChild>
|
|
||||||
<TooltipIconButton
|
|
||||||
tooltip="Scroll to bottom"
|
|
||||||
variant="outline"
|
|
||||||
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
|
|
||||||
>
|
|
||||||
<ArrowDownIcon />
|
|
||||||
</TooltipIconButton>
|
|
||||||
</ThreadPrimitive.ScrollToBottom>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
@ -5,12 +5,10 @@ import {
|
||||||
ThreadPrimitive,
|
ThreadPrimitive,
|
||||||
useAui,
|
useAui,
|
||||||
useAuiState,
|
useAuiState,
|
||||||
useThreadViewportStore,
|
|
||||||
} from "@assistant-ui/react";
|
} from "@assistant-ui/react";
|
||||||
import { useAtom, useAtomValue, useSetAtom } from "jotai";
|
import { useAtom, useAtomValue, useSetAtom } from "jotai";
|
||||||
import {
|
import {
|
||||||
AlertCircle,
|
AlertCircle,
|
||||||
ArrowDownIcon,
|
|
||||||
ArrowUpIcon,
|
ArrowUpIcon,
|
||||||
Camera,
|
Camera,
|
||||||
ChevronDown,
|
ChevronDown,
|
||||||
|
|
@ -37,10 +35,13 @@ import {
|
||||||
toggleToolAtom,
|
toggleToolAtom,
|
||||||
} from "@/atoms/agent-tools/agent-tools.atoms";
|
} from "@/atoms/agent-tools/agent-tools.atoms";
|
||||||
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
|
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
|
||||||
import {
|
import { currentThreadAtom } from "@/atoms/chat/current-thread.atom";
|
||||||
mentionedDocumentsAtom,
|
import { mentionedDocumentsAtom } from "@/atoms/chat/mentioned-documents.atom";
|
||||||
} from "@/atoms/chat/mentioned-documents.atom";
|
|
||||||
import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom";
|
import { pendingUserImageDataUrlsAtom } from "@/atoms/chat/pending-user-images.atom";
|
||||||
|
import {
|
||||||
|
clearPremiumAlertForThreadAtom,
|
||||||
|
premiumAlertByThreadAtom,
|
||||||
|
} from "@/atoms/chat/premium-alert.atom";
|
||||||
import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms";
|
import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms";
|
||||||
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
|
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
|
||||||
import { membersAtom } from "@/atoms/members/members-query.atoms";
|
import { membersAtom } from "@/atoms/members/members-query.atoms";
|
||||||
|
|
@ -52,6 +53,7 @@ import {
|
||||||
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
|
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
|
||||||
import { AssistantMessage } from "@/components/assistant-ui/assistant-message";
|
import { AssistantMessage } from "@/components/assistant-ui/assistant-message";
|
||||||
import { ChatSessionStatus } from "@/components/assistant-ui/chat-session-status";
|
import { ChatSessionStatus } from "@/components/assistant-ui/chat-session-status";
|
||||||
|
import { ChatViewport } from "@/components/assistant-ui/chat-viewport";
|
||||||
import { ConnectorIndicator } from "@/components/assistant-ui/connector-popup";
|
import { ConnectorIndicator } from "@/components/assistant-ui/connector-popup";
|
||||||
import { useDocumentUploadDialog } from "@/components/assistant-ui/document-upload-popup";
|
import { useDocumentUploadDialog } from "@/components/assistant-ui/document-upload-popup";
|
||||||
import {
|
import {
|
||||||
|
|
@ -82,6 +84,7 @@ import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||||
import {
|
import {
|
||||||
CONNECTOR_ICON_TO_TYPES,
|
CONNECTOR_ICON_TO_TYPES,
|
||||||
CONNECTOR_TOOL_ICON_PATHS,
|
CONNECTOR_TOOL_ICON_PATHS,
|
||||||
|
getToolDisplayName,
|
||||||
getToolIcon,
|
getToolIcon,
|
||||||
} from "@/contracts/enums/toolIcons";
|
} from "@/contracts/enums/toolIcons";
|
||||||
import type { Document } from "@/contracts/types/document.types";
|
import type { Document } from "@/contracts/types/document.types";
|
||||||
|
|
@ -89,8 +92,8 @@ import { useBatchCommentsPreload } from "@/hooks/use-comments";
|
||||||
import { useCommentsSync } from "@/hooks/use-comments-sync";
|
import { useCommentsSync } from "@/hooks/use-comments-sync";
|
||||||
import { useMediaQuery } from "@/hooks/use-media-query";
|
import { useMediaQuery } from "@/hooks/use-media-query";
|
||||||
import { useElectronAPI } from "@/hooks/use-platform";
|
import { useElectronAPI } from "@/hooks/use-platform";
|
||||||
import { getMentionDocKey } from "@/lib/chat/mention-doc-key";
|
|
||||||
import { captureDisplayToPngDataUrl } from "@/lib/chat/display-media-capture";
|
import { captureDisplayToPngDataUrl } from "@/lib/chat/display-media-capture";
|
||||||
|
import { getMentionDocKey } from "@/lib/chat/mention-doc-key";
|
||||||
import { SLIDEOUT_PANEL_OPENED_EVENT } from "@/lib/layout-events";
|
import { SLIDEOUT_PANEL_OPENED_EVENT } from "@/lib/layout-events";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
|
@ -108,10 +111,13 @@ const ThreadContent: FC = () => {
|
||||||
["--thread-max-width" as string]: "44rem",
|
["--thread-max-width" as string]: "44rem",
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<ThreadPrimitive.Viewport
|
<ChatViewport
|
||||||
turnAnchor="top"
|
footer={
|
||||||
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4"
|
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
||||||
style={{ scrollbarGutter: "stable" }}
|
<PremiumQuotaPinnedAlert />
|
||||||
|
<Composer />
|
||||||
|
</AuiIf>
|
||||||
|
}
|
||||||
>
|
>
|
||||||
<AuiIf condition={({ thread }) => thread.isEmpty}>
|
<AuiIf condition={({ thread }) => thread.isEmpty}>
|
||||||
<ThreadWelcome />
|
<ThreadWelcome />
|
||||||
|
|
@ -124,36 +130,39 @@ const ThreadContent: FC = () => {
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
</ChatViewport>
|
||||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
|
||||||
<div className="grow" />
|
|
||||||
</AuiIf>
|
|
||||||
|
|
||||||
<ThreadPrimitive.ViewportFooter
|
|
||||||
className="aui-thread-viewport-footer sticky bottom-0 z-10 mx-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-2xl bg-main-panel pb-4 md:pb-6"
|
|
||||||
style={{ paddingBottom: "max(1rem, env(safe-area-inset-bottom))" }}
|
|
||||||
>
|
|
||||||
<ThreadScrollToBottom />
|
|
||||||
<AuiIf condition={({ thread }) => !thread.isEmpty}>
|
|
||||||
<Composer />
|
|
||||||
</AuiIf>
|
|
||||||
</ThreadPrimitive.ViewportFooter>
|
|
||||||
</ThreadPrimitive.Viewport>
|
|
||||||
</ThreadPrimitive.Root>
|
</ThreadPrimitive.Root>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
const ThreadScrollToBottom: FC = () => {
|
const PremiumQuotaPinnedAlert: FC = () => {
|
||||||
|
const currentThreadState = useAtomValue(currentThreadAtom);
|
||||||
|
const alertsByThread = useAtomValue(premiumAlertByThreadAtom);
|
||||||
|
const clearPremiumAlertForThread = useSetAtom(clearPremiumAlertForThreadAtom);
|
||||||
|
|
||||||
|
const currentThreadId = currentThreadState?.id;
|
||||||
|
if (!currentThreadId) return null;
|
||||||
|
|
||||||
|
const alert = alertsByThread[currentThreadId];
|
||||||
|
if (!alert) return null;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ThreadPrimitive.ScrollToBottom asChild>
|
<div className="mx-0 overflow-hidden rounded-2xl border-input bg-muted px-4 py-4 text-foreground select-none">
|
||||||
<TooltipIconButton
|
<div className="flex items-center gap-2">
|
||||||
tooltip="Scroll to bottom"
|
<AlertCircle className="size-4 shrink-0 text-muted-foreground" />
|
||||||
variant="outline"
|
<div className="min-w-0 flex-1">
|
||||||
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-main-panel dark:hover:bg-accent"
|
<p className="text-sm">{alert.message}</p>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
className="inline-flex size-6 items-center justify-center text-muted-foreground transition-colors hover:text-foreground"
|
||||||
|
aria-label="Dismiss premium quota alert"
|
||||||
|
onClick={() => clearPremiumAlertForThread(currentThreadId)}
|
||||||
>
|
>
|
||||||
<ArrowDownIcon />
|
<X className="size-4" />
|
||||||
</TooltipIconButton>
|
</button>
|
||||||
</ThreadPrimitive.ScrollToBottom>
|
</div>
|
||||||
|
</div>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -373,23 +382,9 @@ const Composer: FC = () => {
|
||||||
>(new Map());
|
>(new Map());
|
||||||
const documentPickerRef = useRef<DocumentMentionPickerRef>(null);
|
const documentPickerRef = useRef<DocumentMentionPickerRef>(null);
|
||||||
const promptPickerRef = useRef<PromptPickerRef>(null);
|
const promptPickerRef = useRef<PromptPickerRef>(null);
|
||||||
const viewportRef = useRef<Element | null>(null);
|
|
||||||
const { search_space_id, chat_id } = useParams();
|
const { search_space_id, chat_id } = useParams();
|
||||||
const aui = useAui();
|
const aui = useAui();
|
||||||
const threadViewportStore = useThreadViewportStore();
|
|
||||||
const hasAutoFocusedRef = useRef(false);
|
const hasAutoFocusedRef = useRef(false);
|
||||||
const submitCleanupRef = useRef<(() => void) | null>(null);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
return () => {
|
|
||||||
submitCleanupRef.current?.();
|
|
||||||
};
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
// Store viewport element reference on mount
|
|
||||||
useEffect(() => {
|
|
||||||
viewportRef.current = document.querySelector(".aui-thread-viewport");
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const electronAPI = useElectronAPI();
|
const electronAPI = useElectronAPI();
|
||||||
const [clipboardInitialText, setClipboardInitialText] = useState<string | undefined>();
|
const [clipboardInitialText, setClipboardInitialText] = useState<string | undefined>();
|
||||||
|
|
@ -588,7 +583,6 @@ const Composer: FC = () => {
|
||||||
[showDocumentPopover, showPromptPicker]
|
[showDocumentPopover, showPromptPicker]
|
||||||
);
|
);
|
||||||
|
|
||||||
// Submit message (blocked during streaming, document picker open, or AI responding to another user)
|
|
||||||
const handleSubmit = useCallback(() => {
|
const handleSubmit = useCallback(() => {
|
||||||
if (isThreadRunning || isBlockedByOtherUser) return;
|
if (isThreadRunning || isBlockedByOtherUser) return;
|
||||||
if (showDocumentPopover || showPromptPicker) return;
|
if (showDocumentPopover || showPromptPicker) return;
|
||||||
|
|
@ -600,50 +594,9 @@ const Composer: FC = () => {
|
||||||
setClipboardInitialText(undefined);
|
setClipboardInitialText(undefined);
|
||||||
}
|
}
|
||||||
|
|
||||||
const viewportEl = viewportRef.current;
|
|
||||||
const heightBefore = viewportEl?.scrollHeight ?? 0;
|
|
||||||
|
|
||||||
aui.composer().send();
|
aui.composer().send();
|
||||||
editorRef.current?.clear();
|
editorRef.current?.clear();
|
||||||
setMentionedDocuments([]);
|
setMentionedDocuments([]);
|
||||||
|
|
||||||
// With turnAnchor="top", ViewportSlack adds min-height to the last
|
|
||||||
// assistant message so that scrolling-to-bottom actually positions the
|
|
||||||
// user message at the TOP of the viewport. That slack height is
|
|
||||||
// calculated asynchronously (ResizeObserver → style → layout).
|
|
||||||
// Poll via rAF for ~500ms, re-scrolling whenever scrollHeight changes.
|
|
||||||
const scrollToBottom = () =>
|
|
||||||
threadViewportStore.getState().scrollToBottom({ behavior: "instant" });
|
|
||||||
|
|
||||||
let lastHeight = heightBefore;
|
|
||||||
let frames = 0;
|
|
||||||
let cancelled = false;
|
|
||||||
const POLL_FRAMES = 30;
|
|
||||||
|
|
||||||
const pollAndScroll = () => {
|
|
||||||
if (cancelled) return;
|
|
||||||
const el = viewportRef.current;
|
|
||||||
if (el) {
|
|
||||||
const h = el.scrollHeight;
|
|
||||||
if (h !== lastHeight) {
|
|
||||||
lastHeight = h;
|
|
||||||
scrollToBottom();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (++frames < POLL_FRAMES) {
|
|
||||||
requestAnimationFrame(pollAndScroll);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
requestAnimationFrame(pollAndScroll);
|
|
||||||
|
|
||||||
const t1 = setTimeout(scrollToBottom, 100);
|
|
||||||
const t2 = setTimeout(scrollToBottom, 300);
|
|
||||||
|
|
||||||
submitCleanupRef.current = () => {
|
|
||||||
cancelled = true;
|
|
||||||
clearTimeout(t1);
|
|
||||||
clearTimeout(t2);
|
|
||||||
};
|
|
||||||
}, [
|
}, [
|
||||||
showDocumentPopover,
|
showDocumentPopover,
|
||||||
showPromptPicker,
|
showPromptPicker,
|
||||||
|
|
@ -652,7 +605,6 @@ const Composer: FC = () => {
|
||||||
clipboardInitialText,
|
clipboardInitialText,
|
||||||
aui,
|
aui,
|
||||||
setMentionedDocuments,
|
setMentionedDocuments,
|
||||||
threadViewportStore,
|
|
||||||
]);
|
]);
|
||||||
|
|
||||||
const handleDocumentRemove = useCallback(
|
const handleDocumentRemove = useCallback(
|
||||||
|
|
@ -1317,12 +1269,14 @@ const ComposerAction: FC<ComposerActionProps> = ({ isBlockedByOtherUser = false
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
/** Convert snake_case tool names to human-readable labels */
|
/**
|
||||||
|
* Friendly tool name for display in the chat UI. Delegates to the
|
||||||
|
* shared map in ``contracts/enums/toolIcons`` so unix-style identifiers
|
||||||
|
* (``rm``, ``ls``, ``grep`` …) and snake_cased function names render as
|
||||||
|
* plain English (e.g. "Delete file", "List files", "Search in files").
|
||||||
|
*/
|
||||||
function formatToolName(name: string): string {
|
function formatToolName(name: string): string {
|
||||||
return name
|
return getToolDisplayName(name);
|
||||||
.split("_")
|
|
||||||
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
|
|
||||||
.join(" ");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
interface ToolGroup {
|
interface ToolGroup {
|
||||||
|
|
|
||||||
|
|
@ -1,30 +1,277 @@
|
||||||
import type { ToolCallMessagePartComponent } from "@assistant-ui/react";
|
import { type ToolCallMessagePartComponent, useAuiState } from "@assistant-ui/react";
|
||||||
import { CheckIcon, ChevronDownIcon, ChevronUpIcon, XCircleIcon } from "lucide-react";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { useMemo, useState } from "react";
|
import { useAtomValue } from "jotai";
|
||||||
|
import { CheckIcon, ChevronDownIcon, RotateCcw, XCircleIcon } from "lucide-react";
|
||||||
|
import { useEffect, useMemo, useState } from "react";
|
||||||
|
import { toast } from "sonner";
|
||||||
|
import { chatSessionStateAtom } from "@/atoms/chat/chat-session-state.atom";
|
||||||
|
import { NestedScroll } from "@/components/assistant-ui/nested-scroll";
|
||||||
import {
|
import {
|
||||||
DoomLoopApprovalToolUI,
|
DoomLoopApprovalToolUI,
|
||||||
isDoomLoopInterrupt,
|
isDoomLoopInterrupt,
|
||||||
} from "@/components/tool-ui/doom-loop-approval";
|
} from "@/components/tool-ui/doom-loop-approval";
|
||||||
import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval";
|
import { GenericHitlApprovalToolUI } from "@/components/tool-ui/generic-hitl-approval";
|
||||||
import { getToolIcon } from "@/contracts/enums/toolIcons";
|
import {
|
||||||
|
AlertDialog,
|
||||||
|
AlertDialogAction,
|
||||||
|
AlertDialogCancel,
|
||||||
|
AlertDialogContent,
|
||||||
|
AlertDialogDescription,
|
||||||
|
AlertDialogFooter,
|
||||||
|
AlertDialogHeader,
|
||||||
|
AlertDialogTitle,
|
||||||
|
AlertDialogTrigger,
|
||||||
|
} from "@/components/ui/alert-dialog";
|
||||||
|
import { Badge } from "@/components/ui/badge";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { Card } from "@/components/ui/card";
|
||||||
|
import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible";
|
||||||
|
import { Separator } from "@/components/ui/separator";
|
||||||
|
import { Spinner } from "@/components/ui/spinner";
|
||||||
|
import { getToolDisplayName } from "@/contracts/enums/toolIcons";
|
||||||
|
import { markActionRevertedInCache, useAgentActionsQuery } from "@/hooks/use-agent-actions-query";
|
||||||
|
import { agentActionsApiService } from "@/lib/apis/agent-actions-api.service";
|
||||||
|
import { AppError } from "@/lib/error";
|
||||||
import { isInterruptResult } from "@/lib/hitl";
|
import { isInterruptResult } from "@/lib/hitl";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
function formatToolName(name: string): string {
|
/**
|
||||||
return name.replace(/_/g, " ").replace(/\b\w/g, (c) => c.toUpperCase());
|
* Inline Revert button rendered on a tool card when the matching
|
||||||
|
* ``AgentActionLog`` row is reversible and hasn't been reverted yet.
|
||||||
|
*
|
||||||
|
* Reads from the unified ``useAgentActionsQuery`` cache — the SAME
|
||||||
|
* react-query cache the agent-actions sheet consumes. SSE events
|
||||||
|
* (``data-action-log`` / ``data-action-log-updated``) and
|
||||||
|
* ``POST /threads/{id}/revert/{id}`` responses both flow through the
|
||||||
|
* cache via ``setQueryData`` helpers, so the card and the sheet stay
|
||||||
|
* in lockstep on every code path: page reload, navigation, live
|
||||||
|
* stream, post-stream reversibility flip, and explicit revert clicks.
|
||||||
|
*
|
||||||
|
* Match key (in priority order):
|
||||||
|
* 1. ``a.tool_call_id === toolCallId`` — direct hit in parity_v2 when
|
||||||
|
* the model streamed ``tool_call_chunks`` so the card's synthetic
|
||||||
|
* id IS the LangChain id.
|
||||||
|
* 2. ``a.tool_call_id === langchainToolCallId`` — legacy mode (or
|
||||||
|
* parity_v2 with provider-side chunk emission) where the card's
|
||||||
|
* synthetic id is ``call_<run_id>`` and the LangChain id is
|
||||||
|
* backfilled onto the part by ``tool-output-available``.
|
||||||
|
* 3. ``(chat_turn_id, tool_name, position-within-turn)`` — fallback
|
||||||
|
* for cards whose synthetic id is ``call_<run_id>`` AND whose
|
||||||
|
* ``langchainToolCallId`` never got backfilled (provider emitted
|
||||||
|
* the tool_call as a single payload with no chunks AND streaming
|
||||||
|
* pre-dated the ``tool-output-available langchainToolCallId``
|
||||||
|
* backfill, e.g. older threads). Reads the parent message's
|
||||||
|
* ``chatTurnId`` and ``content`` via ``useAuiState`` so we can
|
||||||
|
* match position-by-tool-name within the turn against the
|
||||||
|
* action_log rows the server returned in ``created_at`` order.
|
||||||
|
*/
|
||||||
|
function ToolCardRevertButton({
|
||||||
|
toolCallId,
|
||||||
|
toolName,
|
||||||
|
langchainToolCallId,
|
||||||
|
}: {
|
||||||
|
toolCallId: string;
|
||||||
|
toolName: string;
|
||||||
|
langchainToolCallId?: string;
|
||||||
|
}) {
|
||||||
|
const session = useAtomValue(chatSessionStateAtom);
|
||||||
|
const threadId = session?.threadId ?? null;
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
const { findByToolCallId, findByChatTurnAndTool } = useAgentActionsQuery(threadId);
|
||||||
|
|
||||||
|
// Parent message metadata, read via the narrowest possible
|
||||||
|
// selectors so this card doesn't re-render on every text-delta of
|
||||||
|
// every other part in the same message during streaming.
|
||||||
|
//
|
||||||
|
// IMPORTANT — ``useAuiState`` re-renders the component whenever the
|
||||||
|
// returned slice's identity changes. Returning ``message?.content``
|
||||||
|
// (an array) would re-render on every token because the runtime
|
||||||
|
// rebuilds the parts array. Returning a PRIMITIVE (the position
|
||||||
|
// number) lets ``useAuiState``'s ``Object.is`` check short-circuit
|
||||||
|
// when the position hasn't actually moved — which is the common
|
||||||
|
// case during text streaming, when only ``text``/``reasoning``
|
||||||
|
// parts are mutating and the same-toolName tool-call ordering is
|
||||||
|
// stable. (See Vercel React rule ``rerender-defer-reads``.)
|
||||||
|
const chatTurnId = useAuiState(({ message }) => {
|
||||||
|
const meta = message?.metadata as { custom?: { chatTurnId?: string } } | undefined;
|
||||||
|
return meta?.custom?.chatTurnId ?? null;
|
||||||
|
});
|
||||||
|
const positionInTurn = useAuiState(({ message }) => {
|
||||||
|
const content = message?.content;
|
||||||
|
if (!Array.isArray(content)) return -1;
|
||||||
|
let n = -1;
|
||||||
|
for (const part of content) {
|
||||||
|
if (
|
||||||
|
part &&
|
||||||
|
typeof part === "object" &&
|
||||||
|
(part as { type?: string }).type === "tool-call" &&
|
||||||
|
(part as { toolName?: string }).toolName === toolName
|
||||||
|
) {
|
||||||
|
n += 1;
|
||||||
|
if ((part as { toolCallId?: string }).toolCallId === toolCallId) return n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
});
|
||||||
|
|
||||||
|
const action = useMemo(() => {
|
||||||
|
// Tier 1 + 2: O(1) Map-backed direct id match. Covers
|
||||||
|
// ~all parity_v2 streams and any legacy stream that backfilled
|
||||||
|
// ``langchainToolCallId`` via ``tool-output-available``.
|
||||||
|
const direct = findByToolCallId(toolCallId) ?? findByToolCallId(langchainToolCallId);
|
||||||
|
if (direct) return direct;
|
||||||
|
// Tier 3: position-within-turn fallback. Only kicks in when the
|
||||||
|
// card has a synthetic ``call_<run_id>`` id AND no
|
||||||
|
// ``langchainToolCallId`` was ever backfilled — i.e. the tool
|
||||||
|
// was emitted as a single non-chunked payload AND streaming
|
||||||
|
// pre-dated the on_tool_end backfill.
|
||||||
|
if (!chatTurnId || positionInTurn < 0) return null;
|
||||||
|
const turnSameTool = findByChatTurnAndTool(chatTurnId, toolName);
|
||||||
|
return turnSameTool[positionInTurn] ?? null;
|
||||||
|
}, [
|
||||||
|
findByToolCallId,
|
||||||
|
findByChatTurnAndTool,
|
||||||
|
toolCallId,
|
||||||
|
langchainToolCallId,
|
||||||
|
chatTurnId,
|
||||||
|
toolName,
|
||||||
|
positionInTurn,
|
||||||
|
]);
|
||||||
|
|
||||||
|
const [isReverting, setIsReverting] = useState(false);
|
||||||
|
const [confirmOpen, setConfirmOpen] = useState(false);
|
||||||
|
|
||||||
|
if (!action) return null;
|
||||||
|
if (!action.reversible) return null;
|
||||||
|
if (action.reverted_by_action_id !== null && action.reverted_by_action_id !== undefined)
|
||||||
|
return null;
|
||||||
|
if (action.is_revert_action) return null;
|
||||||
|
if (action.error !== null && action.error !== undefined) return null;
|
||||||
|
if (!threadId) return null;
|
||||||
|
|
||||||
|
const handleRevert = async () => {
|
||||||
|
setIsReverting(true);
|
||||||
|
try {
|
||||||
|
const response = await agentActionsApiService.revert(threadId, action.id);
|
||||||
|
markActionRevertedInCache(queryClient, threadId, action.id, response.new_action_id ?? null);
|
||||||
|
toast.success(response.message || "Action reverted.");
|
||||||
|
} catch (err) {
|
||||||
|
// 503 means revert is gated off on this deployment — hide the
|
||||||
|
// button silently rather than nagging the user. Any other error
|
||||||
|
// is surfaced as a toast so the operator can investigate.
|
||||||
|
if (err instanceof AppError && err.status === 503) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const message =
|
||||||
|
err instanceof AppError
|
||||||
|
? err.message
|
||||||
|
: err instanceof Error
|
||||||
|
? err.message
|
||||||
|
: "Failed to revert action.";
|
||||||
|
toast.error(message);
|
||||||
|
} finally {
|
||||||
|
setIsReverting(false);
|
||||||
|
setConfirmOpen(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AlertDialog open={confirmOpen} onOpenChange={setConfirmOpen}>
|
||||||
|
<AlertDialogTrigger asChild>
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
variant="outline"
|
||||||
|
className="gap-1.5"
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
setConfirmOpen(true);
|
||||||
|
}}
|
||||||
|
disabled={isReverting}
|
||||||
|
>
|
||||||
|
{isReverting ? (
|
||||||
|
// Spinner's typed props don't accept ``data-icon`` and
|
||||||
|
// it renders an <output>, not an <svg>, so Button's
|
||||||
|
// auto-sizing rule doesn't apply. Bare spinner +
|
||||||
|
// Button's gap handle layout.
|
||||||
|
<Spinner size="xs" />
|
||||||
|
) : (
|
||||||
|
<RotateCcw data-icon="inline-start" />
|
||||||
|
)}
|
||||||
|
Revert
|
||||||
|
</Button>
|
||||||
|
</AlertDialogTrigger>
|
||||||
|
<AlertDialogContent>
|
||||||
|
<AlertDialogHeader>
|
||||||
|
<AlertDialogTitle>Revert this action?</AlertDialogTitle>
|
||||||
|
<AlertDialogDescription>
|
||||||
|
This will undo{" "}
|
||||||
|
<span className="font-medium">{getToolDisplayName(action.tool_name)}</span> and add a
|
||||||
|
new entry to the history. Your chat is preserved — only the changes the agent made to
|
||||||
|
your knowledge base or connected apps will be rolled back where possible.
|
||||||
|
</AlertDialogDescription>
|
||||||
|
</AlertDialogHeader>
|
||||||
|
<AlertDialogFooter>
|
||||||
|
<AlertDialogCancel disabled={isReverting}>Cancel</AlertDialogCancel>
|
||||||
|
<AlertDialogAction
|
||||||
|
onClick={(e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
handleRevert();
|
||||||
|
}}
|
||||||
|
disabled={isReverting}
|
||||||
|
className="gap-1.5"
|
||||||
|
>
|
||||||
|
{isReverting && <Spinner size="xs" />}
|
||||||
|
Revert
|
||||||
|
</AlertDialogAction>
|
||||||
|
</AlertDialogFooter>
|
||||||
|
</AlertDialogContent>
|
||||||
|
</AlertDialog>
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({
|
/**
|
||||||
toolName,
|
* Compact tool-call card.
|
||||||
argsText,
|
*
|
||||||
result,
|
* shadcn composition note: we intentionally use ``Card`` as a visual
|
||||||
status,
|
* frame WITHOUT ``CardHeader / CardContent``. The full composition's
|
||||||
}) => {
|
* ``p-6`` padding doesn't fit a compact collapsible header that IS the
|
||||||
const [isExpanded, setIsExpanded] = useState(false);
|
* trigger; using ``Card`` alone preserves the rounded border, shadow,
|
||||||
|
* and ``bg-card`` token (semantic colors) without forcing a layout
|
||||||
|
* that doesn't fit. All status colors use semantic tokens — no manual
|
||||||
|
* dark-mode overrides, no raw hex.
|
||||||
|
*/
|
||||||
|
const DefaultToolFallbackInner: ToolCallMessagePartComponent = (props) => {
|
||||||
|
const { toolCallId, toolName, argsText, result, status } = props;
|
||||||
|
// ``langchainToolCallId`` is a SurfSense-specific extension the
|
||||||
|
// streaming pipeline attaches to the tool-call content part so
|
||||||
|
// the Revert button can resolve its ``AgentActionLog`` row even
|
||||||
|
// when only the LC id is known. assistant-ui's
|
||||||
|
// ``ToolCallMessagePartProps`` doesn't list it, but the runtime
|
||||||
|
// spreads ``{...part}`` so the prop reaches us at runtime.
|
||||||
|
const langchainToolCallId = (props as { langchainToolCallId?: string }).langchainToolCallId;
|
||||||
|
|
||||||
const isCancelled = status?.type === "incomplete" && status.reason === "cancelled";
|
const isCancelled = status?.type === "incomplete" && status.reason === "cancelled";
|
||||||
const isError = status?.type === "incomplete" && status.reason === "error";
|
const isError = status?.type === "incomplete" && status.reason === "error";
|
||||||
const isRunning = status?.type === "running" || status?.type === "requires-action";
|
const isRunning = status?.type === "running" || status?.type === "requires-action";
|
||||||
|
|
||||||
|
/*
|
||||||
|
Per-card expansion state. Initial value is ``isRunning`` so a
|
||||||
|
card streaming in mounts already-expanded (no flash of
|
||||||
|
collapsed → expanded on first paint), while a card loaded from
|
||||||
|
history (status="complete") mounts collapsed. The useEffect
|
||||||
|
below keeps this in lockstep with this card's own ``isRunning``
|
||||||
|
when it transitions: false → true auto-expands (e.g. a tool
|
||||||
|
that re-runs after edit), true → false auto-collapses once the
|
||||||
|
tool finishes. Because the dep is per-card ``isRunning`` and
|
||||||
|
not the chat-level streaming flag, sibling cards on the same
|
||||||
|
assistant turn each manage their own expansion independently.
|
||||||
|
Once ``isRunning`` is false the user controls expansion via
|
||||||
|
``onOpenChange``.
|
||||||
|
*/
|
||||||
|
const [isExpanded, setIsExpanded] = useState(isRunning);
|
||||||
|
useEffect(() => {
|
||||||
|
setIsExpanded(isRunning);
|
||||||
|
}, [isRunning]);
|
||||||
const errorData = status?.type === "incomplete" ? status.error : undefined;
|
const errorData = status?.type === "incomplete" ? status.error : undefined;
|
||||||
const serializedError = useMemo(
|
const serializedError = useMemo(
|
||||||
() => (errorData && typeof errorData !== "string" ? JSON.stringify(errorData) : null),
|
() => (errorData && typeof errorData !== "string" ? JSON.stringify(errorData) : null),
|
||||||
|
|
@ -50,21 +297,72 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({
|
||||||
: serializedError
|
: serializedError
|
||||||
: null;
|
: null;
|
||||||
|
|
||||||
const Icon = getToolIcon(toolName);
|
const displayName = getToolDisplayName(toolName);
|
||||||
const displayName = formatToolName(toolName);
|
const subtitle = errorReason ?? cancelledReason;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<Card
|
||||||
className={cn(
|
className={cn(
|
||||||
"my-4 max-w-lg overflow-hidden rounded-2xl border bg-muted/30 select-none",
|
"my-4 max-w-lg overflow-hidden",
|
||||||
isCancelled && "opacity-60",
|
isCancelled && "opacity-60",
|
||||||
isError && "border-destructive/20 bg-destructive/5"
|
isError && "border-destructive/30"
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
|
{/*
|
||||||
|
``group`` lets the chevron (rendered as a sibling of the
|
||||||
|
main trigger button) read the Collapsible Root's
|
||||||
|
``data-[state=open]`` for rotation. The Collapsible is
|
||||||
|
fully controlled via ``isExpanded`` — the useEffect
|
||||||
|
above syncs it to ``isRunning`` so the card auto-opens
|
||||||
|
while a tool streams in and auto-collapses once it
|
||||||
|
finishes. We deliberately DON'T pass ``disabled`` so
|
||||||
|
both triggers stay clickable; ``onOpenChange`` is wired
|
||||||
|
to a setter that no-ops while ``isRunning`` (see
|
||||||
|
``handleOpenChange`` below) which keeps the card pinned
|
||||||
|
open mid-stream without losing keyboard / pointer
|
||||||
|
affordance the moment streaming ends.
|
||||||
|
*/}
|
||||||
|
<Collapsible
|
||||||
|
className="group"
|
||||||
|
open={isExpanded}
|
||||||
|
onOpenChange={(next) => {
|
||||||
|
// Block manual collapse while the tool is still
|
||||||
|
// streaming — otherwise a stray click on either
|
||||||
|
// trigger would close the card and hide the live
|
||||||
|
// ``argsText`` panel mid-run. After streaming the
|
||||||
|
// user has full control again.
|
||||||
|
if (isRunning) return;
|
||||||
|
setIsExpanded(next);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{/*
|
||||||
|
Header row: main trigger on the left (icon + title
|
||||||
|
col), Revert + chevron-trigger on the right as
|
||||||
|
siblings of the main trigger. The chevron is wrapped
|
||||||
|
in its OWN ``CollapsibleTrigger`` (Radix supports
|
||||||
|
multiple triggers per Root) so clicking the chevron
|
||||||
|
toggles the same state as clicking the title row.
|
||||||
|
The Revert button stays a separate AlertDialog
|
||||||
|
trigger and stops propagation in its onClick so it
|
||||||
|
doesn't toggle the collapsible while opening the
|
||||||
|
confirm dialog. Keeping these as flat siblings —
|
||||||
|
rather than nesting Revert / chevron inside the
|
||||||
|
title trigger — avoids invalid HTML
|
||||||
|
(button-in-button) and lets the Revert button
|
||||||
|
render in BOTH the collapsed and expanded states.
|
||||||
|
*/}
|
||||||
|
<div className="flex items-stretch transition-colors hover:bg-muted/50">
|
||||||
|
<CollapsibleTrigger asChild>
|
||||||
<button
|
<button
|
||||||
type="button"
|
type="button"
|
||||||
onClick={() => setIsExpanded((prev) => !prev)}
|
className={cn(
|
||||||
className="flex w-full items-center gap-3 px-5 py-4 text-left transition-colors hover:bg-muted/50 focus:outline-none focus-visible:outline-none"
|
"flex flex-1 min-w-0 items-center gap-3 py-4 pl-5 pr-2 text-left",
|
||||||
|
// Inset ring — Card's ``overflow-hidden`` would
|
||||||
|
// clip an ``offset-2`` ring; ``ring-inset``
|
||||||
|
// paints inside the button box.
|
||||||
|
"focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-inset",
|
||||||
|
"disabled:cursor-default"
|
||||||
|
)}
|
||||||
>
|
>
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
|
|
@ -77,78 +375,129 @@ const DefaultToolFallbackInner: ToolCallMessagePartComponent = ({
|
||||||
) : isCancelled ? (
|
) : isCancelled ? (
|
||||||
<XCircleIcon className="size-4 text-muted-foreground" />
|
<XCircleIcon className="size-4 text-muted-foreground" />
|
||||||
) : isRunning ? (
|
) : isRunning ? (
|
||||||
<Icon className="size-4 text-primary animate-pulse" />
|
<Spinner size="sm" className="text-primary" />
|
||||||
) : (
|
) : (
|
||||||
<CheckIcon className="size-4 text-primary" />
|
<CheckIcon className="size-4 text-primary" />
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="flex-1 min-w-0">
|
<div className="flex flex-1 min-w-0 flex-col gap-0.5">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
<p
|
<p
|
||||||
className={cn(
|
className={cn(
|
||||||
"text-sm font-semibold",
|
"text-sm font-semibold truncate",
|
||||||
isError
|
isCancelled && "text-muted-foreground line-through",
|
||||||
? "text-destructive"
|
isError && "text-destructive"
|
||||||
: isCancelled
|
|
||||||
? "text-muted-foreground line-through"
|
|
||||||
: "text-foreground"
|
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{isRunning
|
{displayName}
|
||||||
? displayName
|
|
||||||
: isCancelled
|
|
||||||
? `Cancelled: ${displayName}`
|
|
||||||
: isError
|
|
||||||
? `Failed: ${displayName}`
|
|
||||||
: displayName}
|
|
||||||
</p>
|
</p>
|
||||||
{isRunning && <p className="text-xs text-muted-foreground mt-0.5">Running...</p>}
|
{isRunning && <Badge variant="secondary">Running</Badge>}
|
||||||
{cancelledReason && (
|
{isError && <Badge variant="destructive">Failed</Badge>}
|
||||||
<p className="text-xs text-muted-foreground mt-0.5 truncate">{cancelledReason}</p>
|
{isCancelled && <Badge variant="outline">Cancelled</Badge>}
|
||||||
|
</div>
|
||||||
|
{subtitle && (
|
||||||
|
<p
|
||||||
|
className={cn(
|
||||||
|
"text-xs truncate",
|
||||||
|
isError ? "text-destructive/80" : "text-muted-foreground"
|
||||||
)}
|
)}
|
||||||
{errorReason && (
|
>
|
||||||
<p className="text-xs text-destructive/80 mt-0.5 truncate">{errorReason}</p>
|
{subtitle}
|
||||||
|
</p>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{!isRunning && (
|
|
||||||
<div className="shrink-0 text-muted-foreground">
|
|
||||||
{isExpanded ? (
|
|
||||||
<ChevronDownIcon className="size-4" />
|
|
||||||
) : (
|
|
||||||
<ChevronUpIcon className="size-4" />
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</button>
|
</button>
|
||||||
|
</CollapsibleTrigger>
|
||||||
|
|
||||||
{isExpanded && !isRunning && (
|
{/*
|
||||||
<>
|
Right-side controls. The Revert button is
|
||||||
<div className="mx-5 h-px bg-border/50" />
|
visible whenever the matching action is
|
||||||
<div className="px-5 py-3 space-y-3">
|
reversible — including the collapsed state —
|
||||||
{argsText && (
|
but ``ToolCardRevertButton`` itself returns
|
||||||
<div>
|
``null`` while a tool is still running because
|
||||||
<p className="text-xs font-medium text-muted-foreground mb-1">Arguments</p>
|
no action-log row exists yet, so it doesn't
|
||||||
<pre className="text-xs text-foreground/80 whitespace-pre-wrap break-all">
|
need an explicit ``isRunning`` gate here.
|
||||||
|
*/}
|
||||||
|
<div className="flex shrink-0 items-center gap-2 pl-2 pr-5">
|
||||||
|
<ToolCardRevertButton
|
||||||
|
toolCallId={toolCallId}
|
||||||
|
toolName={toolName}
|
||||||
|
langchainToolCallId={langchainToolCallId}
|
||||||
|
/>
|
||||||
|
<CollapsibleTrigger asChild>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
aria-label={isExpanded ? "Collapse details" : "Expand details"}
|
||||||
|
className={cn(
|
||||||
|
"flex size-7 shrink-0 items-center justify-center rounded-md",
|
||||||
|
"text-muted-foreground hover:bg-muted hover:text-foreground",
|
||||||
|
"focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-inset",
|
||||||
|
"disabled:cursor-default"
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<ChevronDownIcon
|
||||||
|
className={cn(
|
||||||
|
"size-4 transition-transform duration-200",
|
||||||
|
"group-data-[state=open]:rotate-180"
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
</CollapsibleTrigger>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/*
|
||||||
|
CollapsibleContent body — auto-open while streaming
|
||||||
|
(see ``open`` prop above) so the live ``argsText``
|
||||||
|
streams into the Inputs panel directly, no need for
|
||||||
|
a separate "Live input" panel. Native
|
||||||
|
``overflow-auto`` instead of ``ScrollArea`` because
|
||||||
|
Radix's Viewport can let content bleed past
|
||||||
|
``max-h-*`` in dynamic flex layouts. ``min-w-0`` on
|
||||||
|
the column wrappers guarantees ``break-all`` wraps
|
||||||
|
correctly within the bounded ``max-w-lg`` Card.
|
||||||
|
*/}
|
||||||
|
<CollapsibleContent>
|
||||||
|
<Separator />
|
||||||
|
<div className="flex flex-col gap-3 px-5 py-3">
|
||||||
|
{(argsText || isRunning) && (
|
||||||
|
<div className="flex flex-col gap-1 min-w-0">
|
||||||
|
<p className="text-xs font-medium text-muted-foreground">Inputs</p>
|
||||||
|
<NestedScroll className="max-h-48 overflow-auto rounded-md bg-muted/40">
|
||||||
|
{argsText ? (
|
||||||
|
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
|
||||||
{argsText}
|
{argsText}
|
||||||
</pre>
|
</pre>
|
||||||
|
) : (
|
||||||
|
// Bridges the brief gap between
|
||||||
|
// ``tool-input-start`` (creates the
|
||||||
|
// card, ``argsText`` undefined) and
|
||||||
|
// the first ``tool-input-delta``.
|
||||||
|
<p className="px-3 py-2 text-xs italic text-muted-foreground">
|
||||||
|
Waiting for input…
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
</NestedScroll>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
{!isCancelled && result !== undefined && (
|
{!isCancelled && result !== undefined && (
|
||||||
<>
|
<>
|
||||||
<div className="h-px bg-border/30" />
|
<Separator />
|
||||||
<div>
|
<div className="flex flex-col gap-1 min-w-0">
|
||||||
<p className="text-xs font-medium text-muted-foreground mb-1">Result</p>
|
<p className="text-xs font-medium text-muted-foreground">Result</p>
|
||||||
<pre className="text-xs text-foreground/80 whitespace-pre-wrap break-all">
|
<NestedScroll className="max-h-64 overflow-auto rounded-md bg-muted/40">
|
||||||
|
<pre className="px-3 py-2 text-xs text-foreground/80 whitespace-pre-wrap break-all font-mono">
|
||||||
{typeof result === "string" ? result : serializedResult}
|
{typeof result === "string" ? result : serializedResult}
|
||||||
</pre>
|
</pre>
|
||||||
|
</NestedScroll>
|
||||||
</div>
|
</div>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
</>
|
</CollapsibleContent>
|
||||||
)}
|
</Collapsible>
|
||||||
</div>
|
</Card>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,10 @@
|
||||||
import { ActionBarPrimitive, AuiIf, MessagePrimitive, useAuiState } from "@assistant-ui/react";
|
import {
|
||||||
|
ActionBarPrimitive,
|
||||||
|
AuiIf,
|
||||||
|
MessagePrimitive,
|
||||||
|
useAuiState,
|
||||||
|
useMessagePartText,
|
||||||
|
} from "@assistant-ui/react";
|
||||||
import { useAtomValue } from "jotai";
|
import { useAtomValue } from "jotai";
|
||||||
import { CheckIcon, CopyIcon, Pencil } from "lucide-react";
|
import { CheckIcon, CopyIcon, Pencil } from "lucide-react";
|
||||||
import Image from "next/image";
|
import Image from "next/image";
|
||||||
|
|
@ -7,6 +13,8 @@ import { currentThreadAtom } from "@/atoms/chat/current-thread.atom";
|
||||||
import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom";
|
import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom";
|
||||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||||
|
import { getMentionDocKey } from "@/lib/chat/mention-doc-key";
|
||||||
|
import { parseMentionSegments } from "@/lib/chat/parse-mention-segments";
|
||||||
|
|
||||||
interface AuthorMetadata {
|
interface AuthorMetadata {
|
||||||
displayName: string | null;
|
displayName: string | null;
|
||||||
|
|
@ -47,23 +55,40 @@ const UserAvatar: FC<AuthorMetadata> = ({ displayName, avatarUrl }) => {
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export const UserMessage: FC = () => {
|
const UserTextPart: FC = () => {
|
||||||
const messageId = useAuiState(({ message }) => message?.id);
|
const messageId = useAuiState(({ message }) => message?.id);
|
||||||
const messageText = useAuiState(({ message }) =>
|
const part = useMessagePartText();
|
||||||
(message?.content ?? [])
|
const text = (part as { text?: string }).text ?? "";
|
||||||
.map((part) =>
|
|
||||||
typeof part === "object" &&
|
|
||||||
part !== null &&
|
|
||||||
"type" in part &&
|
|
||||||
(part as { type?: string }).type === "text" &&
|
|
||||||
"text" in part
|
|
||||||
? String((part as { text?: string }).text ?? "")
|
|
||||||
: ""
|
|
||||||
)
|
|
||||||
.join("")
|
|
||||||
);
|
|
||||||
const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom);
|
const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom);
|
||||||
const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined;
|
const mentionedDocs = (messageId ? messageDocumentsMap[messageId] : undefined) ?? [];
|
||||||
|
|
||||||
|
const segments = parseMentionSegments(text, mentionedDocs);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<p style={{ whiteSpace: "pre-line" }} className="break-words">
|
||||||
|
{segments.map((segment) =>
|
||||||
|
segment.type === "text" ? (
|
||||||
|
<span key={`txt-${segment.start}`}>{segment.value}</span>
|
||||||
|
) : (
|
||||||
|
<span
|
||||||
|
key={`mention-${getMentionDocKey(segment.doc)}-${segment.start}`}
|
||||||
|
className="inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none align-middle leading-none"
|
||||||
|
title={segment.doc.title}
|
||||||
|
>
|
||||||
|
<span className="flex items-center text-muted-foreground">
|
||||||
|
{getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")}
|
||||||
|
</span>
|
||||||
|
<span className="max-w-[120px] truncate">{segment.doc.title}</span>
|
||||||
|
</span>
|
||||||
|
)
|
||||||
|
)}
|
||||||
|
</p>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const userMessageParts = { Text: UserTextPart };
|
||||||
|
|
||||||
|
export const UserMessage: FC = () => {
|
||||||
const metadata = useAuiState(({ message }) => message?.metadata);
|
const metadata = useAuiState(({ message }) => message?.metadata);
|
||||||
const author = metadata?.custom?.author as AuthorMetadata | undefined;
|
const author = metadata?.custom?.author as AuthorMetadata | undefined;
|
||||||
const isSharedChat = useAtomValue(currentThreadAtom).visibility === "SEARCH_SPACE";
|
const isSharedChat = useAtomValue(currentThreadAtom).visibility === "SEARCH_SPACE";
|
||||||
|
|
@ -77,12 +102,8 @@ export const UserMessage: FC = () => {
|
||||||
<div className="col-start-2 min-w-0">
|
<div className="col-start-2 min-w-0">
|
||||||
<div className="aui-user-message-content-wrapper flex items-end gap-2">
|
<div className="aui-user-message-content-wrapper flex items-end gap-2">
|
||||||
<div className="relative flex-1 min-w-0">
|
<div className="relative flex-1 min-w-0">
|
||||||
<div className="aui-user-message-content wrap-break-word rounded-xl bg-muted px-4 py-2.5 text-foreground">
|
<div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground">
|
||||||
{mentionedDocs && mentionedDocs.length > 0 ? (
|
<MessagePrimitive.Parts components={userMessageParts} />
|
||||||
<UserMessageWithMentionChips text={messageText} mentionedDocs={mentionedDocs} />
|
|
||||||
) : (
|
|
||||||
<MessagePrimitive.Parts />
|
|
||||||
)}
|
|
||||||
</div>
|
</div>
|
||||||
<div className="absolute right-0 top-full mt-1 z-10 opacity-100 pointer-events-auto md:opacity-0 md:pointer-events-none md:transition-opacity md:duration-200 md:delay-300 md:group-hover/user-msg:opacity-100 md:group-hover/user-msg:delay-0 md:group-hover/user-msg:pointer-events-auto">
|
<div className="absolute right-0 top-full mt-1 z-10 opacity-100 pointer-events-auto md:opacity-0 md:pointer-events-none md:transition-opacity md:duration-200 md:delay-300 md:group-hover/user-msg:opacity-100 md:group-hover/user-msg:delay-0 md:group-hover/user-msg:pointer-events-auto">
|
||||||
<UserActionBar />
|
<UserActionBar />
|
||||||
|
|
@ -99,64 +120,6 @@ export const UserMessage: FC = () => {
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
const UserMessageWithMentionChips: FC<{
|
|
||||||
text: string;
|
|
||||||
mentionedDocs: { id: number; title: string; document_type: string }[];
|
|
||||||
}> = ({ text, mentionedDocs }) => {
|
|
||||||
type Segment =
|
|
||||||
| { type: "text"; value: string; start: number }
|
|
||||||
| { type: "mention"; doc: { id: number; title: string; document_type: string }; start: number };
|
|
||||||
|
|
||||||
const tokens = mentionedDocs
|
|
||||||
.map((doc) => ({ doc, token: `@${doc.title}` }))
|
|
||||||
.sort((a, b) => b.token.length - a.token.length);
|
|
||||||
|
|
||||||
const segments: Segment[] = [];
|
|
||||||
let i = 0;
|
|
||||||
let buffer = "";
|
|
||||||
let bufferStart = 0;
|
|
||||||
while (i < text.length) {
|
|
||||||
const tokenMatch = tokens.find(({ token }) => text.startsWith(token, i));
|
|
||||||
if (tokenMatch) {
|
|
||||||
if (buffer) {
|
|
||||||
segments.push({ type: "text", value: buffer, start: bufferStart });
|
|
||||||
buffer = "";
|
|
||||||
}
|
|
||||||
segments.push({ type: "mention", doc: tokenMatch.doc, start: i });
|
|
||||||
i += tokenMatch.token.length;
|
|
||||||
bufferStart = i;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (!buffer) bufferStart = i;
|
|
||||||
buffer += text[i];
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
if (buffer) {
|
|
||||||
segments.push({ type: "text", value: buffer, start: bufferStart });
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<span className="whitespace-pre-wrap break-words">
|
|
||||||
{segments.map((segment) =>
|
|
||||||
segment.type === "text" ? (
|
|
||||||
<span key={`txt-${segment.start}`}>{segment.value}</span>
|
|
||||||
) : (
|
|
||||||
<span
|
|
||||||
key={`mention-${segment.doc.document_type}:${segment.doc.id}-${segment.start}`}
|
|
||||||
className="inline-flex items-center gap-1 mx-0.5 px-1 py-0.5 rounded bg-primary/10 text-xs font-bold text-primary/60 select-none align-baseline"
|
|
||||||
title={segment.doc.title}
|
|
||||||
>
|
|
||||||
<span className="flex items-center text-muted-foreground">
|
|
||||||
{getConnectorIcon(segment.doc.document_type ?? "UNKNOWN", "h-3 w-3")}
|
|
||||||
</span>
|
|
||||||
<span className="max-w-[120px] truncate">{segment.doc.title}</span>
|
|
||||||
</span>
|
|
||||||
)
|
|
||||||
)}
|
|
||||||
</span>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const UserActionBar: FC = () => {
|
const UserActionBar: FC = () => {
|
||||||
const isThreadRunning = useAuiState(({ thread }) => thread.isRunning);
|
const isThreadRunning = useAuiState(({ thread }) => thread.isRunning);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -169,7 +169,7 @@ export const CitationPanelContent: FC<CitationPanelContentProps> = ({ chunkId, o
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
<div className="text-sm">
|
<div className="text-sm">
|
||||||
<MarkdownViewer content={chunk.content} />
|
<MarkdownViewer content={chunk.content} enableCitations />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|
|
||||||
77
surfsense_web/components/citations/citation-renderer.tsx
Normal file
77
surfsense_web/components/citations/citation-renderer.tsx
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import type { ReactNode } from "react";
|
||||||
|
import { InlineCitation, UrlCitation } from "@/components/assistant-ui/inline-citation";
|
||||||
|
import {
|
||||||
|
type CitationToken,
|
||||||
|
type CitationUrlMap,
|
||||||
|
parseTextWithCitations,
|
||||||
|
} from "@/lib/citations/citation-parser";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Render a single parsed citation token as JSX.
|
||||||
|
*
|
||||||
|
* `ordinalKey` should be a stable per-render counter so duplicate identical
|
||||||
|
* citations within the same parent don't collide on `key`. The previous
|
||||||
|
* implementation in `markdown-text.tsx` used the source string itself as
|
||||||
|
* the key, which produced React warnings when two segments rendered the
|
||||||
|
* same `[citation:N]` text.
|
||||||
|
*/
|
||||||
|
export function renderCitationToken(token: CitationToken, ordinalKey: number): ReactNode {
|
||||||
|
if (token.kind === "url") {
|
||||||
|
return <UrlCitation key={`citation-url-${ordinalKey}`} url={token.url} />;
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<InlineCitation
|
||||||
|
key={`citation-${token.isDocsChunk ? "doc-" : ""}${token.chunkId}-${ordinalKey}`}
|
||||||
|
chunkId={token.chunkId}
|
||||||
|
isDocsChunk={token.isDocsChunk}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Walk a `ReactNode` (string, array, or arbitrary node) and replace any
|
||||||
|
* `[citation:...]` tokens inside string children with citation badges.
|
||||||
|
*
|
||||||
|
* Designed for use inside `Streamdown`/`react-markdown` `components`
|
||||||
|
* overrides where the renderer hands you `children`. Non-string children
|
||||||
|
* are returned untouched so block/phrasing structure is preserved.
|
||||||
|
*/
|
||||||
|
export function processChildrenWithCitations(
|
||||||
|
children: ReactNode,
|
||||||
|
urlMap: CitationUrlMap
|
||||||
|
): ReactNode {
|
||||||
|
if (typeof children === "string") {
|
||||||
|
const segments = parseTextWithCitations(children, urlMap);
|
||||||
|
if (segments.length === 1 && typeof segments[0] === "string") {
|
||||||
|
return children;
|
||||||
|
}
|
||||||
|
let ordinal = 0;
|
||||||
|
return segments.map((segment) =>
|
||||||
|
typeof segment === "string" ? segment : renderCitationToken(segment, ordinal++)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Array.isArray(children)) {
|
||||||
|
let ordinal = 0;
|
||||||
|
return children.map((child, childIndex) => {
|
||||||
|
if (typeof child === "string") {
|
||||||
|
const segments = parseTextWithCitations(child, urlMap);
|
||||||
|
if (segments.length === 1 && typeof segments[0] === "string") {
|
||||||
|
return child;
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<span key={`citation-seg-${childIndex}`}>
|
||||||
|
{segments.map((segment) =>
|
||||||
|
typeof segment === "string" ? segment : renderCitationToken(segment, ordinal++)
|
||||||
|
)}
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return child;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return children;
|
||||||
|
}
|
||||||
|
|
@ -32,7 +32,7 @@ export function DocumentViewer({ title, content, trigger }: DocumentViewerProps)
|
||||||
<DialogTitle>{title}</DialogTitle>
|
<DialogTitle>{title}</DialogTitle>
|
||||||
</DialogHeader>
|
</DialogHeader>
|
||||||
<div className="mt-4">
|
<div className="mt-4">
|
||||||
<MarkdownViewer content={content} />
|
<MarkdownViewer content={content} enableCitations />
|
||||||
</div>
|
</div>
|
||||||
</DialogContent>
|
</DialogContent>
|
||||||
</Dialog>
|
</Dialog>
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue