mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-02 19:55:18 +02:00
feat: updated agent harness
This commit is contained in:
parent
9ec9b64348
commit
31a372bb84
139 changed files with 12583 additions and 1111 deletions
|
|
@ -247,3 +247,42 @@ LANGSMITH_TRACING=true
|
|||
LANGSMITH_ENDPOINT=https://api.smith.langchain.com
|
||||
LANGSMITH_API_KEY=lsv2_pt_.....
|
||||
LANGSMITH_PROJECT=surfsense
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OPTIONAL: New-chat agent feature flags (OpenCode-port)
|
||||
# =============================================================================
|
||||
# Master kill-switch — when true, every flag below is forced OFF.
|
||||
# SURFSENSE_DISABLE_NEW_AGENT_STACK=false
|
||||
|
||||
# Tier 1 — Agent quality
|
||||
# SURFSENSE_ENABLE_CONTEXT_EDITING=false
|
||||
# SURFSENSE_ENABLE_COMPACTION_V2=false
|
||||
# SURFSENSE_ENABLE_RETRY_AFTER=false
|
||||
# SURFSENSE_ENABLE_MODEL_FALLBACK=false
|
||||
# SURFSENSE_ENABLE_MODEL_CALL_LIMIT=false
|
||||
# SURFSENSE_ENABLE_TOOL_CALL_LIMIT=false
|
||||
# SURFSENSE_ENABLE_TOOL_CALL_REPAIR=false
|
||||
# SURFSENSE_ENABLE_DOOM_LOOP=false # leave OFF until UI handles permission='doom_loop'
|
||||
|
||||
# Tier 2 — Safety
|
||||
# SURFSENSE_ENABLE_PERMISSION=false
|
||||
# SURFSENSE_ENABLE_BUSY_MUTEX=false
|
||||
# SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
|
||||
|
||||
# Tier 3b — Observability (also requires OTEL_EXPORTER_OTLP_ENDPOINT)
|
||||
# SURFSENSE_ENABLE_OTEL=false
|
||||
|
||||
# Tier 4 — Skills + subagents
|
||||
# SURFSENSE_ENABLE_SKILLS=false
|
||||
# SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS=false
|
||||
# SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE=false
|
||||
|
||||
# Tier 5 — Snapshot / revert
|
||||
# SURFSENSE_ENABLE_ACTION_LOG=false
|
||||
# SURFSENSE_ENABLE_REVERT_ROUTE=false # Backend-only; flip when UI ships
|
||||
|
||||
# Tier 6 — Plugins
|
||||
# SURFSENSE_ENABLE_PLUGIN_LOADER=false
|
||||
# Comma-separated allowlist of plugin entry-point names
|
||||
# SURFSENSE_ALLOWED_PLUGINS=year_substituter
|
||||
|
|
|
|||
|
|
@ -0,0 +1,94 @@
|
|||
"""130_add_agent_action_log
|
||||
|
||||
Revision ID: 130
|
||||
Revises: 129
|
||||
Create Date: 2026-04-28
|
||||
|
||||
Tier 5.2 in the OpenCode-port plan. Adds the append-only ``agent_action_log``
|
||||
table that :class:`ActionLogMiddleware` writes to after every tool call.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "130"
|
||||
down_revision: str | None = "129"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"agent_action_log",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column(
|
||||
"thread_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("new_chat_threads.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("user.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
),
|
||||
sa.Column(
|
||||
"search_space_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
),
|
||||
sa.Column("turn_id", sa.String(length=64), nullable=True, index=True),
|
||||
sa.Column("message_id", sa.String(length=128), nullable=True, index=True),
|
||||
sa.Column("tool_name", sa.String(length=255), nullable=False, index=True),
|
||||
sa.Column("args", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("result_id", sa.String(length=255), nullable=True),
|
||||
sa.Column(
|
||||
"reversible",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
sa.Column(
|
||||
"reverse_descriptor",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("error", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column(
|
||||
"reverse_of",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("agent_action_log.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("(now() AT TIME ZONE 'utc')"),
|
||||
index=True,
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_agent_action_log_thread_created",
|
||||
"agent_action_log",
|
||||
["thread_id", "created_at"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
"ix_agent_action_log_thread_created", table_name="agent_action_log"
|
||||
)
|
||||
op.drop_table("agent_action_log")
|
||||
119
surfsense_backend/alembic/versions/131_add_document_revisions.py
Normal file
119
surfsense_backend/alembic/versions/131_add_document_revisions.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
"""131_add_document_revisions
|
||||
|
||||
Revision ID: 131
|
||||
Revises: 130
|
||||
Create Date: 2026-04-28
|
||||
|
||||
Tier 5.1 in the OpenCode-port plan. Adds two snapshot tables:
|
||||
|
||||
* ``document_revisions``: pre-mutation snapshot of NOTE/FILE/EXTENSION docs.
|
||||
* ``folder_revisions``: pre-mutation snapshot of folder mkdir/move/delete.
|
||||
|
||||
Both are written by :class:`KnowledgeBasePersistenceMiddleware` ahead of
|
||||
state-changing tool calls and consumed by ``revert_service.revert_action``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "131"
|
||||
down_revision: str | None = "130"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"document_revisions",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column(
|
||||
"document_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("documents.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
),
|
||||
sa.Column(
|
||||
"search_space_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
),
|
||||
sa.Column("content_before", sa.Text(), nullable=True),
|
||||
sa.Column("title_before", sa.String(), nullable=True),
|
||||
sa.Column("folder_id_before", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"chunks_before", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||
),
|
||||
sa.Column(
|
||||
"metadata_before", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||
),
|
||||
sa.Column(
|
||||
"created_by_turn_id", sa.String(length=64), nullable=True, index=True
|
||||
),
|
||||
sa.Column(
|
||||
"agent_action_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("agent_action_log.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("(now() AT TIME ZONE 'utc')"),
|
||||
index=True,
|
||||
),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"folder_revisions",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column(
|
||||
"folder_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("folders.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
),
|
||||
sa.Column(
|
||||
"search_space_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
),
|
||||
sa.Column("name_before", sa.String(length=255), nullable=True),
|
||||
sa.Column("parent_id_before", sa.Integer(), nullable=True),
|
||||
sa.Column("position_before", sa.String(length=50), nullable=True),
|
||||
sa.Column(
|
||||
"created_by_turn_id", sa.String(length=64), nullable=True, index=True
|
||||
),
|
||||
sa.Column(
|
||||
"agent_action_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("agent_action_log.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("(now() AT TIME ZONE 'utc')"),
|
||||
index=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("folder_revisions")
|
||||
op.drop_table("document_revisions")
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
"""132_add_agent_permission_rules
|
||||
|
||||
Revision ID: 132
|
||||
Revises: 131
|
||||
Create Date: 2026-04-28
|
||||
|
||||
Tier 2.1 in the OpenCode-port plan. Adds the persistent ``agent_permission_rules``
|
||||
table consumed by :class:`PermissionMiddleware` at agent build time. Rules
|
||||
can be scoped at search-space (``user_id`` / ``thread_id`` NULL),
|
||||
user-wide (``user_id`` set, ``thread_id`` NULL), or per-thread
|
||||
(``thread_id`` set).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "132"
|
||||
down_revision: str | None = "131"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"agent_permission_rules",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column(
|
||||
"search_space_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
),
|
||||
sa.Column(
|
||||
"thread_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("new_chat_threads.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
),
|
||||
sa.Column("permission", sa.String(length=255), nullable=False),
|
||||
sa.Column(
|
||||
"pattern",
|
||||
sa.String(length=255),
|
||||
nullable=False,
|
||||
server_default="*",
|
||||
),
|
||||
sa.Column("action", sa.String(length=16), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("(now() AT TIME ZONE 'utc')"),
|
||||
index=True,
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"search_space_id",
|
||||
"user_id",
|
||||
"thread_id",
|
||||
"permission",
|
||||
"pattern",
|
||||
"action",
|
||||
name="uq_agent_permission_rules_scope",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("agent_permission_rules")
|
||||
|
|
@ -23,9 +23,16 @@ from deepagents import SubAgent, SubAgentMiddleware, __version__ as deepagents_v
|
|||
from deepagents.backends import StateBackend
|
||||
from deepagents.graph import BASE_AGENT_PROMPT
|
||||
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
from deepagents.middleware.skills import SkillsMiddleware
|
||||
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import TodoListMiddleware
|
||||
from langchain.agents.middleware import (
|
||||
LLMToolSelectorMiddleware,
|
||||
ModelCallLimitMiddleware,
|
||||
ModelFallbackMiddleware,
|
||||
TodoListMiddleware,
|
||||
ToolCallLimitMiddleware,
|
||||
)
|
||||
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
|
|
@ -33,27 +40,51 @@ from langgraph.types import Checkpointer
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
|
||||
from app.agents.new_chat.filesystem_backends import build_backend_resolver
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||
from app.agents.new_chat.llm_config import AgentConfig
|
||||
from app.agents.new_chat.middleware import (
|
||||
ActionLogMiddleware,
|
||||
AnonymousDocumentMiddleware,
|
||||
BusyMutexMiddleware,
|
||||
ClearToolUsesEdit,
|
||||
DedupHITLToolCallsMiddleware,
|
||||
DoomLoopMiddleware,
|
||||
FileIntentMiddleware,
|
||||
KnowledgeBasePersistenceMiddleware,
|
||||
KnowledgePriorityMiddleware,
|
||||
KnowledgeTreeMiddleware,
|
||||
MemoryInjectionMiddleware,
|
||||
NoopInjectionMiddleware,
|
||||
OtelSpanMiddleware,
|
||||
PermissionMiddleware,
|
||||
RetryAfterMiddleware,
|
||||
SpillingContextEditingMiddleware,
|
||||
SpillToBackendEdit,
|
||||
SurfSenseFilesystemMiddleware,
|
||||
ToolCallNameRepairMiddleware,
|
||||
build_skills_backend_factory,
|
||||
create_surfsense_compaction_middleware,
|
||||
default_skills_sources,
|
||||
)
|
||||
from app.agents.new_chat.middleware.safe_summarization import (
|
||||
create_safe_summarization_middleware,
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
from app.agents.new_chat.plugin_loader import (
|
||||
PluginContext,
|
||||
load_allowed_plugin_names_from_env,
|
||||
load_plugin_middlewares,
|
||||
)
|
||||
from app.agents.new_chat.subagents import build_specialized_subagents
|
||||
from app.agents.new_chat.system_prompt import (
|
||||
build_configurable_system_prompt,
|
||||
build_surfsense_system_prompt,
|
||||
)
|
||||
from app.agents.new_chat.tools.invalid_tool import (
|
||||
INVALID_TOOL_NAME,
|
||||
invalid_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.registry import (
|
||||
BUILTIN_TOOLS,
|
||||
build_tools_async,
|
||||
get_connector_gated_tools,
|
||||
)
|
||||
|
|
@ -321,6 +352,17 @@ async def create_surfsense_deep_agent(
|
|||
disabled_tools=modified_disabled_tools,
|
||||
additional_tools=list(additional_tools) if additional_tools else None,
|
||||
)
|
||||
|
||||
# Tier 1.6: register `invalid` tool. It is dispatched only when
|
||||
# ToolCallNameRepairMiddleware rewrites a malformed call. We
|
||||
# intentionally append it AFTER ``build_tools_async`` so it never
|
||||
# appears in the system-prompt tool list (which is built from the
|
||||
# registry, not the bound tool list).
|
||||
_flags: AgentFeatureFlags = get_flags()
|
||||
if _flags.enable_tool_call_repair and INVALID_TOOL_NAME not in {
|
||||
t.name for t in tools
|
||||
}:
|
||||
tools = [*list(tools), invalid_tool]
|
||||
_perf_log.info(
|
||||
"[create_agent] build_tools_async in %.3fs (%d tools)",
|
||||
time.perf_counter() - _t0,
|
||||
|
|
@ -397,6 +439,8 @@ async def create_surfsense_deep_agent(
|
|||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
max_input_tokens=_max_input_tokens,
|
||||
flags=_flags,
|
||||
checkpointer=checkpointer,
|
||||
)
|
||||
_perf_log.info(
|
||||
|
|
@ -411,6 +455,71 @@ async def create_surfsense_deep_agent(
|
|||
return agent
|
||||
|
||||
|
||||
# Tier 1.1: tools whose output is too costly / lossy to discard. Keep
|
||||
# this conservative — anything listed here is *never* pruned by
|
||||
# ContextEditingMiddleware. The list is filtered against actually-bound
|
||||
# tool names so disabled connectors don't show up here.
|
||||
_PRUNE_PROTECTED_TOOL_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
"generate_report",
|
||||
"generate_resume",
|
||||
"generate_podcast",
|
||||
"generate_video_presentation",
|
||||
"generate_image",
|
||||
# Read-heavy connector reads — recomputing them is expensive
|
||||
"read_email",
|
||||
"search_emails",
|
||||
# The fallback for malformed tool calls — keep its replies visible
|
||||
"invalid",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _safe_exclude_tools(tools: Sequence[BaseTool]) -> tuple[str, ...]:
|
||||
"""Return ``exclude_tools`` derived from the actually-bound tool list.
|
||||
|
||||
Filters :data:`_PRUNE_PROTECTED_TOOL_NAMES` against the bound tools
|
||||
so we never list tools that don't exist (would be a silent no-op).
|
||||
"""
|
||||
enabled = {t.name for t in tools}
|
||||
return tuple(name for name in _PRUNE_PROTECTED_TOOL_NAMES if name in enabled)
|
||||
|
||||
|
||||
# Tier 2.1 / cleanup: opencode `Permission.disabled` parity. Replaces the
|
||||
# legacy binary ``_CONNECTOR_TYPE_TO_SEARCHABLE``-based gating with a
|
||||
# declarative pass over :data:`BUILTIN_TOOLS`. Each tool that declares a
|
||||
# ``required_connector`` not present in ``available_connectors`` gets a
|
||||
# deny rule so any execution attempt short-circuits with permission_denied.
|
||||
def _synthesize_connector_deny_rules(
|
||||
*,
|
||||
available_connectors: list[str] | None,
|
||||
enabled_tool_names: set[str],
|
||||
) -> list[Rule]:
|
||||
"""Build deny rules for tools whose required connector is not enabled.
|
||||
|
||||
Source of truth is ``ToolDefinition.required_connector`` in
|
||||
:data:`BUILTIN_TOOLS`. A tool only gets a deny rule when:
|
||||
|
||||
1. It is currently bound (``enabled_tool_names``).
|
||||
2. It declares a ``required_connector``.
|
||||
3. That connector is *not* in ``available_connectors``.
|
||||
|
||||
This expresses the OpenCode ``Permission.disabled`` semantics
|
||||
declaratively, replacing the substring-heuristic binary gating
|
||||
that used to consult the hardcoded ``_CONNECTOR_TYPE_TO_SEARCHABLE``
|
||||
map.
|
||||
"""
|
||||
available = set(available_connectors or [])
|
||||
deny: list[Rule] = []
|
||||
for tool_def in BUILTIN_TOOLS:
|
||||
if tool_def.name not in enabled_tool_names:
|
||||
continue
|
||||
rc = tool_def.required_connector
|
||||
if rc and rc not in available:
|
||||
deny.append(Rule(permission=tool_def.name, pattern="*", action="deny"))
|
||||
return deny
|
||||
|
||||
|
||||
def _build_compiled_agent_blocking(
|
||||
*,
|
||||
llm: BaseChatModel,
|
||||
|
|
@ -426,6 +535,8 @@ def _build_compiled_agent_blocking(
|
|||
available_connectors: list[str] | None,
|
||||
available_document_types: list[str] | None,
|
||||
mentioned_document_ids: list[int] | None,
|
||||
max_input_tokens: int | None,
|
||||
flags: AgentFeatureFlags,
|
||||
checkpointer: Checkpointer,
|
||||
):
|
||||
"""Build the middleware stack and compile the agent graph synchronously.
|
||||
|
|
@ -458,7 +569,7 @@ def _build_compiled_agent_blocking(
|
|||
created_by_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
create_safe_summarization_middleware(llm, StateBackend),
|
||||
create_surfsense_compaction_middleware(llm, StateBackend),
|
||||
PatchToolCallsMiddleware(),
|
||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||
]
|
||||
|
|
@ -470,13 +581,319 @@ def _build_compiled_agent_blocking(
|
|||
"middleware": gp_middleware,
|
||||
}
|
||||
|
||||
# Tier 4.3: specialized user-facing subagents (explore, report_writer,
|
||||
# connector_negotiator). Registered through SubAgentMiddleware alongside
|
||||
# the general-purpose spec so the parent's `task` tool can address them
|
||||
# by name. Off by default until the flag flips so existing deployments
|
||||
# don't see new agent types in the task tool description.
|
||||
specialized_subagents: list[SubAgent] = []
|
||||
if (
|
||||
flags.enable_specialized_subagents
|
||||
and not flags.disable_new_agent_stack
|
||||
):
|
||||
try:
|
||||
# Specialized subagents share the parent's filesystem +
|
||||
# todo view so their system prompts (which promise
|
||||
# ``read_file``, ``ls``, ``grep``, ``glob``, ``write_todos``)
|
||||
# actually match runtime behavior. Build *fresh* instances
|
||||
# rather than aliasing the parent's GP middleware to avoid
|
||||
# subtle state coupling across compiled graphs.
|
||||
subagent_extra_middleware: list = [
|
||||
TodoListMiddleware(),
|
||||
SurfSenseFilesystemMiddleware(
|
||||
backend=backend_resolver,
|
||||
filesystem_mode=filesystem_mode,
|
||||
search_space_id=search_space_id,
|
||||
created_by_id=user_id,
|
||||
thread_id=thread_id,
|
||||
),
|
||||
]
|
||||
specialized_subagents = build_specialized_subagents(
|
||||
tools=tools,
|
||||
model=llm,
|
||||
extra_middleware=subagent_extra_middleware,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.warning(
|
||||
"Specialized subagent build failed; running without them: %s",
|
||||
exc,
|
||||
)
|
||||
specialized_subagents = []
|
||||
|
||||
subagent_specs: list[SubAgent] = [general_purpose_spec, *specialized_subagents]
|
||||
|
||||
# Main agent middleware
|
||||
# Order: AnonDoc -> Tree -> Priority -> FileIntent -> Filesystem -> Persistence -> ...
|
||||
# before_agent hooks run in declared order; later injections sit closer to
|
||||
# the latest human turn. Tree (large + cacheable) is injected earliest so
|
||||
# provider-side prefix caching has more material to hit; FileIntent (most
|
||||
# actionable per-turn contract) is injected closest to the user message.
|
||||
#
|
||||
# ``wrap_model_call`` ordering: the FIRST middleware in the list is the
|
||||
# OUTERMOST wrapper. To ensure prune executes before summarization,
|
||||
# place ``SpillingContextEditingMiddleware`` before
|
||||
# ``SurfSenseCompactionMiddleware`` (Tier 1.1 + 1.3).
|
||||
# Compaction is the canonical token-budget defense after the
|
||||
# cleanup tier removed ``SafeSummarizationMiddleware``. The Bedrock
|
||||
# buffer-empty defense is folded into ``SurfSenseCompactionMiddleware``.
|
||||
summarization_mw = create_surfsense_compaction_middleware(llm, StateBackend)
|
||||
_ = flags.enable_compaction_v2 # historical flag; retained for telemetry parity
|
||||
|
||||
# Tier 1.1: ContextEditing prune. Trigger at 55% of model_max_input,
|
||||
# earlier than summarization (~85%). When disabled, no edit runs.
|
||||
context_edit_mw = None
|
||||
if (
|
||||
flags.enable_context_editing
|
||||
and not flags.disable_new_agent_stack
|
||||
and max_input_tokens
|
||||
):
|
||||
spill_edit = SpillToBackendEdit(
|
||||
trigger=int(max_input_tokens * 0.55),
|
||||
clear_at_least=int(max_input_tokens * 0.15),
|
||||
keep=5,
|
||||
exclude_tools=_safe_exclude_tools(tools),
|
||||
clear_tool_inputs=True,
|
||||
)
|
||||
clear_edit = ClearToolUsesEdit(
|
||||
trigger=int(max_input_tokens * 0.55),
|
||||
clear_at_least=int(max_input_tokens * 0.15),
|
||||
keep=5,
|
||||
exclude_tools=_safe_exclude_tools(tools),
|
||||
clear_tool_inputs=True,
|
||||
placeholder="[cleared - older tool output trimmed for context]",
|
||||
)
|
||||
context_edit_mw = SpillingContextEditingMiddleware(
|
||||
edits=[spill_edit, clear_edit],
|
||||
backend_resolver=backend_resolver,
|
||||
)
|
||||
|
||||
# Tier 1.4 / 1.8 / 1.9 / 1.10: built-in retry/fallback/limits.
|
||||
retry_mw = (
|
||||
RetryAfterMiddleware(max_retries=3)
|
||||
if flags.enable_retry_after and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
# Fallback chain — primary is the agent's own model; we add cheap
|
||||
# alternatives. Off by default; only the first call site that
|
||||
# configures the chain via env should enable it.
|
||||
fallback_mw: ModelFallbackMiddleware | None = None
|
||||
if flags.enable_model_fallback and not flags.disable_new_agent_stack:
|
||||
try:
|
||||
fallback_mw = ModelFallbackMiddleware(
|
||||
"openai:gpt-4o-mini",
|
||||
"anthropic:claude-3-5-haiku-20241022",
|
||||
)
|
||||
except Exception:
|
||||
logging.warning("ModelFallbackMiddleware init failed; skipping.")
|
||||
fallback_mw = None
|
||||
model_call_limit_mw = (
|
||||
ModelCallLimitMiddleware(
|
||||
thread_limit=120,
|
||||
run_limit=80,
|
||||
exit_behavior="end",
|
||||
)
|
||||
if flags.enable_model_call_limit and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
tool_call_limit_mw = (
|
||||
ToolCallLimitMiddleware(thread_limit=300, run_limit=80, exit_behavior="continue")
|
||||
if flags.enable_tool_call_limit and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
|
||||
# Tier 1.5: provider-compat _noop injection.
|
||||
noop_mw = (
|
||||
NoopInjectionMiddleware()
|
||||
if flags.enable_compaction_v2 and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
|
||||
# Tier 1.7: tool-call name repair (lowercase + invalid fallback).
|
||||
#
|
||||
# ``registered_tool_names`` MUST cover every tool the model can legitimately
|
||||
# call. That includes the bound ``tools`` list AND every tool provided by
|
||||
# middleware in the stack — ``FilesystemMiddleware`` (read_file, ls, grep,
|
||||
# glob, edit_file, write_file, execute), ``TodoListMiddleware``
|
||||
# (write_todos), ``SubAgentMiddleware`` (task), ``SkillsMiddleware`` (skill
|
||||
# loaders), etc. If we only inspect ``tools`` here, every call to
|
||||
# ``read_file`` / ``ls`` / ``grep`` from the model will be rewritten to
|
||||
# ``invalid`` because the repair middleware doesn't recognize them. The
|
||||
# built-in deepagents middleware aren't in scope yet at this point of the
|
||||
# function but they're added unconditionally below, so we hard-code their
|
||||
# canonical names alongside the dynamic ``tools`` set.
|
||||
repair_mw = None
|
||||
if flags.enable_tool_call_repair and not flags.disable_new_agent_stack:
|
||||
registered_names: set[str] = {t.name for t in tools}
|
||||
# Tools owned by the standard deepagents middleware stack.
|
||||
registered_names |= {
|
||||
"write_todos",
|
||||
"ls",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"edit_file",
|
||||
"glob",
|
||||
"grep",
|
||||
"execute",
|
||||
"task",
|
||||
}
|
||||
repair_mw = ToolCallNameRepairMiddleware(
|
||||
registered_tool_names=registered_names,
|
||||
fuzzy_match_threshold=None, # opencode parity: no fuzzy step
|
||||
)
|
||||
|
||||
# Tier 1.11: doom-loop detector. Off by default until UI handles.
|
||||
doom_loop_mw = (
|
||||
DoomLoopMiddleware(threshold=3)
|
||||
if flags.enable_doom_loop and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
|
||||
# Tier 2.1: PermissionMiddleware. Layers, earliest -> latest (last
|
||||
# match wins per opencode):
|
||||
#
|
||||
# 1. ``surfsense_defaults`` — single ``allow */*`` rule. SurfSense
|
||||
# already runs per-tool HITL (see ``tools/hitl.py``) for mutating
|
||||
# connector tools, so we only want PermissionMiddleware to *deny*
|
||||
# things the user has gated off; the default fallback in
|
||||
# ``permissions.evaluate`` is ``ask``, which would double-prompt
|
||||
# on every safe read-only call (``ls``, ``read_file``, ``grep``,
|
||||
# ``glob``, ``web_search`` …) and, on resume, replay the previous
|
||||
# reject decision into innocent calls.
|
||||
# 2. ``connector_synthesized`` — deny rules for tools whose required
|
||||
# connector is not connected to this space. Overrides #1.
|
||||
# 3. (future) user-defined rules from ``agent_permission_rules`` table
|
||||
# via the Agent Permissions UI. Loaded last so they override both.
|
||||
permission_mw: PermissionMiddleware | None = None
|
||||
if flags.enable_permission and not flags.disable_new_agent_stack:
|
||||
synthesized = _synthesize_connector_deny_rules(
|
||||
available_connectors=available_connectors,
|
||||
enabled_tool_names={t.name for t in tools},
|
||||
)
|
||||
permission_mw = PermissionMiddleware(
|
||||
rulesets=[
|
||||
Ruleset(
|
||||
rules=[Rule(permission="*", pattern="*", action="allow")],
|
||||
origin="surfsense_defaults",
|
||||
),
|
||||
Ruleset(rules=synthesized, origin="connector_synthesized"),
|
||||
],
|
||||
)
|
||||
|
||||
# Tier 5.2: ActionLogMiddleware. Off by default until the
|
||||
# ``agent_action_log`` table is migrated. When enabled, persists one
|
||||
# row per tool call with optional reverse_descriptor for
|
||||
# /api/threads/{thread_id}/revert/{action_id}. Sits inside permission
|
||||
# so denied calls aren't logged as completions.
|
||||
action_log_mw: ActionLogMiddleware | None = None
|
||||
if (
|
||||
flags.enable_action_log
|
||||
and not flags.disable_new_agent_stack
|
||||
and thread_id is not None
|
||||
):
|
||||
try:
|
||||
tool_defs_by_name = {td.name: td for td in BUILTIN_TOOLS}
|
||||
action_log_mw = ActionLogMiddleware(
|
||||
thread_id=thread_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
tool_definitions=tool_defs_by_name,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logging.warning(
|
||||
"ActionLogMiddleware init failed; running without it.",
|
||||
exc_info=True,
|
||||
)
|
||||
action_log_mw = None
|
||||
|
||||
# Tier 2.2: per-thread busy mutex.
|
||||
busy_mutex_mw: BusyMutexMiddleware | None = (
|
||||
BusyMutexMiddleware()
|
||||
if flags.enable_busy_mutex and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
|
||||
# Tier 3b: OpenTelemetry spans (model.call + tool.call). Lives just
|
||||
# inside BusyMutex so it spans every retry/fallback attempt of the
|
||||
# current turn but never wraps a queued/blocked turn.
|
||||
otel_mw: OtelSpanMiddleware | None = (
|
||||
OtelSpanMiddleware()
|
||||
if flags.enable_otel and not flags.disable_new_agent_stack
|
||||
else None
|
||||
)
|
||||
|
||||
# Tier 6: plugin entry-point loader. Off by default; opt-in via the
|
||||
# ``SURFSENSE_ENABLE_PLUGIN_LOADER`` flag. The allowlist is read from
|
||||
# the ``SURFSENSE_ALLOWED_PLUGINS`` env var (comma-separated). A future
|
||||
# PR can wire it through ``global_llm_config.yaml``.
|
||||
plugin_middlewares: list[Any] = []
|
||||
if flags.enable_plugin_loader and not flags.disable_new_agent_stack:
|
||||
try:
|
||||
allowed_names = load_allowed_plugin_names_from_env()
|
||||
if allowed_names:
|
||||
plugin_middlewares = load_plugin_middlewares(
|
||||
PluginContext.build(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_visibility=visibility,
|
||||
llm=llm,
|
||||
),
|
||||
allowed_plugin_names=allowed_names,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logging.warning(
|
||||
"Plugin loader failed; continuing without plugins.",
|
||||
exc_info=True,
|
||||
)
|
||||
plugin_middlewares = []
|
||||
|
||||
# Tier 4.1: SkillsMiddleware. Loads built-in + space-authored skills
|
||||
# via a CompositeBackend. Sources are layered: built-in first, space
|
||||
# last, so a search-space-authored skill of the same name overrides
|
||||
# the bundled one.
|
||||
skills_mw: SkillsMiddleware | None = None
|
||||
if flags.enable_skills and not flags.disable_new_agent_stack:
|
||||
try:
|
||||
skills_factory = build_skills_backend_factory(
|
||||
search_space_id=search_space_id
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
)
|
||||
skills_mw = SkillsMiddleware(
|
||||
backend=skills_factory,
|
||||
sources=default_skills_sources(),
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.warning("SkillsMiddleware init failed; skipping: %s", exc)
|
||||
skills_mw = None
|
||||
|
||||
# Tier 2.5: LLM-driven tool selection for >30 tools.
|
||||
selector_mw: LLMToolSelectorMiddleware | None = None
|
||||
if (
|
||||
flags.enable_llm_tool_selector
|
||||
and not flags.disable_new_agent_stack
|
||||
and len(tools) > 30
|
||||
):
|
||||
try:
|
||||
selector_mw = LLMToolSelectorMiddleware(
|
||||
model="openai:gpt-4o-mini",
|
||||
max_tools=12,
|
||||
always_include=[
|
||||
name
|
||||
for name in ("update_memory", "get_connected_accounts", "scrape_webpage")
|
||||
if name in {t.name for t in tools}
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
logging.warning("LLMToolSelectorMiddleware init failed; skipping.")
|
||||
selector_mw = None
|
||||
|
||||
deepagent_middleware = [
|
||||
# BusyMutex is OUTERMOST: it must wrap the entire stream so no
|
||||
# other turn can sneak in while this one is mid-flight.
|
||||
busy_mutex_mw,
|
||||
# OTel spans sit just inside BusyMutex so each retry attempt
|
||||
# gets its own model.call / tool.call span.
|
||||
otel_mw,
|
||||
TodoListMiddleware(),
|
||||
_memory_middleware,
|
||||
AnonymousDocumentMiddleware(
|
||||
|
|
@ -514,10 +931,40 @@ def _build_compiled_agent_blocking(
|
|||
)
|
||||
if filesystem_mode == FilesystemMode.CLOUD
|
||||
else None,
|
||||
SubAgentMiddleware(backend=StateBackend, subagents=[general_purpose_spec]),
|
||||
create_safe_summarization_middleware(llm, StateBackend),
|
||||
# Tier 4.1: skill loader. Placed before SubAgentMiddleware so
|
||||
# subagents inherit the same skill metadata (subagent specs reference
|
||||
# the same source paths via `default_skills_sources()`).
|
||||
skills_mw,
|
||||
SubAgentMiddleware(backend=StateBackend, subagents=subagent_specs),
|
||||
# Tier 2.5: tool selection (only when >30 tools and flag on).
|
||||
selector_mw,
|
||||
# Defensive caps, then prune, then summarize.
|
||||
model_call_limit_mw,
|
||||
tool_call_limit_mw,
|
||||
context_edit_mw,
|
||||
summarization_mw,
|
||||
# Provider compatibility + retry chain — placed after prune/compact
|
||||
# so retries happen on the already-trimmed payload.
|
||||
noop_mw,
|
||||
retry_mw,
|
||||
fallback_mw,
|
||||
# Tool-call repair must run after model emits but before
|
||||
# permission / dedup / doom-loop interpret the calls.
|
||||
repair_mw,
|
||||
# Tier 2.1: deny/ask BEFORE the calls are forwarded to tool nodes.
|
||||
permission_mw,
|
||||
doom_loop_mw,
|
||||
# Tier 5.2: action log sits inside permission so denied calls
|
||||
# don't appear as completions, and outside dedup so each unique
|
||||
# tool invocation gets its own row.
|
||||
action_log_mw,
|
||||
PatchToolCallsMiddleware(),
|
||||
DedupHITLToolCallsMiddleware(agent_tools=list(tools)),
|
||||
# Tier 6: plugin slot — sits just before AnthropicCache so plugin-side
|
||||
# transforms see the final tool result and run before any caching
|
||||
# heuristics. Multiple plugins in declared order; loader filtered by
|
||||
# the admin allowlist already.
|
||||
*plugin_middlewares,
|
||||
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||
]
|
||||
deepagent_middleware = [m for m in deepagent_middleware if m is not None]
|
||||
|
|
|
|||
95
surfsense_backend/app/agents/new_chat/errors.py
Normal file
95
surfsense_backend/app/agents/new_chat/errors.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
"""
|
||||
Typed error taxonomy for the SurfSense agent stack.
|
||||
|
||||
Used by:
|
||||
- :class:`RetryAfterMiddleware` (Tier 1.4) — its ``retry_on`` callable
|
||||
consults the error code to decide whether a retry is appropriate.
|
||||
- :class:`PermissionMiddleware` (Tier 2.1) — emits
|
||||
``code="permission_denied"`` errors when a deny rule trips.
|
||||
- All tools — return :class:`StreamingError` payloads in
|
||||
``ToolMessage.additional_kwargs["error"]`` so the model and the
|
||||
retry/permission layers share a contract.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
ErrorCode = Literal[
|
||||
"rate_limit",
|
||||
"auth",
|
||||
"tool_validation",
|
||||
"tool_runtime",
|
||||
"context_overflow",
|
||||
"provider",
|
||||
"permission_denied",
|
||||
"doom_loop",
|
||||
"busy",
|
||||
"cancelled",
|
||||
]
|
||||
|
||||
|
||||
class StreamingError(BaseModel):
|
||||
"""Structured error payload attached to ``ToolMessage.additional_kwargs["error"]``.
|
||||
|
||||
Tools and middleware emit this so retry, permission, and routing
|
||||
layers can decide what to do without parsing free-form strings.
|
||||
"""
|
||||
|
||||
code: ErrorCode
|
||||
retryable: bool = False
|
||||
suggestion: str | None = None
|
||||
correlation_id: str | None = None
|
||||
detail: str | None = Field(
|
||||
default=None,
|
||||
description="Free-form additional context. Not surfaced to the model.",
|
||||
)
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
|
||||
class RejectedError(Exception):
|
||||
"""Raised when the user rejects a permission ask without feedback.
|
||||
|
||||
Caught by :class:`PermissionMiddleware`; the agent stops the current
|
||||
tool fan-out and surfaces a user-facing rejection.
|
||||
"""
|
||||
|
||||
def __init__(self, *, tool: str | None = None, pattern: str | None = None) -> None:
|
||||
super().__init__(f"Permission rejected for tool {tool!r}, pattern {pattern!r}")
|
||||
self.tool = tool
|
||||
self.pattern = pattern
|
||||
|
||||
|
||||
class CorrectedError(Exception):
|
||||
"""Raised when the user rejects a permission ask *with* feedback.
|
||||
|
||||
The :class:`PermissionMiddleware` translates the feedback into a
|
||||
synthetic ``ToolMessage`` so the model sees the user's correction
|
||||
and can retry the request differently.
|
||||
"""
|
||||
|
||||
def __init__(self, feedback: str, *, tool: str | None = None) -> None:
|
||||
super().__init__(feedback)
|
||||
self.feedback = feedback
|
||||
self.tool = tool
|
||||
|
||||
|
||||
class BusyError(Exception):
|
||||
"""Raised when a second prompt arrives while the same thread is mid-stream."""
|
||||
|
||||
def __init__(self, request_id: str | None = None) -> None:
|
||||
super().__init__("Thread is busy with another request")
|
||||
self.request_id = request_id
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BusyError",
|
||||
"CorrectedError",
|
||||
"ErrorCode",
|
||||
"RejectedError",
|
||||
"StreamingError",
|
||||
]
|
||||
188
surfsense_backend/app/agents/new_chat/feature_flags.py
Normal file
188
surfsense_backend/app/agents/new_chat/feature_flags.py
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
"""
|
||||
Feature flags for the SurfSense new_chat agent stack.
|
||||
|
||||
These flags control rollout of OpenCode-pattern middleware ported into
|
||||
SurfSense. They follow a "default-OFF for risky things, default-ON for
|
||||
safe upgrades, master kill-switch for everything new" model.
|
||||
|
||||
All new middleware checks its flag at agent build time. If the master
|
||||
kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new
|
||||
middleware is disabled regardless of its individual flag. This gives
|
||||
operators a single switch to revert to pre-port behavior.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
Local development (recommended for trying everything except doom-loop / selector):
|
||||
|
||||
SURFSENSE_ENABLE_CONTEXT_EDITING=true
|
||||
SURFSENSE_ENABLE_COMPACTION_V2=true
|
||||
SURFSENSE_ENABLE_RETRY_AFTER=true
|
||||
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
|
||||
SURFSENSE_ENABLE_PERMISSION=false # default off, opt-in per deploy
|
||||
SURFSENSE_ENABLE_DOOM_LOOP=false # default off until UI ships
|
||||
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false
|
||||
|
||||
Master kill-switch (overrides everything else):
|
||||
|
||||
SURFSENSE_DISABLE_NEW_AGENT_STACK=true
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _env_bool(name: str, default: bool) -> bool:
|
||||
"""Parse a boolean env var. Accepts ``1``/``true``/``yes``/``on`` (case-insensitive)."""
|
||||
raw = os.environ.get(name)
|
||||
if raw is None:
|
||||
return default
|
||||
return raw.strip().lower() in ("1", "true", "yes", "on")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentFeatureFlags:
|
||||
"""Resolved feature-flag state for one agent build.
|
||||
|
||||
Constructed via :meth:`from_env`. The dataclass is frozen so it can be
|
||||
safely shared across coroutines.
|
||||
"""
|
||||
|
||||
# Master kill-switch — when true, every flag below resolves to False
|
||||
# regardless of its env value. Used for rapid rollback.
|
||||
disable_new_agent_stack: bool = False
|
||||
|
||||
# Tier 1 — Agent quality
|
||||
enable_context_editing: bool = False
|
||||
enable_compaction_v2: bool = False
|
||||
enable_retry_after: bool = False
|
||||
enable_model_fallback: bool = False
|
||||
enable_model_call_limit: bool = False
|
||||
enable_tool_call_limit: bool = False
|
||||
enable_tool_call_repair: bool = False
|
||||
enable_doom_loop: bool = False # Default OFF until UI handles permission='doom_loop'
|
||||
|
||||
# Tier 2 — Safety
|
||||
enable_permission: bool = False # Default OFF for first deploy
|
||||
enable_busy_mutex: bool = False
|
||||
enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost
|
||||
|
||||
# Tier 4 — Skills + subagents
|
||||
enable_skills: bool = False
|
||||
enable_specialized_subagents: bool = False
|
||||
enable_kb_planner_runnable: bool = False
|
||||
|
||||
# Tier 5 — Snapshot / revert
|
||||
enable_action_log: bool = False
|
||||
enable_revert_route: bool = False # Backend ships before UI; route returns 503 until this flips
|
||||
|
||||
# Tier 6 — Plugins
|
||||
enable_plugin_loader: bool = False
|
||||
|
||||
# Tier 3b — OTel (orthogonal: also requires OTEL_EXPORTER_OTLP_ENDPOINT)
|
||||
enable_otel: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> AgentFeatureFlags:
|
||||
"""Read flags from environment.
|
||||
|
||||
Master kill-switch is evaluated first; when set, all other flags
|
||||
force to False.
|
||||
"""
|
||||
master_off = _env_bool("SURFSENSE_DISABLE_NEW_AGENT_STACK", False)
|
||||
if master_off:
|
||||
logger.info(
|
||||
"SURFSENSE_DISABLE_NEW_AGENT_STACK is set: every new agent "
|
||||
"middleware is forced OFF for this build."
|
||||
)
|
||||
return cls(disable_new_agent_stack=True)
|
||||
|
||||
return cls(
|
||||
disable_new_agent_stack=False,
|
||||
# Tier 1
|
||||
enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", False),
|
||||
enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", False),
|
||||
enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", False),
|
||||
enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False),
|
||||
enable_model_call_limit=_env_bool("SURFSENSE_ENABLE_MODEL_CALL_LIMIT", False),
|
||||
enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", False),
|
||||
enable_tool_call_repair=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_REPAIR", False),
|
||||
enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", False),
|
||||
# Tier 2
|
||||
enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", False),
|
||||
enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", False),
|
||||
enable_llm_tool_selector=_env_bool("SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False),
|
||||
# Tier 4
|
||||
enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", False),
|
||||
enable_specialized_subagents=_env_bool(
|
||||
"SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", False
|
||||
),
|
||||
enable_kb_planner_runnable=_env_bool(
|
||||
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", False
|
||||
),
|
||||
# Tier 5
|
||||
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", False),
|
||||
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", False),
|
||||
# Tier 6
|
||||
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
|
||||
# Tier 3b
|
||||
enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False),
|
||||
)
|
||||
|
||||
def any_new_middleware_enabled(self) -> bool:
|
||||
"""Return True if any new middleware flag is on."""
|
||||
if self.disable_new_agent_stack:
|
||||
return False
|
||||
return any(
|
||||
(
|
||||
self.enable_context_editing,
|
||||
self.enable_compaction_v2,
|
||||
self.enable_retry_after,
|
||||
self.enable_model_fallback,
|
||||
self.enable_model_call_limit,
|
||||
self.enable_tool_call_limit,
|
||||
self.enable_tool_call_repair,
|
||||
self.enable_doom_loop,
|
||||
self.enable_permission,
|
||||
self.enable_busy_mutex,
|
||||
self.enable_llm_tool_selector,
|
||||
self.enable_skills,
|
||||
self.enable_specialized_subagents,
|
||||
self.enable_kb_planner_runnable,
|
||||
self.enable_action_log,
|
||||
self.enable_revert_route,
|
||||
self.enable_plugin_loader,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Module-level cache. Read once at import time so the values are consistent
|
||||
# across the process lifetime. Use ``reload_for_tests`` to reset in tests.
|
||||
_FLAGS: AgentFeatureFlags | None = None
|
||||
|
||||
|
||||
def get_flags() -> AgentFeatureFlags:
|
||||
"""Return the resolved feature-flag state, caching on first call."""
|
||||
global _FLAGS
|
||||
if _FLAGS is None:
|
||||
_FLAGS = AgentFeatureFlags.from_env()
|
||||
return _FLAGS
|
||||
|
||||
|
||||
def reload_for_tests() -> AgentFeatureFlags:
|
||||
"""Force a fresh read from env. Tests should call this after monkeypatching env."""
|
||||
global _FLAGS
|
||||
_FLAGS = AgentFeatureFlags.from_env()
|
||||
return _FLAGS
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentFeatureFlags",
|
||||
"get_flags",
|
||||
"reload_for_tests",
|
||||
]
|
||||
|
|
@ -1,11 +1,23 @@
|
|||
"""Middleware components for the SurfSense new chat agent."""
|
||||
|
||||
from app.agents.new_chat.middleware.action_log import ActionLogMiddleware
|
||||
from app.agents.new_chat.middleware.anonymous_document import (
|
||||
AnonymousDocumentMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.middleware.busy_mutex import BusyMutexMiddleware
|
||||
from app.agents.new_chat.middleware.compaction import (
|
||||
SurfSenseCompactionMiddleware,
|
||||
create_surfsense_compaction_middleware,
|
||||
)
|
||||
from app.agents.new_chat.middleware.context_editing import (
|
||||
ClearToolUsesEdit,
|
||||
SpillingContextEditingMiddleware,
|
||||
SpillToBackendEdit,
|
||||
)
|
||||
from app.agents.new_chat.middleware.dedup_tool_calls import (
|
||||
DedupHITLToolCallsMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.middleware.doom_loop import DoomLoopMiddleware
|
||||
from app.agents.new_chat.middleware.file_intent import (
|
||||
FileIntentMiddleware,
|
||||
)
|
||||
|
|
@ -26,16 +38,46 @@ from app.agents.new_chat.middleware.knowledge_tree import (
|
|||
from app.agents.new_chat.middleware.memory_injection import (
|
||||
MemoryInjectionMiddleware,
|
||||
)
|
||||
from app.agents.new_chat.middleware.noop_injection import NoopInjectionMiddleware
|
||||
from app.agents.new_chat.middleware.otel_span import OtelSpanMiddleware
|
||||
from app.agents.new_chat.middleware.permission import PermissionMiddleware
|
||||
from app.agents.new_chat.middleware.retry_after import RetryAfterMiddleware
|
||||
from app.agents.new_chat.middleware.skills_backends import (
|
||||
BuiltinSkillsBackend,
|
||||
SearchSpaceSkillsBackend,
|
||||
build_skills_backend_factory,
|
||||
default_skills_sources,
|
||||
)
|
||||
from app.agents.new_chat.middleware.tool_call_repair import (
|
||||
ToolCallNameRepairMiddleware,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ActionLogMiddleware",
|
||||
"AnonymousDocumentMiddleware",
|
||||
"BuiltinSkillsBackend",
|
||||
"BusyMutexMiddleware",
|
||||
"ClearToolUsesEdit",
|
||||
"DedupHITLToolCallsMiddleware",
|
||||
"DoomLoopMiddleware",
|
||||
"FileIntentMiddleware",
|
||||
"KnowledgeBasePersistenceMiddleware",
|
||||
"KnowledgeBaseSearchMiddleware",
|
||||
"KnowledgePriorityMiddleware",
|
||||
"KnowledgeTreeMiddleware",
|
||||
"MemoryInjectionMiddleware",
|
||||
"NoopInjectionMiddleware",
|
||||
"OtelSpanMiddleware",
|
||||
"PermissionMiddleware",
|
||||
"RetryAfterMiddleware",
|
||||
"SearchSpaceSkillsBackend",
|
||||
"SpillToBackendEdit",
|
||||
"SpillingContextEditingMiddleware",
|
||||
"SurfSenseCompactionMiddleware",
|
||||
"SurfSenseFilesystemMiddleware",
|
||||
"ToolCallNameRepairMiddleware",
|
||||
"build_skills_backend_factory",
|
||||
"commit_staged_filesystem_state",
|
||||
"create_surfsense_compaction_middleware",
|
||||
"default_skills_sources",
|
||||
]
|
||||
|
|
|
|||
294
surfsense_backend/app/agents/new_chat/middleware/action_log.py
Normal file
294
surfsense_backend/app/agents/new_chat/middleware/action_log.py
Normal file
|
|
@ -0,0 +1,294 @@
|
|||
"""Append-only action-log middleware for the SurfSense agent.
|
||||
|
||||
Wraps every tool call via :meth:`AgentMiddleware.awrap_tool_call` and writes
|
||||
a row to :class:`~app.db.AgentActionLog` after the tool returns. Tools opt
|
||||
into reversibility by declaring a ``reverse`` callable on their
|
||||
:class:`~app.agents.new_chat.tools.registry.ToolDefinition`; the rendered
|
||||
descriptor is persisted in ``reverse_descriptor`` for use by
|
||||
``/api/threads/{thread_id}/revert/{action_id}``.
|
||||
|
||||
Design points:
|
||||
|
||||
* **Defensive.** Logging never blocks the agent. We catch every exception
|
||||
on the DB write path and emit a warning; the tool's ``ToolMessage``
|
||||
result is always returned untouched.
|
||||
* **Lightweight payload.** Only the tool ``name`` + ``args`` (capped) +
|
||||
``result_id`` + ``reverse_descriptor`` are stored. Tool output text
|
||||
remains in the LangGraph checkpoint / spilled tool-output files.
|
||||
* **Best-effort reversibility.** We invoke ``reverse(args, result_obj)``
|
||||
with the parsed JSON result when the tool's content is a JSON object;
|
||||
otherwise the raw text is passed. Exceptions in the reverse callable
|
||||
are swallowed and logged — a failed descriptor render simply means the
|
||||
action is NOT marked reversible.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
from app.agents.new_chat.feature_flags import get_flags
|
||||
from app.agents.new_chat.tools.registry import ToolDefinition
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - type-only
|
||||
from langchain.agents.middleware.types import ToolCallRequest
|
||||
from langgraph.types import Command
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Cap for the persisted ``args`` JSON to avoid bloating the action log with
|
||||
# accidentally-huge inputs. Values are truncated and a flag is set in the
|
||||
# stored payload so consumers can detect truncation.
|
||||
_MAX_ARGS_PERSIST_BYTES = 32 * 1024 # 32KB
|
||||
|
||||
|
||||
class ActionLogMiddleware(AgentMiddleware):
|
||||
"""Persist a row in :class:`AgentActionLog` after every tool call.
|
||||
|
||||
Should be placed near the OUTERMOST end of the tool-call wrapping stack
|
||||
so that it sees the *final* :class:`ToolMessage` after all retries,
|
||||
permission checks, and dedup logic have run. In practice that means
|
||||
placing it just inside :class:`PermissionMiddleware` and outside
|
||||
:class:`DedupHITLToolCallsMiddleware`.
|
||||
|
||||
The middleware is fully a no-op when:
|
||||
|
||||
* the master kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set
|
||||
(checked via :func:`get_flags`),
|
||||
* the per-feature flag ``enable_action_log`` is off, or
|
||||
* persistence raises (defensive: tool-call dispatch always succeeds).
|
||||
|
||||
Args:
|
||||
thread_id: The current chat thread's primary-key id. Required to
|
||||
persist a row; if ``None`` the middleware silently no-ops.
|
||||
search_space_id: Search-space id for cascade-on-delete safety.
|
||||
user_id: UUID string of the user driving this turn (nullable in
|
||||
anonymous mode).
|
||||
tool_definitions: Optional mapping of tool name -> :class:`ToolDefinition`
|
||||
so the middleware can look up the tool's ``reverse`` callable.
|
||||
When omitted, no actions are marked reversible.
|
||||
"""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
thread_id: int | None,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
tool_definitions: dict[str, ToolDefinition] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._thread_id = thread_id
|
||||
self._search_space_id = search_space_id
|
||||
self._user_id = user_id
|
||||
self._tool_definitions = dict(tool_definitions or {})
|
||||
|
||||
def _enabled(self) -> bool:
|
||||
flags = get_flags()
|
||||
if flags.disable_new_agent_stack:
|
||||
return False
|
||||
return bool(flags.enable_action_log) and self._thread_id is not None
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[
|
||||
[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]
|
||||
],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
if not self._enabled():
|
||||
return await handler(request)
|
||||
|
||||
result: ToolMessage | Command[Any]
|
||||
error_payload: dict[str, Any] | None = None
|
||||
try:
|
||||
result = await handler(request)
|
||||
except Exception as exc:
|
||||
# Persist the failure too so revert/audit can see it, then
|
||||
# re-raise so downstream middleware (RetryAfter, etc.) handles it.
|
||||
error_payload = {"type": type(exc).__name__, "message": str(exc)}
|
||||
await self._record(
|
||||
request=request,
|
||||
result=None,
|
||||
error_payload=error_payload,
|
||||
)
|
||||
raise
|
||||
|
||||
await self._record(request=request, result=result, error_payload=None)
|
||||
return result
|
||||
|
||||
async def _record(
|
||||
self,
|
||||
*,
|
||||
request: ToolCallRequest,
|
||||
result: ToolMessage | Command[Any] | None,
|
||||
error_payload: dict[str, Any] | None,
|
||||
) -> None:
|
||||
"""Persist one ``agent_action_log`` row. Defensive: never raises."""
|
||||
try:
|
||||
from app.db import AgentActionLog, shielded_async_session
|
||||
|
||||
tool_name = _resolve_tool_name(request)
|
||||
args_payload = _resolve_args_payload(request)
|
||||
result_id = _resolve_result_id(result)
|
||||
reverse_descriptor, reversible = self._render_reverse(
|
||||
tool_name=tool_name,
|
||||
args=_resolve_args_dict(request),
|
||||
result=result,
|
||||
)
|
||||
|
||||
row = AgentActionLog(
|
||||
thread_id=self._thread_id,
|
||||
user_id=self._user_id,
|
||||
search_space_id=self._search_space_id,
|
||||
turn_id=_resolve_turn_id(request),
|
||||
message_id=_resolve_message_id(request),
|
||||
tool_name=tool_name,
|
||||
args=args_payload,
|
||||
result_id=result_id,
|
||||
reversible=reversible,
|
||||
reverse_descriptor=reverse_descriptor,
|
||||
error=error_payload,
|
||||
)
|
||||
async with shielded_async_session() as session:
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"ActionLogMiddleware failed to persist action log row",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def _render_reverse(
|
||||
self,
|
||||
*,
|
||||
tool_name: str,
|
||||
args: dict[str, Any] | None,
|
||||
result: ToolMessage | Command[Any] | None,
|
||||
) -> tuple[dict[str, Any] | None, bool]:
|
||||
"""Run the tool's ``reverse`` callable and return its descriptor.
|
||||
|
||||
Returns a tuple of ``(descriptor_or_None, reversible_bool)``. When
|
||||
the tool has no ``reverse`` callable, or when the callable raises,
|
||||
the action is marked non-reversible.
|
||||
"""
|
||||
if not result or not isinstance(result, ToolMessage):
|
||||
return None, False
|
||||
if args is None:
|
||||
return None, False
|
||||
tool_def = self._tool_definitions.get(tool_name)
|
||||
if tool_def is None or tool_def.reverse is None:
|
||||
return None, False
|
||||
try:
|
||||
parsed_result = _parse_tool_result_content(result)
|
||||
descriptor = tool_def.reverse(args, parsed_result)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Reverse descriptor render failed for tool %s",
|
||||
tool_name,
|
||||
exc_info=True,
|
||||
)
|
||||
return None, False
|
||||
if not isinstance(descriptor, dict):
|
||||
return None, False
|
||||
return descriptor, True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Resolution helpers — defensive against tool_call request shape variation.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_tool_name(request: Any) -> str:
|
||||
try:
|
||||
tool = getattr(request, "tool", None)
|
||||
if tool is not None:
|
||||
name = getattr(tool, "name", None)
|
||||
if isinstance(name, str) and name:
|
||||
return name
|
||||
call = getattr(request, "tool_call", None) or {}
|
||||
if isinstance(call, dict):
|
||||
name = call.get("name")
|
||||
if isinstance(name, str) and name:
|
||||
return name
|
||||
except Exception: # pragma: no cover - defensive
|
||||
pass
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _resolve_args_dict(request: Any) -> dict[str, Any] | None:
|
||||
try:
|
||||
call = getattr(request, "tool_call", None)
|
||||
if not isinstance(call, dict):
|
||||
return None
|
||||
args = call.get("args")
|
||||
if isinstance(args, dict):
|
||||
return args
|
||||
return None
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_args_payload(request: Any) -> dict[str, Any] | None:
|
||||
"""Return a JSON-serializable args dict, truncated if too big."""
|
||||
args = _resolve_args_dict(request)
|
||||
if args is None:
|
||||
return None
|
||||
try:
|
||||
encoded = json.dumps(args, default=str)
|
||||
except Exception:
|
||||
return {"_repr": repr(args)[:_MAX_ARGS_PERSIST_BYTES]}
|
||||
if len(encoded) <= _MAX_ARGS_PERSIST_BYTES:
|
||||
return args
|
||||
return {
|
||||
"_truncated": True,
|
||||
"_size": len(encoded),
|
||||
"_preview": encoded[:_MAX_ARGS_PERSIST_BYTES],
|
||||
}
|
||||
|
||||
|
||||
def _resolve_turn_id(request: Any) -> str | None:
|
||||
try:
|
||||
call = getattr(request, "tool_call", None) or {}
|
||||
if isinstance(call, dict):
|
||||
tid = call.get("id")
|
||||
if isinstance(tid, str):
|
||||
return tid
|
||||
except Exception: # pragma: no cover
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_message_id(request: Any) -> str | None:
|
||||
"""Tool-call IDs serve as best-available message correlator at this layer."""
|
||||
return _resolve_turn_id(request)
|
||||
|
||||
|
||||
def _resolve_result_id(result: Any) -> str | None:
|
||||
if isinstance(result, ToolMessage):
|
||||
msg_id = getattr(result, "id", None)
|
||||
if isinstance(msg_id, str):
|
||||
return msg_id
|
||||
return None
|
||||
|
||||
|
||||
def _parse_tool_result_content(result: ToolMessage) -> Any:
|
||||
content = result.content
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
return json.loads(content)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return content
|
||||
return content
|
||||
|
||||
|
||||
__all__ = ["ActionLogMiddleware"]
|
||||
231
surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py
Normal file
231
surfsense_backend/app/agents/new_chat/middleware/busy_mutex.py
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
"""
|
||||
BusyMutexMiddleware — per-thread asyncio lock + cancel token.
|
||||
|
||||
Tier 2.2 in the OpenCode-port plan. Mirrors opencode's
|
||||
``Stream.scoped(AbortController)`` pattern (single-process, in-memory
|
||||
lock + cooperative cancellation). For multi-worker deployments a
|
||||
distributed lock backend (Redis or PostgreSQL advisory locks) is a
|
||||
phase-2 follow-up.
|
||||
|
||||
What this provides:
|
||||
- A ``WeakValueDictionary[str, asyncio.Lock]`` keyed by ``thread_id``;
|
||||
acquiring the lock during ``before_agent`` blocks any concurrent
|
||||
prompt on the same thread until release.
|
||||
- A per-thread ``asyncio.Event`` (``cancel_event``) that long-running
|
||||
tools can poll to abort cooperatively. The event is reset between
|
||||
turns. Tools should check ``runtime.context.cancel_event.is_set()``
|
||||
in tight inner loops.
|
||||
- A typed :class:`~app.agents.new_chat.errors.BusyError` raised when a
|
||||
second turn arrives while the lock is held.
|
||||
|
||||
Note: SurfSense's ``stream_new_chat`` is the call site that should
|
||||
acquire/release. Wiring this as middleware means the contract is
|
||||
explicit and the lock manager is shared with subagents that compile
|
||||
their own ``create_agent`` runnables.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import weakref
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ResponseT,
|
||||
)
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agents.new_chat.errors import BusyError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _ThreadLockManager:
|
||||
"""Process-local registry of per-thread asyncio locks + cancel events."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = (
|
||||
weakref.WeakValueDictionary()
|
||||
)
|
||||
self._cancel_events: dict[str, asyncio.Event] = {}
|
||||
|
||||
def lock_for(self, thread_id: str) -> asyncio.Lock:
|
||||
lock = self._locks.get(thread_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
self._locks[thread_id] = lock
|
||||
return lock
|
||||
|
||||
def cancel_event(self, thread_id: str) -> asyncio.Event:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
if event is None:
|
||||
event = asyncio.Event()
|
||||
self._cancel_events[thread_id] = event
|
||||
return event
|
||||
|
||||
def request_cancel(self, thread_id: str) -> bool:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
if event is None:
|
||||
return False
|
||||
event.set()
|
||||
return True
|
||||
|
||||
def reset(self, thread_id: str) -> None:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
if event is not None:
|
||||
event.clear()
|
||||
|
||||
|
||||
# Module-level singleton — process-local but reused across all agent
|
||||
# instances built in this process. Subagents created in nested
|
||||
# ``create_agent`` calls also get this so locks are coherent.
|
||||
manager = _ThreadLockManager()
|
||||
|
||||
|
||||
def get_cancel_event(thread_id: str) -> asyncio.Event:
|
||||
"""Public accessor used by long-running tools to poll cancellation."""
|
||||
return manager.cancel_event(thread_id)
|
||||
|
||||
|
||||
def request_cancel(thread_id: str) -> bool:
|
||||
"""Trip the cancel event for ``thread_id``. Returns True if found."""
|
||||
return manager.request_cancel(thread_id)
|
||||
|
||||
|
||||
def reset_cancel(thread_id: str) -> None:
|
||||
"""Reset the cancel event for ``thread_id`` (called between turns)."""
|
||||
manager.reset(thread_id)
|
||||
|
||||
|
||||
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Block concurrent prompts on the same thread.
|
||||
|
||||
Acquires the thread's lock in ``abefore_agent`` and releases in
|
||||
``aafter_agent``. If the lock is held, raises :class:`BusyError`
|
||||
so the caller can emit a ``surfsense.busy`` SSE event with the
|
||||
in-flight request id.
|
||||
|
||||
Args:
|
||||
require_thread_id: When True, raise :class:`BusyError` if no
|
||||
``thread_id`` can be resolved from the active
|
||||
``RunnableConfig``. Default is False — we treat a missing
|
||||
thread_id as "this turn has nothing to lock against" and
|
||||
no-op the mutex. Set True only when you trust the call
|
||||
site to always provide ``configurable.thread_id`` (e.g.
|
||||
in production where ``stream_new_chat`` always does).
|
||||
"""
|
||||
|
||||
def __init__(self, *, require_thread_id: bool = False) -> None:
|
||||
super().__init__()
|
||||
self._require_thread_id = require_thread_id
|
||||
self.tools = []
|
||||
# Per-call locks owned by this middleware. We track them as
|
||||
# an instance attribute so ``aafter_agent`` knows which lock
|
||||
# to release.
|
||||
self._held_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
@staticmethod
|
||||
def _thread_id(runtime: Runtime[ContextT]) -> str | None:
|
||||
"""Extract ``thread_id`` from the active LangGraph ``RunnableConfig``.
|
||||
|
||||
``langgraph.runtime.Runtime`` deliberately does NOT expose ``config``.
|
||||
The runnable config (where ``configurable.thread_id`` lives) must be
|
||||
fetched via :func:`langgraph.config.get_config` from inside a node /
|
||||
middleware. We fall back to ``getattr(runtime, "config", None)`` for
|
||||
unit tests / legacy runtimes that synthesize a config-bearing stub.
|
||||
"""
|
||||
|
||||
def _from_dict(cfg: Any) -> str | None:
|
||||
if not isinstance(cfg, dict):
|
||||
return None
|
||||
tid = (cfg.get("configurable") or {}).get("thread_id")
|
||||
return str(tid) if tid is not None else None
|
||||
|
||||
# Preferred path: real LangGraph runtime context.
|
||||
try:
|
||||
tid = _from_dict(get_config())
|
||||
except Exception:
|
||||
tid = None
|
||||
if tid is not None:
|
||||
return tid
|
||||
|
||||
# Fallback for tests and any runtime that surfaces a config dict
|
||||
# directly on the runtime instance.
|
||||
return _from_dict(getattr(runtime, "config", None))
|
||||
|
||||
async def abefore_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState[Any],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
del state
|
||||
thread_id = self._thread_id(runtime)
|
||||
if thread_id is None:
|
||||
if self._require_thread_id:
|
||||
raise BusyError("no thread_id configured")
|
||||
logger.debug(
|
||||
"BusyMutexMiddleware: no thread_id resolved from RunnableConfig; "
|
||||
"skipping per-thread lock for this turn."
|
||||
)
|
||||
return None
|
||||
|
||||
lock = manager.lock_for(thread_id)
|
||||
if lock.locked():
|
||||
raise BusyError(request_id=thread_id)
|
||||
await lock.acquire()
|
||||
self._held_locks[thread_id] = lock
|
||||
# Reset the cancel event so this turn starts fresh
|
||||
reset_cancel(thread_id)
|
||||
return None
|
||||
|
||||
async def aafter_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState[Any],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
del state
|
||||
thread_id = self._thread_id(runtime)
|
||||
if thread_id is None:
|
||||
return None
|
||||
lock = self._held_locks.pop(thread_id, None)
|
||||
if lock is not None and lock.locked():
|
||||
lock.release()
|
||||
# Always clear cancel event between turns so a stale signal
|
||||
# doesn't leak into the next request.
|
||||
reset_cancel(thread_id)
|
||||
return None
|
||||
|
||||
# Provide sync no-ops because the middleware base class allows them
|
||||
def before_agent( # type: ignore[override]
|
||||
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
# Sync path: no asyncio.Lock to acquire. Best we can do is reject
|
||||
# if anyone else is in flight.
|
||||
thread_id = self._thread_id(runtime)
|
||||
if thread_id is None:
|
||||
if self._require_thread_id:
|
||||
raise BusyError("no thread_id configured")
|
||||
return None
|
||||
lock = manager.lock_for(thread_id)
|
||||
if lock.locked():
|
||||
raise BusyError(request_id=thread_id)
|
||||
return None
|
||||
|
||||
def after_agent( # type: ignore[override]
|
||||
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BusyMutexMiddleware",
|
||||
"get_cancel_event",
|
||||
"manager",
|
||||
"request_cancel",
|
||||
"reset_cancel",
|
||||
]
|
||||
253
surfsense_backend/app/agents/new_chat/middleware/compaction.py
Normal file
253
surfsense_backend/app/agents/new_chat/middleware/compaction.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
"""
|
||||
SurfSense compaction middleware.
|
||||
|
||||
Subclasses :class:`deepagents.middleware.summarization.SummarizationMiddleware`
|
||||
to add SurfSense-specific behavior:
|
||||
|
||||
1. **Structured summary template** (OpenCode-style ``## Goal / Constraints /
|
||||
Progress / Key Decisions / Next Steps / Critical Context / Relevant Files``).
|
||||
2. **Protect SurfSense-specific SystemMessages** so injected hints
|
||||
(``<priority_documents>``, ``<workspace_tree>``, ``<file_operation_contract>``,
|
||||
``<user_memory>``, ``<team_memory>``, ``<user_name>``, ``<memory_warning>``)
|
||||
are *not* summarized away and are kept verbatim in the post-summary
|
||||
message list.
|
||||
3. **Sanitize ``content=None``** when feeding messages into ``get_buffer_string``
|
||||
(Azure OpenAI / LiteLLM defense — when a provider streams an AIMessage
|
||||
containing only tool_calls and no text, ``content`` can be ``None`` and
|
||||
``get_buffer_string`` crashes iterating over ``None``). This used to live in
|
||||
``safe_summarization.py``; folded in here.
|
||||
|
||||
This replaces ``app.agents.new_chat.middleware.safe_summarization``.
|
||||
|
||||
Tier 1.3 in the OpenCode-port plan.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from deepagents.middleware.summarization import (
|
||||
SummarizationMiddleware,
|
||||
compute_summarization_defaults,
|
||||
)
|
||||
from langchain_core.messages import SystemMessage
|
||||
|
||||
from app.observability import otel as ot
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deepagents.backends.protocol import BACKEND_TYPES
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AnyMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OpenCode-faithful structured summary template. Mirrors
|
||||
# ``opencode/packages/opencode/src/session/compaction.ts:40-75``. Kept as a
|
||||
# module-level constant so unit tests can assert on its sections.
|
||||
SURFSENSE_SUMMARY_PROMPT = """<role>
|
||||
SurfSense Conversation Compaction Assistant
|
||||
</role>
|
||||
|
||||
<primary_objective>
|
||||
Extract the most important context from the conversation history below into a structured summary that will replace the older messages.
|
||||
</primary_objective>
|
||||
|
||||
<instructions>
|
||||
You are running because the conversation has grown beyond the model's input window. The conversation history below will be summarized and replaced with your output. Use the structured template that follows; keep each section concise but comprehensive enough that the agent can resume work without losing context. Each section is a checklist — populate it with relevant content or write "None" if there is nothing to report.
|
||||
|
||||
## Goal
|
||||
What is the user's primary goal or request? State it in one or two sentences.
|
||||
|
||||
## Constraints
|
||||
What boundaries must the agent respect (citations rules, visibility scope, allowed tools, user-imposed style, deadlines, deny-listed topics)?
|
||||
|
||||
## Progress
|
||||
What has the agent already accomplished? List each completed step succinctly. Do not reproduce tool output; just record the conclusion.
|
||||
|
||||
## Key Decisions
|
||||
What choices were made and why? Include rejected alternatives and the reasoning behind selecting the current path.
|
||||
|
||||
## Next Steps
|
||||
What specific tasks remain to achieve the goal? Order them by dependency.
|
||||
|
||||
## Critical Context
|
||||
What facts, IDs, document titles, query keywords, error messages, or partial answers must persist into the next turn? Include verbatim quotes only when the exact wording matters (e.g. a precise filter clause or a literal name).
|
||||
|
||||
## Relevant Files
|
||||
What documents or paths in the SurfSense knowledge base are in play? Use ``/documents/...`` paths exactly as they appeared in the workspace tree.
|
||||
</instructions>
|
||||
|
||||
<messages>
|
||||
Messages to summarize:
|
||||
{messages}
|
||||
</messages>
|
||||
|
||||
Respond ONLY with the structured summary. Do not include any text before or after.
|
||||
"""
|
||||
|
||||
# SystemMessage prefixes that must NOT be summarized away. They are
|
||||
# re-injected on every turn by the corresponding middleware, but the
|
||||
# compaction step happens *before* re-injection in some paths, so we
|
||||
# must preserve them verbatim across the cutoff.
|
||||
PROTECTED_SYSTEM_PREFIXES: tuple[str, ...] = (
|
||||
"<priority_documents>", # KnowledgePriorityMiddleware
|
||||
"<workspace_tree>", # KnowledgeTreeMiddleware
|
||||
"<file_operation_contract>", # FileIntentMiddleware
|
||||
"<user_memory>", # MemoryInjectionMiddleware
|
||||
"<team_memory>", # MemoryInjectionMiddleware
|
||||
"<user_name>", # MemoryInjectionMiddleware
|
||||
"<memory_warning>", # MemoryInjectionMiddleware
|
||||
)
|
||||
|
||||
|
||||
def _is_protected_system_message(msg: AnyMessage) -> bool:
|
||||
"""Return True if ``msg`` is a SystemMessage we must not summarize."""
|
||||
if not isinstance(msg, SystemMessage):
|
||||
return False
|
||||
content = msg.content
|
||||
if not isinstance(content, str):
|
||||
return False
|
||||
stripped = content.lstrip()
|
||||
return any(stripped.startswith(prefix) for prefix in PROTECTED_SYSTEM_PREFIXES)
|
||||
|
||||
|
||||
def _sanitize_message_content(msg: AnyMessage) -> AnyMessage:
|
||||
"""Return ``msg`` with ``content=None`` coerced to ``""``.
|
||||
|
||||
Folds in the historical defense from ``safe_summarization.py`` —
|
||||
``get_buffer_string`` reads ``m.text`` which iterates ``self.content``,
|
||||
so a ``None`` content (Azure OpenAI / LiteLLM streaming a tool-only
|
||||
AIMessage) explodes. We return a copy with empty string content so
|
||||
downstream consumers see an empty body without mutating the original.
|
||||
"""
|
||||
if getattr(msg, "content", "not-missing") is not None:
|
||||
return msg
|
||||
try:
|
||||
return msg.model_copy(update={"content": ""})
|
||||
except AttributeError:
|
||||
import copy
|
||||
|
||||
new_msg = copy.copy(msg)
|
||||
try:
|
||||
new_msg.content = ""
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Could not sanitize content=None on message of type %s",
|
||||
type(msg).__name__,
|
||||
)
|
||||
return msg
|
||||
return new_msg
|
||||
|
||||
|
||||
class SurfSenseCompactionMiddleware(SummarizationMiddleware):
|
||||
"""SummarizationMiddleware tuned for SurfSense.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Overrides :meth:`_partition_messages` so protected SystemMessages
|
||||
survive into the ``preserved_messages`` half regardless of cutoff.
|
||||
- Overrides :meth:`_filter_summary_messages` so the buffer-string path
|
||||
never iterates ``None`` content.
|
||||
- Inherits everything else (auto-trigger, backend offload,
|
||||
``_summarization_event`` plumbing, ``ContextOverflowError`` fallback).
|
||||
"""
|
||||
|
||||
def _partition_messages( # type: ignore[override]
|
||||
self,
|
||||
conversation_messages: list[AnyMessage],
|
||||
cutoff_index: int,
|
||||
) -> tuple[list[AnyMessage], list[AnyMessage]]:
|
||||
"""Split messages but always preserve SurfSense protected SystemMessages.
|
||||
|
||||
Mirrors OpenCode's ``PRUNE_PROTECTED_TOOLS`` philosophy
|
||||
(``opencode/packages/opencode/src/session/compaction.ts``): some
|
||||
message types are always kept verbatim because they are part of the
|
||||
agent's working contract, not transient output.
|
||||
|
||||
Also opens a ``compaction.run`` OTel span (no-op when OTel is off)
|
||||
so dashboards can count compaction events and message-volume
|
||||
without having to instrument upstream callers.
|
||||
"""
|
||||
# Opening a span here is appropriate because partitioning is the
|
||||
# first call SummarizationMiddleware makes when it has decided to
|
||||
# summarize; we record the volume and then close as a normal span.
|
||||
with ot.compaction_span(
|
||||
reason="auto",
|
||||
messages_in=len(conversation_messages),
|
||||
extra={"compaction.cutoff_index": int(cutoff_index)},
|
||||
):
|
||||
messages_to_summarize, preserved_messages = (
|
||||
super()._partition_messages(conversation_messages, cutoff_index)
|
||||
)
|
||||
|
||||
protected: list[AnyMessage] = []
|
||||
kept_for_summary: list[AnyMessage] = []
|
||||
for msg in messages_to_summarize:
|
||||
if _is_protected_system_message(msg):
|
||||
protected.append(msg)
|
||||
else:
|
||||
kept_for_summary.append(msg)
|
||||
|
||||
# Place protected blocks at the *front* of preserved_messages so
|
||||
# they keep their original ordering relative to the summary
|
||||
# HumanMessage that precedes the rest of the preserved tail.
|
||||
return kept_for_summary, [*protected, *preserved_messages]
|
||||
|
||||
def _filter_summary_messages( # type: ignore[override]
|
||||
self, messages: list[AnyMessage]
|
||||
) -> list[AnyMessage]:
|
||||
"""Filter previous summaries AND sanitize ``content=None``.
|
||||
|
||||
Folds the ``safe_summarization.py`` defense in: when the buffer
|
||||
builder iterates ``m.text`` over ``None`` it explodes; sanitizing
|
||||
here covers both the sync and async offload paths.
|
||||
"""
|
||||
filtered = super()._filter_summary_messages(messages)
|
||||
return [_sanitize_message_content(m) for m in filtered]
|
||||
|
||||
|
||||
def create_surfsense_compaction_middleware(
|
||||
model: BaseChatModel,
|
||||
backend: BACKEND_TYPES,
|
||||
*,
|
||||
summary_prompt: str | None = None,
|
||||
history_path_prefix: str = "/conversation_history",
|
||||
**overrides: Any,
|
||||
) -> SurfSenseCompactionMiddleware:
|
||||
"""Build a :class:`SurfSenseCompactionMiddleware` with sensible defaults.
|
||||
|
||||
Pulls profile-aware ``trigger`` / ``keep`` / ``truncate_args_settings``
|
||||
via :func:`deepagents.middleware.summarization.compute_summarization_defaults`
|
||||
so callers get the same behavior as ``create_summarization_middleware``
|
||||
plus our overrides.
|
||||
|
||||
Args:
|
||||
model: Chat model to call for summary generation.
|
||||
backend: Backend instance or factory for offloading conversation history.
|
||||
summary_prompt: Optional override; defaults to :data:`SURFSENSE_SUMMARY_PROMPT`.
|
||||
history_path_prefix: Path prefix for offloaded conversation history.
|
||||
**overrides: Forwarded to :class:`SurfSenseCompactionMiddleware`.
|
||||
"""
|
||||
defaults = compute_summarization_defaults(model)
|
||||
return SurfSenseCompactionMiddleware(
|
||||
model=model,
|
||||
backend=backend,
|
||||
trigger=overrides.pop("trigger", defaults["trigger"]),
|
||||
keep=overrides.pop("keep", defaults["keep"]),
|
||||
trim_tokens_to_summarize=overrides.pop("trim_tokens_to_summarize", None),
|
||||
truncate_args_settings=overrides.pop(
|
||||
"truncate_args_settings", defaults["truncate_args_settings"]
|
||||
),
|
||||
summary_prompt=summary_prompt or SURFSENSE_SUMMARY_PROMPT,
|
||||
history_path_prefix=history_path_prefix,
|
||||
**overrides,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PROTECTED_SYSTEM_PREFIXES",
|
||||
"SURFSENSE_SUMMARY_PROMPT",
|
||||
"SurfSenseCompactionMiddleware",
|
||||
"create_surfsense_compaction_middleware",
|
||||
]
|
||||
|
|
@ -0,0 +1,349 @@
|
|||
"""
|
||||
SpillToBackendEdit + SpillingContextEditingMiddleware.
|
||||
|
||||
Mirrors OpenCode's spill-to-disk behavior in
|
||||
``opencode/packages/opencode/src/tool/truncate.ts``. Before
|
||||
``ClearToolUsesEdit`` rewrites old ``ToolMessage.content`` to a placeholder,
|
||||
we capture the full original content and write it to the runtime backend
|
||||
under ``/tool_outputs/{thread_id}/{message_id}.txt``. The placeholder is
|
||||
upgraded to ``"[cleared — full output at /tool_outputs/.../{id}.txt; ask the
|
||||
explore subagent to read it]"`` so the agent can recover it on demand.
|
||||
|
||||
Tier 1.2 in the OpenCode-port plan.
|
||||
|
||||
Why this is a middleware subclass instead of a plain ``ContextEdit``:
|
||||
``ContextEdit.apply`` is sync, but writing to the backend is async. We
|
||||
capture the spill payloads inside ``apply`` and flush them via
|
||||
``await backend.aupload_files(...)`` from ``awrap_model_call`` *before*
|
||||
delegating to the handler, so the explore subagent can always read what
|
||||
the placeholder advertises.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware.context_editing import (
|
||||
ClearToolUsesEdit,
|
||||
ContextEdit,
|
||||
ContextEditingMiddleware,
|
||||
TokenCounter,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.utils import count_tokens_approximately
|
||||
from langgraph.config import get_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deepagents.backends.protocol import BackendProtocol
|
||||
from langchain.agents.middleware.types import (
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SPILL_PREFIX = "/tool_outputs"
|
||||
|
||||
|
||||
def _build_spill_placeholder(spill_path: str) -> str:
|
||||
"""Build the user-facing placeholder text shown to the model."""
|
||||
return (
|
||||
f"[cleared — full output at {spill_path}; "
|
||||
f"ask the explore subagent to read it]"
|
||||
)
|
||||
|
||||
|
||||
def _get_thread_id_or_session() -> str:
|
||||
"""Best-effort thread_id discovery for the spill path.
|
||||
|
||||
Falls back to a process-stable string if no LangGraph config is
|
||||
available (e.g. unit tests). The exact value doesn't matter as long
|
||||
as it's stable within one stream so the placeholder paths line up
|
||||
with the actual upload path.
|
||||
"""
|
||||
try:
|
||||
config = get_config()
|
||||
thread_id = config.get("configurable", {}).get("thread_id")
|
||||
if thread_id is not None:
|
||||
return str(thread_id)
|
||||
except RuntimeError:
|
||||
pass
|
||||
return "no_thread"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SpillToBackendEdit(ContextEdit):
|
||||
"""Capture-and-replace context edit that spills full tool output to the backend.
|
||||
|
||||
Behaves like :class:`ClearToolUsesEdit` (same trigger / keep / exclude
|
||||
semantics) **and** records the original ``ToolMessage.content`` in
|
||||
:attr:`pending_spills` so the wrapping middleware can flush them
|
||||
before the model call.
|
||||
|
||||
Args:
|
||||
trigger: Token threshold above which the edit fires.
|
||||
clear_at_least: Minimum number of tokens to reclaim (best effort).
|
||||
keep: Number of most-recent ``ToolMessage`` instances to leave
|
||||
untouched.
|
||||
exclude_tools: Names of tools whose output is NOT spilled.
|
||||
clear_tool_inputs: Also clear the originating ``AIMessage.tool_calls``
|
||||
args when their pair is cleared.
|
||||
path_prefix: Path under the backend where spills are written.
|
||||
Default ``"/tool_outputs"``.
|
||||
"""
|
||||
|
||||
trigger: int = 100_000
|
||||
clear_at_least: int = 0
|
||||
keep: int = 3
|
||||
clear_tool_inputs: bool = False
|
||||
exclude_tools: Sequence[str] = ()
|
||||
path_prefix: str = DEFAULT_SPILL_PREFIX
|
||||
|
||||
pending_spills: list[tuple[str, bytes]] = field(default_factory=list)
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
|
||||
def drain_pending(self) -> list[tuple[str, bytes]]:
|
||||
"""Return and clear the pending-spill list atomically."""
|
||||
with self._lock:
|
||||
out = list(self.pending_spills)
|
||||
self.pending_spills.clear()
|
||||
return out
|
||||
|
||||
def apply(
|
||||
self,
|
||||
messages: list[AnyMessage],
|
||||
*,
|
||||
count_tokens: TokenCounter,
|
||||
) -> None:
|
||||
"""Mirror ``ClearToolUsesEdit.apply`` but capture originals first."""
|
||||
tokens = count_tokens(messages)
|
||||
if tokens <= self.trigger:
|
||||
return
|
||||
|
||||
candidates = [
|
||||
(idx, msg) for idx, msg in enumerate(messages) if isinstance(msg, ToolMessage)
|
||||
]
|
||||
if self.keep >= len(candidates):
|
||||
return
|
||||
if self.keep:
|
||||
candidates = candidates[: -self.keep]
|
||||
|
||||
thread_id = _get_thread_id_or_session()
|
||||
excluded_tools = set(self.exclude_tools)
|
||||
|
||||
for idx, tool_message in candidates:
|
||||
if tool_message.response_metadata.get("context_editing", {}).get("cleared"):
|
||||
continue
|
||||
|
||||
ai_message = next(
|
||||
(m for m in reversed(messages[:idx]) if isinstance(m, AIMessage)),
|
||||
None,
|
||||
)
|
||||
if ai_message is None:
|
||||
continue
|
||||
|
||||
tool_call = next(
|
||||
(
|
||||
call
|
||||
for call in ai_message.tool_calls
|
||||
if call.get("id") == tool_message.tool_call_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if tool_call is None:
|
||||
continue
|
||||
|
||||
tool_name = tool_message.name or tool_call["name"]
|
||||
if tool_name in excluded_tools:
|
||||
continue
|
||||
|
||||
message_id = tool_message.id or tool_message.tool_call_id or "unknown"
|
||||
spill_path = f"{self.path_prefix}/{thread_id}/{message_id}.txt"
|
||||
|
||||
original = tool_message.content
|
||||
payload = self._encode_payload(original)
|
||||
with self._lock:
|
||||
self.pending_spills.append((spill_path, payload))
|
||||
|
||||
messages[idx] = tool_message.model_copy(
|
||||
update={
|
||||
"artifact": None,
|
||||
"content": _build_spill_placeholder(spill_path),
|
||||
"response_metadata": {
|
||||
**tool_message.response_metadata,
|
||||
"context_editing": {
|
||||
"cleared": True,
|
||||
"strategy": "spill_to_backend",
|
||||
"spill_path": spill_path,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if self.clear_tool_inputs:
|
||||
ai_idx = messages.index(ai_message)
|
||||
messages[ai_idx] = self._clear_input_args(
|
||||
ai_message, tool_message.tool_call_id or ""
|
||||
)
|
||||
|
||||
if self.clear_at_least > 0:
|
||||
new_token_count = count_tokens(messages)
|
||||
cleared_tokens = max(0, tokens - new_token_count)
|
||||
if cleared_tokens >= self.clear_at_least:
|
||||
break
|
||||
|
||||
@staticmethod
|
||||
def _encode_payload(content: Any) -> bytes:
|
||||
"""Serialize ``ToolMessage.content`` to bytes for upload."""
|
||||
if isinstance(content, bytes):
|
||||
return content
|
||||
if isinstance(content, str):
|
||||
return content.encode("utf-8")
|
||||
try:
|
||||
import json
|
||||
|
||||
return json.dumps(content, default=str).encode("utf-8")
|
||||
except Exception:
|
||||
return str(content).encode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def _clear_input_args(message: AIMessage, tool_call_id: str) -> AIMessage:
|
||||
updated_tool_calls: list[dict[str, Any]] = []
|
||||
cleared_any = False
|
||||
for tool_call in message.tool_calls:
|
||||
updated = dict(tool_call)
|
||||
if updated.get("id") == tool_call_id:
|
||||
updated["args"] = {}
|
||||
cleared_any = True
|
||||
updated_tool_calls.append(updated)
|
||||
|
||||
metadata = dict(getattr(message, "response_metadata", {}))
|
||||
if cleared_any:
|
||||
ctx = dict(metadata.get("context_editing", {}))
|
||||
ids = set(ctx.get("cleared_tool_inputs", []))
|
||||
ids.add(tool_call_id)
|
||||
ctx["cleared_tool_inputs"] = sorted(ids)
|
||||
metadata["context_editing"] = ctx
|
||||
return message.model_copy(
|
||||
update={
|
||||
"tool_calls": updated_tool_calls,
|
||||
"response_metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
BackendResolver = "Callable[[Any], BackendProtocol] | BackendProtocol"
|
||||
|
||||
|
||||
class SpillingContextEditingMiddleware(ContextEditingMiddleware):
|
||||
""":class:`ContextEditingMiddleware` that flushes :class:`SpillToBackendEdit` writes.
|
||||
|
||||
Runs the configured edits as the parent does, then flushes any
|
||||
pending spills via the supplied backend resolver before delegating
|
||||
to the model handler. Spill failures are logged but never abort the
|
||||
model call — the placeholder text is already in the message, so the
|
||||
worst case is the agent gets a placeholder it cannot follow up on.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
edits: Sequence[ContextEdit],
|
||||
backend_resolver: BackendResolver | None = None,
|
||||
token_count_method: str = "approximate",
|
||||
) -> None:
|
||||
super().__init__(edits=list(edits), token_count_method=token_count_method) # type: ignore[arg-type]
|
||||
self._backend_resolver = backend_resolver
|
||||
|
||||
def _resolve_backend(self, request: ModelRequest) -> BackendProtocol | None:
|
||||
if self._backend_resolver is None:
|
||||
return None
|
||||
if callable(self._backend_resolver):
|
||||
try:
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
tool_runtime = ToolRuntime(
|
||||
state=getattr(request, "state", {}),
|
||||
context=getattr(request.runtime, "context", None),
|
||||
stream_writer=getattr(request.runtime, "stream_writer", None),
|
||||
store=getattr(request.runtime, "store", None),
|
||||
config=getattr(request.runtime, "config", None) or {},
|
||||
tool_call_id=None,
|
||||
)
|
||||
return self._backend_resolver(tool_runtime)
|
||||
except Exception:
|
||||
logger.exception("Failed to resolve spill backend")
|
||||
return None
|
||||
return self._backend_resolver # type: ignore[return-value]
|
||||
|
||||
def _collect_pending(self) -> list[tuple[str, bytes]]:
|
||||
out: list[tuple[str, bytes]] = []
|
||||
for edit in self.edits:
|
||||
if isinstance(edit, SpillToBackendEdit):
|
||||
out.extend(edit.drain_pending())
|
||||
return out
|
||||
|
||||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> Any:
|
||||
if not request.messages:
|
||||
return await handler(request)
|
||||
|
||||
if self.token_count_method == "approximate":
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return count_tokens_approximately(messages)
|
||||
|
||||
else:
|
||||
system_msg = [request.system_message] if request.system_message else []
|
||||
|
||||
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
||||
return request.model.get_num_tokens_from_messages(
|
||||
system_msg + list(messages), request.tools
|
||||
)
|
||||
|
||||
edited_messages = deepcopy(list(request.messages))
|
||||
for edit in self.edits:
|
||||
edit.apply(edited_messages, count_tokens=count_tokens)
|
||||
|
||||
pending = self._collect_pending()
|
||||
if pending:
|
||||
backend = self._resolve_backend(request)
|
||||
if backend is not None:
|
||||
try:
|
||||
await backend.aupload_files(pending)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Spill-to-backend upload failed (%d files); placeholders "
|
||||
"remain in messages but content is unrecoverable",
|
||||
len(pending),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"SpillToBackendEdit produced %d pending spills but no backend "
|
||||
"resolver was configured; content is unrecoverable",
|
||||
len(pending),
|
||||
)
|
||||
|
||||
return await handler(request.override(messages=edited_messages))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_SPILL_PREFIX",
|
||||
"ClearToolUsesEdit",
|
||||
"SpillToBackendEdit",
|
||||
"SpillingContextEditingMiddleware",
|
||||
"_build_spill_placeholder",
|
||||
]
|
||||
|
|
@ -2,17 +2,28 @@
|
|||
|
||||
When the LLM emits multiple calls to the same HITL tool with the same
|
||||
primary argument (e.g. two ``delete_calendar_event("Doctor Appointment")``),
|
||||
only the first call is kept. Non-HITL tools are never touched.
|
||||
only the first call is kept. Non-HITL tools are never touched.
|
||||
|
||||
This runs in the ``after_model`` hook — **before** any tool executes — so
|
||||
the duplicate call is stripped from the AIMessage that gets checkpointed.
|
||||
That means it is also safe across LangGraph ``interrupt()`` boundaries:
|
||||
the removed call will never appear on graph resume.
|
||||
|
||||
Dedup-key resolution order (Tier 2.3 / cleanup in the OpenCode-port plan):
|
||||
|
||||
1. :class:`ToolDefinition.dedup_key` — callable provided by the registry
|
||||
entry. This is the canonical mechanism after the cleanup-tier removal
|
||||
of the legacy ``PRIMARY_ARG`` map.
|
||||
2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg name;
|
||||
used by MCP / Composio tools whose schemas the registry doesn't see.
|
||||
|
||||
A tool with no resolver from either path simply opts out of dedup.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
|
|
@ -20,81 +31,84 @@ from langgraph.runtime import Runtime
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NATIVE_HITL_TOOL_DEDUP_KEYS: dict[str, str] = {
|
||||
# Gmail
|
||||
"send_gmail_email": "subject",
|
||||
"create_gmail_draft": "subject",
|
||||
"update_gmail_draft": "draft_subject_or_id",
|
||||
"trash_gmail_email": "email_subject_or_id",
|
||||
# Google Calendar
|
||||
"create_calendar_event": "title",
|
||||
"update_calendar_event": "event_title_or_id",
|
||||
"delete_calendar_event": "event_title_or_id",
|
||||
# Google Drive
|
||||
"create_google_drive_file": "file_name",
|
||||
"delete_google_drive_file": "file_name",
|
||||
# OneDrive
|
||||
"create_onedrive_file": "file_name",
|
||||
"delete_onedrive_file": "file_name",
|
||||
# Dropbox
|
||||
"create_dropbox_file": "file_name",
|
||||
"delete_dropbox_file": "file_name",
|
||||
# Notion
|
||||
"create_notion_page": "title",
|
||||
"update_notion_page": "page_title",
|
||||
"delete_notion_page": "page_title",
|
||||
# Linear
|
||||
"create_linear_issue": "title",
|
||||
"update_linear_issue": "issue_ref",
|
||||
"delete_linear_issue": "issue_ref",
|
||||
# Jira
|
||||
"create_jira_issue": "summary",
|
||||
"update_jira_issue": "issue_title_or_key",
|
||||
"delete_jira_issue": "issue_title_or_key",
|
||||
# Confluence
|
||||
"create_confluence_page": "title",
|
||||
"update_confluence_page": "page_title_or_id",
|
||||
"delete_confluence_page": "page_title_or_id",
|
||||
}
|
||||
# Resolver type — given the tool ``args`` dict returns a stable
|
||||
# string used to dedupe consecutive calls. ``None`` means no dedup.
|
||||
DedupResolver = Callable[[dict[str, Any]], str]
|
||||
|
||||
|
||||
def wrap_dedup_key_by_arg_name(arg_name: str) -> DedupResolver:
|
||||
"""Adapt a string-arg name into a :data:`DedupResolver`.
|
||||
|
||||
Convenience helper used by registry entries that just want to dedupe
|
||||
on a single arg's lowercased value (the most common case for native
|
||||
HITL tools like ``send_gmail_email`` keyed on ``subject``).
|
||||
|
||||
Example::
|
||||
|
||||
ToolDefinition(
|
||||
name="send_gmail_email",
|
||||
...,
|
||||
dedup_key=wrap_dedup_key_by_arg_name("subject"),
|
||||
)
|
||||
"""
|
||||
|
||||
def _resolver(args: dict[str, Any]) -> str:
|
||||
return str(args.get(arg_name, "")).lower()
|
||||
|
||||
return _resolver
|
||||
|
||||
|
||||
# Backwards-compatible alias for code that imported the original
|
||||
# private name. New callers should use :func:`wrap_dedup_key_by_arg_name`.
|
||||
_wrap_string_key = wrap_dedup_key_by_arg_name
|
||||
|
||||
|
||||
class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Remove duplicate HITL tool calls from a single LLM response.
|
||||
|
||||
Only the **first** occurrence of each (tool-name, primary-arg-value)
|
||||
Only the **first** occurrence of each ``(tool-name, dedup_key)``
|
||||
pair is kept; subsequent duplicates are silently dropped.
|
||||
|
||||
The dedup map is built from two sources:
|
||||
The dedup-resolver map is built from two sources, in priority order:
|
||||
|
||||
1. A comprehensive list of native HITL tools (hardcoded above).
|
||||
2. Any ``StructuredTool`` instances passed via *agent_tools* whose
|
||||
``metadata`` contains ``{"hitl": True, "hitl_dedup_key": "..."}``.
|
||||
This is how MCP tools automatically get dedup support.
|
||||
1. ``tool.metadata["dedup_key"]`` — callable provided by the registry's
|
||||
``ToolDefinition.dedup_key`` (Tier 2.3). Receives the args dict
|
||||
and returns a string signature. This is the canonical mechanism
|
||||
after the cleanup-tier removal of the legacy ``PRIMARY_ARG`` map.
|
||||
2. ``tool.metadata["hitl_dedup_key"]`` — string with a primary arg
|
||||
name; primarily used by MCP / Composio tools.
|
||||
"""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(self, *, agent_tools: list[Any] | None = None) -> None:
|
||||
self._dedup_keys: dict[str, str] = dict(_NATIVE_HITL_TOOL_DEDUP_KEYS)
|
||||
self._resolvers: dict[str, DedupResolver] = {}
|
||||
|
||||
for t in agent_tools or []:
|
||||
meta = getattr(t, "metadata", None) or {}
|
||||
callable_key = meta.get("dedup_key")
|
||||
if callable(callable_key):
|
||||
self._resolvers[t.name] = callable_key
|
||||
continue
|
||||
if meta.get("hitl") and meta.get("hitl_dedup_key"):
|
||||
self._dedup_keys[t.name] = meta["hitl_dedup_key"]
|
||||
self._resolvers[t.name] = wrap_dedup_key_by_arg_name(
|
||||
meta["hitl_dedup_key"]
|
||||
)
|
||||
|
||||
def after_model(
|
||||
self, state: AgentState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._dedup(state, self._dedup_keys)
|
||||
return self._dedup(state, self._resolvers)
|
||||
|
||||
async def aafter_model(
|
||||
self, state: AgentState, runtime: Runtime[Any]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._dedup(state, self._dedup_keys)
|
||||
return self._dedup(state, self._resolvers)
|
||||
|
||||
@staticmethod
|
||||
def _dedup(
|
||||
state: AgentState,
|
||||
dedup_keys: dict[str, str], # type: ignore[type-arg]
|
||||
resolvers: dict[str, DedupResolver],
|
||||
) -> dict[str, Any] | None:
|
||||
messages = state.get("messages")
|
||||
if not messages:
|
||||
|
|
@ -110,9 +124,16 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
|
||||
for tc in tool_calls:
|
||||
name = tc.get("name", "")
|
||||
dedup_key_arg = dedup_keys.get(name)
|
||||
if dedup_key_arg is not None:
|
||||
arg_val = str(tc.get("args", {}).get(dedup_key_arg, "")).lower()
|
||||
resolver = resolvers.get(name)
|
||||
if resolver is not None:
|
||||
try:
|
||||
arg_val = resolver(tc.get("args", {}) or {})
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Dedup resolver for tool %s raised; keeping call", name
|
||||
)
|
||||
deduped.append(tc)
|
||||
continue
|
||||
key = (name, arg_val)
|
||||
if key in seen:
|
||||
logger.info(
|
||||
|
|
|
|||
228
surfsense_backend/app/agents/new_chat/middleware/doom_loop.py
Normal file
228
surfsense_backend/app/agents/new_chat/middleware/doom_loop.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
"""
|
||||
DoomLoopMiddleware — pattern-based detector for repeated identical tool calls.
|
||||
|
||||
Mirrors ``opencode/packages/opencode/src/session/processor.ts`` doom-loop
|
||||
behavior. When the same tool with the same arguments is called N times
|
||||
in a row, the agent has likely entered an infinite loop. We surface this
|
||||
to the user as an interrupt with ``permission="doom_loop"`` so the UI
|
||||
can render an "Are you stuck? Continue / cancel?" affordance.
|
||||
|
||||
Tier 1.11 in the OpenCode-port plan.
|
||||
|
||||
This ships **OFF by default** until the frontend explicitly handles
|
||||
``context.permission == "doom_loop"`` interrupts (the plan flips
|
||||
``SURFSENSE_ENABLE_DOOM_LOOP=true`` once the UI is ready).
|
||||
|
||||
Wire format: uses SurfSense's existing ``interrupt()`` payload shape
|
||||
(see ``app/agents/new_chat/tools/hitl.py``):
|
||||
|
||||
{
|
||||
"type": "permission_ask",
|
||||
"action": {"tool": <name>, "params": <args>},
|
||||
"context": {"permission": "doom_loop", "recent_signatures": [...]},
|
||||
}
|
||||
|
||||
so the frontend that already handles HITL prompts can render this with
|
||||
no changes beyond a string check.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import interrupt
|
||||
|
||||
from app.observability import otel as ot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _signature(name: str, args: Any) -> str:
|
||||
"""Hash a tool call ``(name, args)`` to a short signature."""
|
||||
try:
|
||||
canonical = json.dumps(args, sort_keys=True, default=str)
|
||||
except (TypeError, ValueError):
|
||||
canonical = repr(args)
|
||||
digest = hashlib.sha1(f"{name}::{canonical}".encode()).hexdigest()
|
||||
return digest[:16]
|
||||
|
||||
|
||||
class DoomLoopMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Detect repeated identical tool calls and prompt the user.
|
||||
|
||||
Tracks a sliding window of the most-recent ``threshold`` tool-call
|
||||
signatures across the live request. When all entries match, raise
|
||||
a SurfSense-style HITL interrupt with ``permission="doom_loop"``.
|
||||
|
||||
Args:
|
||||
threshold: How many consecutive identical signatures count as a
|
||||
doom loop. Default 3 (opencode parity).
|
||||
"""
|
||||
|
||||
def __init__(self, *, threshold: int = 3) -> None:
|
||||
super().__init__()
|
||||
if threshold < 2:
|
||||
raise ValueError("DoomLoopMiddleware threshold must be >= 2")
|
||||
self._threshold = threshold
|
||||
self.tools = []
|
||||
# Per-thread sliding windows. We can't put this in graph state
|
||||
# without state-schema gymnastics; for one process-lifetime it's
|
||||
# fine to keep an in-memory map keyed by thread_id.
|
||||
self._windows: dict[str, deque[str]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _thread_id_from_runtime(runtime: Runtime[ContextT]) -> str:
|
||||
"""Resolve the thread id for sliding-window keying.
|
||||
|
||||
Prefer LangGraph's ``get_config()`` (the only way to read
|
||||
``RunnableConfig`` inside a node — :class:`Runtime` does NOT carry
|
||||
a ``config`` attribute). Fall back to ``runtime.config`` for unit
|
||||
tests that synthesize a config-bearing stub. Default
|
||||
``"no_thread"`` is intentionally only used when both lookups fail
|
||||
— it would collapse all threads into one window so we keep the
|
||||
debug log loud.
|
||||
"""
|
||||
|
||||
def _from_dict(cfg: Any) -> str | None:
|
||||
if not isinstance(cfg, dict):
|
||||
return None
|
||||
tid = (cfg.get("configurable") or {}).get("thread_id")
|
||||
return str(tid) if tid is not None else None
|
||||
|
||||
try:
|
||||
tid = _from_dict(get_config())
|
||||
except Exception:
|
||||
tid = None
|
||||
if tid is not None:
|
||||
return tid
|
||||
|
||||
tid = _from_dict(getattr(runtime, "config", None))
|
||||
if tid is not None:
|
||||
return tid
|
||||
|
||||
logger.debug(
|
||||
"DoomLoopMiddleware: no thread_id resolved from RunnableConfig; "
|
||||
"falling back to shared 'no_thread' window."
|
||||
)
|
||||
return "no_thread"
|
||||
|
||||
def _window(self, thread_id: str) -> deque[str]:
|
||||
win = self._windows.get(thread_id)
|
||||
if win is None:
|
||||
win = deque(maxlen=self._threshold)
|
||||
self._windows[thread_id] = win
|
||||
return win
|
||||
|
||||
def _detect(
|
||||
self, message: AIMessage, runtime: Runtime[ContextT]
|
||||
) -> tuple[bool, list[str], dict[str, Any] | None]:
|
||||
if not message.tool_calls:
|
||||
return False, [], None
|
||||
|
||||
thread_id = self._thread_id_from_runtime(runtime)
|
||||
window = self._window(thread_id)
|
||||
|
||||
triggered_call: dict[str, Any] | None = None
|
||||
for call in message.tool_calls:
|
||||
name = call.get("name") if isinstance(call, dict) else getattr(call, "name", None)
|
||||
args = call.get("args") if isinstance(call, dict) else getattr(call, "args", {})
|
||||
if not isinstance(name, str):
|
||||
continue
|
||||
sig = _signature(name, args)
|
||||
window.append(sig)
|
||||
if (
|
||||
len(window) >= self._threshold
|
||||
and len(set(window)) == 1
|
||||
):
|
||||
triggered_call = {"name": name, "params": args or {}}
|
||||
break
|
||||
|
||||
if triggered_call is None:
|
||||
return False, list(window), None
|
||||
return True, list(window), triggered_call
|
||||
|
||||
def after_model( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage):
|
||||
return None
|
||||
|
||||
triggered, signatures, action = self._detect(last, runtime)
|
||||
if not triggered:
|
||||
return None
|
||||
|
||||
logger.warning(
|
||||
"Doom loop detected: tool %s called %d times in a row (sig=%s)",
|
||||
action["name"] if action else "<unknown>",
|
||||
self._threshold,
|
||||
signatures[-1] if signatures else "<empty>",
|
||||
)
|
||||
|
||||
# Tier 3b: interrupt.raised span with permission=doom_loop attribute
|
||||
# so dashboards can break out doom-loop interrupts from regular
|
||||
# permission asks via the ``interrupt.permission`` attribute.
|
||||
with ot.interrupt_span(
|
||||
interrupt_type="permission_ask",
|
||||
extra={
|
||||
"interrupt.permission": "doom_loop",
|
||||
"interrupt.threshold": self._threshold,
|
||||
"interrupt.tool": (action or {}).get("tool", "<unknown>"),
|
||||
},
|
||||
):
|
||||
decision = interrupt(
|
||||
{
|
||||
"type": "permission_ask",
|
||||
"action": action or {"tool": "<unknown>", "params": {}},
|
||||
"context": {
|
||||
"permission": "doom_loop",
|
||||
"recent_signatures": signatures,
|
||||
"threshold": self._threshold,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Reset window so the next decision (continue/cancel) starts fresh.
|
||||
thread_id = self._thread_id_from_runtime(runtime)
|
||||
self._windows.pop(thread_id, None)
|
||||
|
||||
# Decision shape mirrors ``tools/hitl.py``: {"decision_type": "..."}
|
||||
# If the user cancelled, jump to end. Otherwise return ``None`` so the
|
||||
# tool call proceeds. The frontend's exact reply names may differ —
|
||||
# we tolerate any shape that contains a string with "reject"/"cancel".
|
||||
if isinstance(decision, dict):
|
||||
kind = str(decision.get("decision_type") or decision.get("type") or "").lower()
|
||||
if "reject" in kind or "cancel" in kind:
|
||||
return {"jump_to": "end"}
|
||||
return None
|
||||
|
||||
async def aafter_model( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DoomLoopMiddleware",
|
||||
"_signature",
|
||||
]
|
||||
|
|
@ -31,14 +31,17 @@ from collections.abc import Sequence
|
|||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import Runnable
|
||||
from langgraph.runtime import Runtime
|
||||
from litellm import token_counter
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.new_chat.feature_flags import get_flags
|
||||
from app.agents.new_chat.filesystem_selection import FilesystemMode
|
||||
from app.agents.new_chat.filesystem_state import SurfSenseFilesystemState
|
||||
from app.agents.new_chat.path_resolver import (
|
||||
|
|
@ -589,6 +592,53 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
self.available_document_types = available_document_types
|
||||
self.top_k = top_k
|
||||
self.mentioned_document_ids = mentioned_document_ids or []
|
||||
# Tier 4.2: build the kb-planner private Runnable ONCE here so we
|
||||
# don't pay the create_agent compile cost (50–200ms) on every turn.
|
||||
# Disabled by default behind ``enable_kb_planner_runnable``; when off
|
||||
# the planner falls back to the legacy ``self.llm.ainvoke`` path.
|
||||
self._planner: Runnable | None = None
|
||||
self._planner_compile_failed = False
|
||||
|
||||
def _build_kb_planner_runnable(self) -> Runnable | None:
|
||||
"""Compile the kb-planner private :class:`Runnable` once.
|
||||
|
||||
Returns ``None`` when the feature flag is disabled, when the LLM is
|
||||
unavailable, or when ``create_agent`` raises (we fall back to the
|
||||
legacy ``self.llm.ainvoke`` path in that case). Compilation happens
|
||||
lazily on first call, then memoized via ``self._planner``.
|
||||
|
||||
The compiled agent is constructed without tools — the planner's
|
||||
contract is "answer with structured JSON" — but with ``RetryAfter``
|
||||
+ the OpenCode-port retry/limit middleware so it shares the parent
|
||||
agent's resilience guarantees.
|
||||
"""
|
||||
if self._planner is not None or self._planner_compile_failed:
|
||||
return self._planner
|
||||
if self.llm is None:
|
||||
return None
|
||||
flags = get_flags()
|
||||
if (
|
||||
not flags.enable_kb_planner_runnable
|
||||
or flags.disable_new_agent_stack
|
||||
):
|
||||
return None
|
||||
|
||||
from app.agents.new_chat.middleware.retry_after import RetryAfterMiddleware
|
||||
|
||||
try:
|
||||
self._planner = create_agent(
|
||||
self.llm,
|
||||
tools=[],
|
||||
middleware=[RetryAfterMiddleware(max_retries=2)],
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.warning(
|
||||
"kb-planner Runnable compile failed; falling back to llm.ainvoke: %s",
|
||||
exc,
|
||||
)
|
||||
self._planner_compile_failed = True
|
||||
self._planner = None
|
||||
return self._planner
|
||||
|
||||
async def _plan_search_inputs(
|
||||
self,
|
||||
|
|
@ -611,11 +661,32 @@ class KnowledgePriorityMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
loop = asyncio.get_running_loop()
|
||||
t0 = loop.time()
|
||||
|
||||
# Tier 4.2: prefer the compiled-once planner Runnable when enabled;
|
||||
# otherwise fall back to ``self.llm.ainvoke``. The ``surfsense:internal``
|
||||
# tag is preserved on both paths so ``_stream_agent_events`` still
|
||||
# suppresses the planner's intermediate events from the UI.
|
||||
planner = self._build_kb_planner_runnable()
|
||||
try:
|
||||
response = await self.llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal"]},
|
||||
)
|
||||
if planner is not None:
|
||||
planner_state = await planner.ainvoke(
|
||||
{"messages": [HumanMessage(content=prompt)]},
|
||||
config={"tags": ["surfsense:internal"]},
|
||||
)
|
||||
response_messages = (
|
||||
planner_state.get("messages", [])
|
||||
if isinstance(planner_state, dict)
|
||||
else []
|
||||
)
|
||||
response = (
|
||||
response_messages[-1]
|
||||
if response_messages
|
||||
else AIMessage(content="")
|
||||
)
|
||||
else:
|
||||
response = await self.llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal"]},
|
||||
)
|
||||
plan = _parse_kb_search_plan_response(_extract_text_from_message(response))
|
||||
optimized_query = (
|
||||
re.sub(r"\s+", " ", plan.optimized_query).strip() or user_text
|
||||
|
|
|
|||
|
|
@ -0,0 +1,133 @@
|
|||
"""
|
||||
``_noop`` provider-compatibility tool + injection middleware.
|
||||
|
||||
OpenCode injects a ``_noop`` tool for LiteLLM/Bedrock/Copilot when the
|
||||
model call has empty tools but message history includes prior
|
||||
``tool_calls`` — some providers 400 in that shape (see
|
||||
``opencode/packages/opencode/src/session/llm.ts:209-228``). SurfSense uses
|
||||
LiteLLM, and the compaction summarize call (no tools, history full of
|
||||
tool calls) hits this. Tier 1.5 in the OpenCode-port plan.
|
||||
|
||||
Operation: a :class:`NoopInjectionMiddleware` ``wrap_model_call`` checks
|
||||
if the request has zero tools but the last AI message in history includes
|
||||
``tool_calls``. If yes, it injects the ``_noop`` tool only — never globally,
|
||||
mirroring opencode's gating exactly. The :func:`noop_tool` returns empty
|
||||
content when called (which it should never be in practice).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NOOP_TOOL_NAME = "_noop"
|
||||
NOOP_TOOL_DESCRIPTION = (
|
||||
"Do not call this tool. It exists only for API compatibility."
|
||||
)
|
||||
|
||||
|
||||
@tool(name_or_callable=NOOP_TOOL_NAME, description=NOOP_TOOL_DESCRIPTION)
|
||||
def noop_tool() -> str:
|
||||
"""Return empty content. Never expected to be called."""
|
||||
return ""
|
||||
|
||||
|
||||
# Provider markers that benefit from ``_noop`` injection. These match
|
||||
# opencode's gating list. We also accept any string containing one of
|
||||
# these substrings (so e.g. ``litellm`` matches ``ChatLiteLLM``).
|
||||
_NOOP_NEEDED_PROVIDERS: tuple[str, ...] = (
|
||||
"litellm",
|
||||
"bedrock",
|
||||
"copilot",
|
||||
)
|
||||
|
||||
|
||||
def _provider_needs_noop(model: Any) -> bool:
|
||||
"""Heuristic: does this model's provider need the _noop injection?"""
|
||||
try:
|
||||
ls_params = model._get_ls_params()
|
||||
provider = str(ls_params.get("ls_provider", "")).lower()
|
||||
except Exception:
|
||||
provider = ""
|
||||
|
||||
if not provider:
|
||||
cls_name = type(model).__name__.lower()
|
||||
provider = cls_name
|
||||
|
||||
return any(needle in provider for needle in _NOOP_NEEDED_PROVIDERS)
|
||||
|
||||
|
||||
def _last_ai_has_tool_calls(messages: list[Any]) -> bool:
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
return bool(msg.tool_calls)
|
||||
return False
|
||||
|
||||
|
||||
class NoopInjectionMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Inject the ``_noop`` tool only when the provider would otherwise 400.
|
||||
|
||||
The check fires per model call, not at agent build time, because the
|
||||
summarization path generates a no-tool subcall at runtime. The
|
||||
extra tool is appended to ``request.tools`` as an instance — the
|
||||
actual ``langchain_core.tools.BaseTool`` is bound on every call site
|
||||
that creates the agent.
|
||||
"""
|
||||
|
||||
def __init__(self, *, noop_tool_instance: Any | None = None) -> None:
|
||||
super().__init__()
|
||||
self._noop_tool = noop_tool_instance or noop_tool
|
||||
self.tools = []
|
||||
|
||||
def _should_inject(self, request: ModelRequest[ContextT]) -> bool:
|
||||
if request.tools:
|
||||
return False
|
||||
if not _last_ai_has_tool_calls(request.messages):
|
||||
return False
|
||||
return _provider_needs_noop(request.model)
|
||||
|
||||
def _augmented(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]:
|
||||
return request.override(tools=[self._noop_tool])
|
||||
|
||||
def wrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> Any:
|
||||
if self._should_inject(request):
|
||||
logger.debug("Injecting _noop tool for provider compatibility")
|
||||
return handler(self._augmented(request))
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
) -> Any:
|
||||
if self._should_inject(request):
|
||||
logger.debug("Injecting _noop tool for provider compatibility")
|
||||
return await handler(self._augmented(request))
|
||||
return await handler(request)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"NOOP_TOOL_DESCRIPTION",
|
||||
"NOOP_TOOL_NAME",
|
||||
"NoopInjectionMiddleware",
|
||||
"_provider_needs_noop",
|
||||
"noop_tool",
|
||||
]
|
||||
202
surfsense_backend/app/agents/new_chat/middleware/otel_span.py
Normal file
202
surfsense_backend/app/agents/new_chat/middleware/otel_span.py
Normal file
|
|
@ -0,0 +1,202 @@
|
|||
"""
|
||||
OpenTelemetry span middleware for the SurfSense ``new_chat`` agent.
|
||||
|
||||
Wraps both ``model.call`` (LLM invocations) and ``tool.call`` (tool
|
||||
executions) with OTel spans, attaching low-cardinality span names and
|
||||
high-cardinality identifiers as attributes (per the Tier 3b plan).
|
||||
|
||||
This middleware is intentionally a thin adapter over
|
||||
:mod:`app.observability.otel`; when OTel is not configured all spans
|
||||
collapse to no-ops and the wrapper adds <1µs overhead per call. When
|
||||
OTel **is** configured (``OTEL_EXPORTER_OTLP_ENDPOINT`` set), every
|
||||
model and tool call gets a span with the standard attributes the
|
||||
plan's dashboards expect.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.observability import otel as ot
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover — type-only
|
||||
from langchain.agents.middleware.types import (
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ToolCallRequest,
|
||||
)
|
||||
from langgraph.types import Command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OtelSpanMiddleware(AgentMiddleware):
|
||||
"""Emit ``model.call`` and ``tool.call`` OTel spans for every invocation.
|
||||
|
||||
Should be placed near the **outer** end of the middleware list so
|
||||
that the spans encompass retry/fallback wrapper effects (i.e. ``N``
|
||||
model.call spans for ``N`` retry attempts) but inside any concurrency/
|
||||
auth gate. Empirically this means **between** ``BusyMutex`` and
|
||||
``RetryAfter``.
|
||||
"""
|
||||
|
||||
def __init__(self, *, instrumentation_name: str = "surfsense.new_chat") -> None:
|
||||
super().__init__()
|
||||
self._instrumentation_name = instrumentation_name
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Model call spans
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[
|
||||
[ModelRequest], Awaitable[ModelResponse | AIMessage | Any]
|
||||
],
|
||||
) -> ModelResponse | AIMessage | Any:
|
||||
if not ot.is_enabled():
|
||||
return await handler(request)
|
||||
|
||||
model_id, provider = _resolve_model_attrs(request)
|
||||
with ot.model_call_span(model_id=model_id, provider=provider) as sp:
|
||||
try:
|
||||
result = await handler(request)
|
||||
except Exception:
|
||||
# span context manager records + re-raises
|
||||
raise
|
||||
else:
|
||||
_annotate_model_response(sp, result)
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool call spans
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[
|
||||
[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]
|
||||
],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
if not ot.is_enabled():
|
||||
return await handler(request)
|
||||
|
||||
tool_name = _resolve_tool_name(request)
|
||||
input_size = _resolve_input_size(request)
|
||||
|
||||
with ot.tool_call_span(tool_name, input_size=input_size) as sp:
|
||||
result = await handler(request)
|
||||
_annotate_tool_result(sp, result)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attribute helpers (kept defensive; we never want OTel bookkeeping to break
|
||||
# a real model/tool call).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_model_attrs(request: Any) -> tuple[str | None, str | None]:
|
||||
"""Extract ``model.id`` and ``model.provider`` from a ``ModelRequest``."""
|
||||
model_id: str | None = None
|
||||
provider: str | None = None
|
||||
try:
|
||||
model = getattr(request, "model", None)
|
||||
if model is None:
|
||||
return None, None
|
||||
# langchain BaseChatModel exposes a few different identifiers
|
||||
for attr in ("model_name", "model", "model_id"):
|
||||
value = getattr(model, attr, None)
|
||||
if value:
|
||||
model_id = str(value)
|
||||
break
|
||||
# provider sometimes lives on ``_llm_type`` (legacy) or ``provider``
|
||||
for attr in ("provider", "_llm_type"):
|
||||
value = getattr(model, attr, None)
|
||||
if value:
|
||||
provider = str(value)
|
||||
break
|
||||
except Exception: # pragma: no cover — defensive
|
||||
pass
|
||||
return model_id, provider
|
||||
|
||||
|
||||
def _resolve_tool_name(request: Any) -> str:
|
||||
try:
|
||||
tool = getattr(request, "tool", None)
|
||||
if tool is not None:
|
||||
name = getattr(tool, "name", None)
|
||||
if isinstance(name, str) and name:
|
||||
return name
|
||||
# Fall back to the tool_call dict
|
||||
call = getattr(request, "tool_call", None) or {}
|
||||
name = call.get("name") if isinstance(call, dict) else None
|
||||
if isinstance(name, str) and name:
|
||||
return name
|
||||
except Exception: # pragma: no cover — defensive
|
||||
pass
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _resolve_input_size(request: Any) -> int | None:
|
||||
try:
|
||||
call = getattr(request, "tool_call", None)
|
||||
if not isinstance(call, dict) or not call:
|
||||
return None
|
||||
args = call.get("args")
|
||||
if args is None:
|
||||
return None
|
||||
return len(repr(args))
|
||||
except Exception: # pragma: no cover — defensive
|
||||
return None
|
||||
|
||||
|
||||
def _annotate_model_response(span: Any, result: Any) -> None:
|
||||
"""Best-effort: attach prompt/completion token counts when available."""
|
||||
try:
|
||||
# ModelResponse may be a dataclass with .result containing AIMessage
|
||||
msg: Any
|
||||
if isinstance(result, AIMessage):
|
||||
msg = result
|
||||
else:
|
||||
inner = getattr(result, "result", None)
|
||||
msg = inner[-1] if isinstance(inner, list) and inner else inner
|
||||
if msg is None:
|
||||
return
|
||||
usage = getattr(msg, "usage_metadata", None) or {}
|
||||
if isinstance(usage, dict):
|
||||
if (n := usage.get("input_tokens")) is not None:
|
||||
span.set_attribute("tokens.prompt", int(n))
|
||||
if (n := usage.get("output_tokens")) is not None:
|
||||
span.set_attribute("tokens.completion", int(n))
|
||||
if (n := usage.get("total_tokens")) is not None:
|
||||
span.set_attribute("tokens.total", int(n))
|
||||
tool_calls = getattr(msg, "tool_calls", None) or []
|
||||
span.set_attribute("model.tool_calls", len(tool_calls))
|
||||
except Exception: # pragma: no cover — defensive
|
||||
pass
|
||||
|
||||
|
||||
def _annotate_tool_result(span: Any, result: Any) -> None:
|
||||
try:
|
||||
if isinstance(result, ToolMessage):
|
||||
content = result.content if isinstance(result.content, str) else repr(result.content)
|
||||
span.set_attribute("tool.output.size", len(content))
|
||||
status = getattr(result, "status", None)
|
||||
if isinstance(status, str):
|
||||
span.set_attribute("tool.status", status)
|
||||
kwargs = getattr(result, "additional_kwargs", None) or {}
|
||||
if isinstance(kwargs, dict) and kwargs.get("error"):
|
||||
span.set_attribute("tool.error", True)
|
||||
except Exception: # pragma: no cover — defensive
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["OtelSpanMiddleware"]
|
||||
344
surfsense_backend/app/agents/new_chat/middleware/permission.py
Normal file
344
surfsense_backend/app/agents/new_chat/middleware/permission.py
Normal file
|
|
@ -0,0 +1,344 @@
|
|||
"""
|
||||
PermissionMiddleware — pattern-based allow/deny/ask with HITL fallback.
|
||||
|
||||
Mirrors ``opencode/packages/opencode/src/permission/index.ts`` but uses
|
||||
SurfSense's existing ``interrupt({type, action, context})`` payload shape
|
||||
(see ``app/agents/new_chat/tools/hitl.py``) so the frontend keeps
|
||||
working unchanged. Tier 2.1 in the OpenCode-port plan.
|
||||
|
||||
Operation:
|
||||
1. ``aafter_model`` inspects the latest ``AIMessage.tool_calls``.
|
||||
2. For each call, the middleware builds a list of ``patterns`` (the
|
||||
tool name plus any tool-specific patterns from the resolver). It
|
||||
evaluates each pattern against the layered rulesets and aggregates
|
||||
the results: ``deny`` > ``ask`` > ``allow``.
|
||||
3. On ``deny``: replaces the call with a synthetic ``ToolMessage``
|
||||
containing a :class:`StreamingError`.
|
||||
4. On ``ask``: raises a SurfSense-style ``interrupt(...)``. The reply
|
||||
shape is ``{"decision_type": "once|always|reject", "feedback"?: str}``.
|
||||
- ``once``: proceed.
|
||||
- ``always``: also persist allow rules for ``request.always`` patterns.
|
||||
- ``reject`` w/o feedback: raise :class:`RejectedError`.
|
||||
- ``reject`` w/ feedback: raise :class:`CorrectedError`.
|
||||
5. On ``allow``: proceed unchanged.
|
||||
|
||||
The middleware also performs a *pre-model* tool-filter step (the
|
||||
``before_model`` hook) so globally denied tools are stripped from the
|
||||
exposed tool list before the model gets to see them. This is
|
||||
opencode's ``Permission.disabled`` equivalent and dramatically reduces
|
||||
the chance the model emits a deny-only call.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import interrupt
|
||||
|
||||
from app.agents.new_chat.errors import (
|
||||
CorrectedError,
|
||||
RejectedError,
|
||||
StreamingError,
|
||||
)
|
||||
from app.agents.new_chat.permissions import (
|
||||
Rule,
|
||||
Ruleset,
|
||||
aggregate_action,
|
||||
evaluate_many,
|
||||
)
|
||||
from app.observability import otel as ot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Mapping ``tool_name -> resolver`` that converts ``args`` to a list of
|
||||
# patterns to evaluate. The first pattern is conventionally the bare
|
||||
# tool name; later entries narrow down to specific resources.
|
||||
PatternResolver = Callable[[dict[str, Any]], list[str]]
|
||||
|
||||
|
||||
def _default_pattern_resolver(name: str) -> PatternResolver:
|
||||
def _resolve(args: dict[str, Any]) -> list[str]:
|
||||
# Bare name covers the default catch-all; primary-arg fallbacks
|
||||
# are best added per-tool by callers.
|
||||
del args
|
||||
return [name]
|
||||
|
||||
return _resolve
|
||||
|
||||
|
||||
class PermissionMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
"""Allow/deny/ask layer over the agent's tool calls.
|
||||
|
||||
Args:
|
||||
rulesets: Layered rulesets to evaluate. Earlier entries are
|
||||
overridden by later ones (last-match-wins). Typical layering:
|
||||
``defaults < global < space < thread < runtime_approved``.
|
||||
pattern_resolvers: Optional per-tool callables that return a list
|
||||
of patterns to evaluate. When a tool isn't listed, the bare
|
||||
tool name is used as the only pattern.
|
||||
runtime_ruleset: Mutable :class:`Ruleset` that the middleware
|
||||
extends in-place when the user replies ``"always"`` to an
|
||||
ask interrupt. Reused across all calls in the same agent
|
||||
instance so newly-allowed rules apply to subsequent calls.
|
||||
always_emit_interrupt_payload: If True, every ask uses the
|
||||
SurfSense interrupt wire format (default). Set False to
|
||||
disable interrupts and treat ``ask`` as ``deny`` for
|
||||
non-interactive deployments.
|
||||
"""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
rulesets: list[Ruleset] | None = None,
|
||||
pattern_resolvers: dict[str, PatternResolver] | None = None,
|
||||
runtime_ruleset: Ruleset | None = None,
|
||||
always_emit_interrupt_payload: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._static_rulesets: list[Ruleset] = list(rulesets or [])
|
||||
self._pattern_resolvers: dict[str, PatternResolver] = dict(
|
||||
pattern_resolvers or {}
|
||||
)
|
||||
self._runtime_ruleset: Ruleset = runtime_ruleset or Ruleset(
|
||||
origin="runtime_approved"
|
||||
)
|
||||
self._emit_interrupt = always_emit_interrupt_payload
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool-filter step (opencode `Permission.disabled` equivalent)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _globally_denied(self, tool_name: str) -> bool:
|
||||
"""Return True if a deny rule with no narrowing pattern matches."""
|
||||
rules = evaluate_many(tool_name, ["*"], *self._all_rulesets())
|
||||
return aggregate_action(rules) == "deny"
|
||||
|
||||
def _all_rulesets(self) -> list[Ruleset]:
|
||||
return [*self._static_rulesets, self._runtime_ruleset]
|
||||
|
||||
# NOTE: ``before_model`` filtering of the tools list is left to the
|
||||
# agent factory. This middleware only blocks at execution time — and
|
||||
# only via the rule-evaluator path, not by mutating ``request.tools``.
|
||||
# Mutating ``request.tools`` per-call would invalidate provider
|
||||
# prompt-cache prefixes (see Operational risks: prompt-cache regression).
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool-call evaluation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _resolve_patterns(self, tool_name: str, args: dict[str, Any]) -> list[str]:
|
||||
resolver = self._pattern_resolvers.get(
|
||||
tool_name, _default_pattern_resolver(tool_name)
|
||||
)
|
||||
try:
|
||||
patterns = resolver(args or {})
|
||||
except Exception:
|
||||
logger.exception("Pattern resolver for %s raised; using bare name", tool_name)
|
||||
patterns = [tool_name]
|
||||
if not patterns:
|
||||
patterns = [tool_name]
|
||||
return patterns
|
||||
|
||||
def _evaluate(
|
||||
self, tool_name: str, args: dict[str, Any]
|
||||
) -> tuple[str, list[str], list[Rule]]:
|
||||
patterns = self._resolve_patterns(tool_name, args)
|
||||
rules = evaluate_many(tool_name, patterns, *self._all_rulesets())
|
||||
action = aggregate_action(rules)
|
||||
return action, patterns, rules
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# HITL ask flow — SurfSense wire format
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _raise_interrupt(
|
||||
self,
|
||||
*,
|
||||
tool_name: str,
|
||||
args: dict[str, Any],
|
||||
patterns: list[str],
|
||||
rules: list[Rule],
|
||||
) -> dict[str, Any]:
|
||||
"""Block on user approval via SurfSense's ``interrupt`` shape."""
|
||||
if not self._emit_interrupt:
|
||||
return {"decision_type": "reject"}
|
||||
|
||||
# ``params`` (NOT ``args``) is what SurfSense's streaming
|
||||
# normalizer forwards. Other fields move into ``context``.
|
||||
payload = {
|
||||
"type": "permission_ask",
|
||||
"action": {"tool": tool_name, "params": args or {}},
|
||||
"context": {
|
||||
"patterns": patterns,
|
||||
"rules": [
|
||||
{
|
||||
"permission": r.permission,
|
||||
"pattern": r.pattern,
|
||||
"action": r.action,
|
||||
}
|
||||
for r in rules
|
||||
],
|
||||
# Rules of thumb for the frontend: surface the patterns
|
||||
# the user can promote to "always" with a single reply.
|
||||
"always": patterns,
|
||||
},
|
||||
}
|
||||
# Tier 3b: permission.asked + interrupt.raised spans (no-op when
|
||||
# OTel is disabled). Both fire here so dashboards can correlate
|
||||
# "we asked X" with "interrupt was actually delivered".
|
||||
with ot.permission_asked_span(
|
||||
permission=tool_name,
|
||||
pattern=patterns[0] if patterns else None,
|
||||
extra={"permission.patterns": list(patterns)},
|
||||
), ot.interrupt_span(interrupt_type="permission_ask"):
|
||||
decision = interrupt(payload)
|
||||
if isinstance(decision, dict):
|
||||
return decision
|
||||
# Tolerate a plain string reply ("once", "always", "reject")
|
||||
if isinstance(decision, str):
|
||||
return {"decision_type": decision}
|
||||
return {"decision_type": "reject"}
|
||||
|
||||
def _persist_always(
|
||||
self, tool_name: str, patterns: list[str]
|
||||
) -> None:
|
||||
"""Promote ``always`` reply into runtime allow rules.
|
||||
|
||||
Persistence to ``agent_permission_rules`` is done by the
|
||||
streaming layer (``stream_new_chat``) once it observes the
|
||||
``always`` reply — the middleware just keeps an in-memory
|
||||
copy so subsequent calls in the same stream see the rule.
|
||||
"""
|
||||
for pattern in patterns:
|
||||
self._runtime_ruleset.rules.append(
|
||||
Rule(permission=tool_name, pattern=pattern, action="allow")
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Synthesizing deny -> ToolMessage
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _deny_message(
|
||||
tool_call: dict[str, Any],
|
||||
rule: Rule,
|
||||
) -> ToolMessage:
|
||||
err = StreamingError(
|
||||
code="permission_denied",
|
||||
retryable=False,
|
||||
suggestion=(
|
||||
f"rule permission={rule.permission!r} pattern={rule.pattern!r} "
|
||||
f"blocked this call"
|
||||
),
|
||||
)
|
||||
return ToolMessage(
|
||||
content=(
|
||||
f"Permission denied: rule {rule.permission}/{rule.pattern} "
|
||||
f"blocked tool {tool_call.get('name')!r}."
|
||||
),
|
||||
tool_call_id=tool_call.get("id") or "",
|
||||
name=tool_call.get("name"),
|
||||
status="error",
|
||||
additional_kwargs={"error": err.model_dump()},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# The hook: aafter_model
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _process(
|
||||
self,
|
||||
state: AgentState,
|
||||
runtime: Runtime[Any],
|
||||
) -> dict[str, Any] | None:
|
||||
del runtime # unused
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage) or not last.tool_calls:
|
||||
return None
|
||||
|
||||
deny_messages: list[ToolMessage] = []
|
||||
kept_calls: list[dict[str, Any]] = []
|
||||
any_change = False
|
||||
|
||||
for raw in last.tool_calls:
|
||||
call = dict(raw) if isinstance(raw, dict) else {
|
||||
"name": getattr(raw, "name", None),
|
||||
"args": getattr(raw, "args", {}),
|
||||
"id": getattr(raw, "id", None),
|
||||
"type": "tool_call",
|
||||
}
|
||||
name = call.get("name") or ""
|
||||
args = call.get("args") or {}
|
||||
action, patterns, rules = self._evaluate(name, args)
|
||||
|
||||
if action == "deny":
|
||||
# Find the deny rule for the suggestion text
|
||||
deny_rule = next((r for r in rules if r.action == "deny"), rules[0])
|
||||
deny_messages.append(self._deny_message(call, deny_rule))
|
||||
any_change = True
|
||||
continue
|
||||
|
||||
if action == "ask":
|
||||
decision = self._raise_interrupt(
|
||||
tool_name=name, args=args, patterns=patterns, rules=rules
|
||||
)
|
||||
kind = str(decision.get("decision_type") or "reject").lower()
|
||||
if kind == "once":
|
||||
kept_calls.append(call)
|
||||
elif kind == "always":
|
||||
self._persist_always(name, patterns)
|
||||
kept_calls.append(call)
|
||||
elif kind == "reject":
|
||||
feedback = decision.get("feedback")
|
||||
if isinstance(feedback, str) and feedback.strip():
|
||||
raise CorrectedError(feedback, tool=name)
|
||||
raise RejectedError(tool=name, pattern=patterns[0] if patterns else None)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unknown permission decision %r; treating as reject", kind
|
||||
)
|
||||
raise RejectedError(tool=name)
|
||||
continue
|
||||
|
||||
# allow
|
||||
kept_calls.append(call)
|
||||
|
||||
if not any_change and len(kept_calls) == len(last.tool_calls):
|
||||
return None
|
||||
|
||||
updated = last.model_copy(update={"tool_calls": kept_calls})
|
||||
result_messages: list[Any] = [updated]
|
||||
if deny_messages:
|
||||
result_messages.extend(deny_messages)
|
||||
return {"messages": result_messages}
|
||||
|
||||
def after_model( # type: ignore[override]
|
||||
self, state: AgentState, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._process(state, runtime)
|
||||
|
||||
async def aafter_model( # type: ignore[override]
|
||||
self, state: AgentState, runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
return self._process(state, runtime)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PatternResolver",
|
||||
"PermissionMiddleware",
|
||||
]
|
||||
245
surfsense_backend/app/agents/new_chat/middleware/retry_after.py
Normal file
245
surfsense_backend/app/agents/new_chat/middleware/retry_after.py
Normal file
|
|
@ -0,0 +1,245 @@
|
|||
"""
|
||||
RetryAfterMiddleware — Header-aware retry with custom backoff and SSE eventing.
|
||||
|
||||
Why standalone instead of subclassing ``ModelRetryMiddleware``: the upstream
|
||||
class calls module-level ``calculate_delay`` inline (no overridable
|
||||
``_calculate_delay`` hook), so a subclass cannot inject Retry-After header
|
||||
delays without rewriting the loop. Tier 1.4 in the OpenCode-port plan.
|
||||
|
||||
Behaviour:
|
||||
- Extracts ``Retry-After`` / ``retry-after-ms`` from
|
||||
``litellm.exceptions.RateLimitError.response.headers`` (or any exception
|
||||
exposing a similar shape).
|
||||
- Sleeps ``max(exponential_backoff, header_delay)`` between retries.
|
||||
- Returns ``False`` from ``retry_on`` for ``ContextWindowExceededError`` /
|
||||
``ContextOverflowError`` so :class:`SurfSenseCompactionMiddleware` (or
|
||||
the LangChain summarization fallback path) handles those instead.
|
||||
- Emits ``surfsense.retrying`` via ``adispatch_custom_event`` on each retry
|
||||
so ``stream_new_chat`` can forward it to clients as an SSE event.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.callbacks import adispatch_custom_event, dispatch_custom_event
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Names of exception classes for which a retry would not help — context
|
||||
# overflow needs compaction, auth needs human intervention, etc. Detected
|
||||
# by class-name substring so we don't have to import LiteLLM/Anthropic
|
||||
# here (which would tie this module to optional deps).
|
||||
_NON_RETRYABLE_NAME_HINTS: tuple[str, ...] = (
|
||||
"ContextWindowExceeded",
|
||||
"ContextOverflow",
|
||||
"AuthenticationError",
|
||||
"InvalidRequestError",
|
||||
"PermissionDenied",
|
||||
"InvalidApiKey",
|
||||
"ContextLimit",
|
||||
)
|
||||
|
||||
|
||||
def _is_non_retryable(exc: BaseException) -> bool:
|
||||
name = type(exc).__name__
|
||||
return any(hint in name for hint in _NON_RETRYABLE_NAME_HINTS)
|
||||
|
||||
|
||||
def _extract_retry_after_seconds(exc: BaseException) -> float | None:
|
||||
"""Return seconds-to-wait suggested by the provider, if any.
|
||||
|
||||
Looks at ``exc.response.headers`` or ``exc.headers`` for the standard
|
||||
HTTP ``Retry-After`` header (in seconds) or its millisecond cousin
|
||||
``retry-after-ms`` (sometimes used by Anthropic / OpenAI). Falls back
|
||||
to a regex on the exception message for shapes like
|
||||
``"Please retry after 30s"``.
|
||||
"""
|
||||
headers: dict[str, Any] | None = None
|
||||
response = getattr(exc, "response", None)
|
||||
if response is not None:
|
||||
headers = getattr(response, "headers", None)
|
||||
if headers is None:
|
||||
headers = getattr(exc, "headers", None)
|
||||
|
||||
if isinstance(headers, dict):
|
||||
# Normalize keys to lowercase for case-insensitive matching
|
||||
norm = {str(k).lower(): v for k, v in headers.items()}
|
||||
ms = norm.get("retry-after-ms")
|
||||
if ms is not None:
|
||||
try:
|
||||
return float(ms) / 1000.0
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
seconds = norm.get("retry-after")
|
||||
if seconds is not None:
|
||||
try:
|
||||
return float(seconds)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
# Last resort: scan the message for "retry after Xs" or "X seconds"
|
||||
msg = str(exc)
|
||||
match = re.search(r"retry\s+after\s+([0-9]+(?:\.[0-9]+)?)", msg, re.IGNORECASE)
|
||||
if match:
|
||||
try:
|
||||
return float(match.group(1))
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _exponential_delay(
|
||||
attempt: int,
|
||||
*,
|
||||
initial_delay: float,
|
||||
backoff_factor: float,
|
||||
max_delay: float,
|
||||
jitter: bool,
|
||||
) -> float:
|
||||
"""Compute an exponential-backoff delay with optional ±25% jitter."""
|
||||
delay = initial_delay * (backoff_factor**attempt) if backoff_factor else initial_delay
|
||||
delay = min(delay, max_delay)
|
||||
if jitter and delay > 0:
|
||||
delay *= 1 + random.uniform(-0.25, 0.25)
|
||||
return max(delay, 0.0)
|
||||
|
||||
|
||||
class RetryAfterMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Retry middleware that honors provider-issued Retry-After hints.
|
||||
|
||||
Drop-in replacement for :class:`langchain.agents.middleware.ModelRetryMiddleware`
|
||||
when working with LiteLLM/Anthropic/OpenAI providers that surface
|
||||
rate-limit hints in headers. Always emits ``surfsense.retrying`` SSE
|
||||
events so the UI can show a friendly "rate limited, retrying in Xs"
|
||||
indicator.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum retries after the initial attempt (default 3).
|
||||
initial_delay: Initial backoff delay in seconds.
|
||||
backoff_factor: Exponential growth factor for backoff.
|
||||
max_delay: Cap on per-attempt delay in seconds.
|
||||
jitter: Whether to add ±25% jitter.
|
||||
retry_on: Optional callable that returns True for retryable
|
||||
exceptions. The default retries everything except known
|
||||
non-retryable classes (context overflow, auth, etc.).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_retries: int = 3,
|
||||
initial_delay: float = 1.0,
|
||||
backoff_factor: float = 2.0,
|
||||
max_delay: float = 60.0,
|
||||
jitter: bool = True,
|
||||
retry_on: Callable[[BaseException], bool] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.max_retries = max_retries
|
||||
self.initial_delay = initial_delay
|
||||
self.backoff_factor = backoff_factor
|
||||
self.max_delay = max_delay
|
||||
self.jitter = jitter
|
||||
self._retry_on: Callable[[BaseException], bool] = retry_on or (
|
||||
lambda exc: not _is_non_retryable(exc)
|
||||
)
|
||||
|
||||
def _should_retry(self, exc: BaseException) -> bool:
|
||||
try:
|
||||
return bool(self._retry_on(exc))
|
||||
except Exception:
|
||||
logger.exception("retry_on callable raised; defaulting to False")
|
||||
return False
|
||||
|
||||
def _delay_for_attempt(self, attempt: int, exc: BaseException) -> float:
|
||||
backoff = _exponential_delay(
|
||||
attempt,
|
||||
initial_delay=self.initial_delay,
|
||||
backoff_factor=self.backoff_factor,
|
||||
max_delay=self.max_delay,
|
||||
jitter=self.jitter,
|
||||
)
|
||||
header = _extract_retry_after_seconds(exc) or 0.0
|
||||
return max(backoff, header)
|
||||
|
||||
def wrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], ModelResponse[ResponseT]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception as exc:
|
||||
if not self._should_retry(exc) or attempt >= self.max_retries:
|
||||
raise
|
||||
delay = self._delay_for_attempt(attempt, exc)
|
||||
try:
|
||||
dispatch_custom_event(
|
||||
"surfsense.retrying",
|
||||
{
|
||||
"attempt": attempt + 1,
|
||||
"max_retries": self.max_retries,
|
||||
"delay_ms": int(delay * 1000),
|
||||
"reason": type(exc).__name__,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("dispatch_custom_event failed; suppressed", exc_info=True)
|
||||
if delay > 0:
|
||||
time.sleep(delay)
|
||||
# Unreachable
|
||||
raise RuntimeError("RetryAfterMiddleware: retry loop exited without resolution")
|
||||
|
||||
async def awrap_model_call( # type: ignore[override]
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
) -> ModelResponse[ResponseT] | AIMessage:
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return await handler(request)
|
||||
except Exception as exc:
|
||||
if not self._should_retry(exc) or attempt >= self.max_retries:
|
||||
raise
|
||||
delay = self._delay_for_attempt(attempt, exc)
|
||||
try:
|
||||
await adispatch_custom_event(
|
||||
"surfsense.retrying",
|
||||
{
|
||||
"attempt": attempt + 1,
|
||||
"max_retries": self.max_retries,
|
||||
"delay_ms": int(delay * 1000),
|
||||
"reason": type(exc).__name__,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"adispatch_custom_event failed; suppressed", exc_info=True
|
||||
)
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
raise RuntimeError("RetryAfterMiddleware: retry loop exited without resolution")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RetryAfterMiddleware",
|
||||
"_extract_retry_after_seconds",
|
||||
"_is_non_retryable",
|
||||
]
|
||||
|
|
@ -1,123 +0,0 @@
|
|||
"""Safe wrapper around deepagents' SummarizationMiddleware.
|
||||
|
||||
Upstream issue
|
||||
--------------
|
||||
`deepagents.middleware.summarization.SummarizationMiddleware._aoffload_to_backend`
|
||||
(and its sync counterpart) call
|
||||
``get_buffer_string(filtered_messages)`` before writing the evicted history
|
||||
to the backend file. In recent ``langchain-core`` versions, ``get_buffer_string``
|
||||
accesses ``m.text`` which iterates ``self.content`` — this raises
|
||||
``TypeError: 'NoneType' object is not iterable`` whenever an ``AIMessage``
|
||||
has ``content=None`` (common when a model returns *only* tool_calls, seen
|
||||
frequently with Azure OpenAI ``gpt-5.x`` responses streamed through
|
||||
LiteLLM).
|
||||
|
||||
The exception aborts the whole agent turn, so the user just sees "Error during
|
||||
chat" with no assistant response.
|
||||
|
||||
Fix
|
||||
---
|
||||
We subclass ``SummarizationMiddleware`` and override
|
||||
``_filter_summary_messages`` — the only call site that feeds messages into
|
||||
``get_buffer_string`` — to return *copies* of messages whose ``content`` is
|
||||
``None`` with ``content=""``. The originals flowing through the rest of the
|
||||
agent state are untouched.
|
||||
|
||||
We also expose a drop-in ``create_safe_summarization_middleware`` factory
|
||||
that mirrors ``deepagents.middleware.summarization.create_summarization_middleware``
|
||||
but instantiates our safe subclass.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from deepagents.middleware.summarization import (
|
||||
SummarizationMiddleware,
|
||||
compute_summarization_defaults,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deepagents.backends.protocol import BACKEND_TYPES
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AnyMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _sanitize_message_content(msg: AnyMessage) -> AnyMessage:
|
||||
"""Return ``msg`` with ``content`` coerced to a non-``None`` value.
|
||||
|
||||
``get_buffer_string`` reads ``m.text`` which iterates ``self.content``;
|
||||
when a provider streams back an ``AIMessage`` with only tool_calls and
|
||||
no text, ``content`` can be ``None`` and the iteration explodes. We
|
||||
replace ``None`` with an empty string so downstream consumers that only
|
||||
care about text see an empty body.
|
||||
|
||||
The original message is left untouched — we return a copy via
|
||||
pydantic's ``model_copy`` when available, otherwise we fall back to
|
||||
re-setting the attribute on a shallow copy.
|
||||
"""
|
||||
|
||||
if getattr(msg, "content", "not-missing") is not None:
|
||||
return msg
|
||||
|
||||
try:
|
||||
return msg.model_copy(update={"content": ""})
|
||||
except AttributeError:
|
||||
import copy
|
||||
|
||||
new_msg = copy.copy(msg)
|
||||
try:
|
||||
new_msg.content = ""
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logger.debug(
|
||||
"Could not sanitize content=None on message of type %s",
|
||||
type(msg).__name__,
|
||||
)
|
||||
return msg
|
||||
return new_msg
|
||||
|
||||
|
||||
class SafeSummarizationMiddleware(SummarizationMiddleware):
|
||||
"""`SummarizationMiddleware` that tolerates messages with ``content=None``.
|
||||
|
||||
Only ``_filter_summary_messages`` is overridden — this is the single
|
||||
helper invoked by both the sync and async offload paths immediately
|
||||
before ``get_buffer_string``. Normalising here means we get coverage
|
||||
for both without having to copy the (long, rapidly-changing) offload
|
||||
implementations from upstream.
|
||||
"""
|
||||
|
||||
def _filter_summary_messages(self, messages: list[AnyMessage]) -> list[AnyMessage]:
|
||||
filtered = super()._filter_summary_messages(messages)
|
||||
return [_sanitize_message_content(m) for m in filtered]
|
||||
|
||||
|
||||
def create_safe_summarization_middleware(
|
||||
model: BaseChatModel,
|
||||
backend: BACKEND_TYPES,
|
||||
) -> SafeSummarizationMiddleware:
|
||||
"""Drop-in replacement for ``create_summarization_middleware``.
|
||||
|
||||
Mirrors the defaults computed by ``deepagents`` but returns our
|
||||
``SafeSummarizationMiddleware`` subclass so the
|
||||
``content=None`` crash in ``get_buffer_string`` is avoided.
|
||||
"""
|
||||
|
||||
defaults = compute_summarization_defaults(model)
|
||||
return SafeSummarizationMiddleware(
|
||||
model=model,
|
||||
backend=backend,
|
||||
trigger=defaults["trigger"],
|
||||
keep=defaults["keep"],
|
||||
trim_tokens_to_summarize=None,
|
||||
truncate_args_settings=defaults["truncate_args_settings"],
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SafeSummarizationMiddleware",
|
||||
"create_safe_summarization_middleware",
|
||||
]
|
||||
|
|
@ -0,0 +1,332 @@
|
|||
"""Skills backends for SurfSense.
|
||||
|
||||
Implements two minimal :class:`deepagents.backends.protocol.BackendProtocol`
|
||||
subclasses tailored for use with :class:`deepagents.middleware.skills.SkillsMiddleware`.
|
||||
|
||||
The middleware only needs four methods to load skills from a backend:
|
||||
|
||||
* ``ls_info`` / ``als_info`` — list directories under a source path.
|
||||
* ``download_files`` / ``adownload_files`` — fetch ``SKILL.md`` bytes.
|
||||
|
||||
Other ``BackendProtocol`` methods (``read``/``write``/``edit``/``grep_raw`` …)
|
||||
default to ``NotImplementedError`` from the base class. They are never reached
|
||||
by the skills middleware because skill content is rendered into the system
|
||||
prompt at agent build time, not edited at runtime.
|
||||
|
||||
Two backends are provided:
|
||||
|
||||
* :class:`BuiltinSkillsBackend` — disk-backed read of bundled skills from
|
||||
``app/agents/new_chat/skills/builtin/``.
|
||||
* :class:`SearchSpaceSkillsBackend` — a thin read-only wrapper over
|
||||
:class:`KBPostgresBackend` that filters notes under the privileged folder
|
||||
``/documents/_skills/``.
|
||||
|
||||
Both backends are intentionally read-only: skill authoring happens out of band
|
||||
(via filesystem or a search-space-admin route), so we never expose
|
||||
``write`` / ``edit`` / ``upload_files``. The base class' ``NotImplementedError``
|
||||
gives a clean failure mode if anything tries.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import replace
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from deepagents.backends.composite import CompositeBackend
|
||||
from deepagents.backends.protocol import (
|
||||
BackendProtocol,
|
||||
FileDownloadResponse,
|
||||
FileInfo,
|
||||
)
|
||||
from deepagents.backends.state import StateBackend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from app.agents.new_chat.middleware.kb_postgres_backend import KBPostgresBackend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Limit per Agent Skills spec; matches deepagents.middleware.skills.MAX_SKILL_FILE_SIZE.
|
||||
_MAX_SKILL_FILE_SIZE = 10 * 1024 * 1024
|
||||
|
||||
|
||||
def _default_builtin_root() -> Path:
|
||||
"""Return the absolute path to the bundled builtin skills directory.
|
||||
|
||||
Located at ``app/agents/new_chat/skills/builtin/`` relative to this module.
|
||||
"""
|
||||
return (Path(__file__).resolve().parent.parent / "skills" / "builtin").resolve()
|
||||
|
||||
|
||||
class BuiltinSkillsBackend(BackendProtocol):
|
||||
"""Read-only disk-backed skills source.
|
||||
|
||||
Maps a virtual ``/skills/builtin/`` namespace onto a directory on local disk,
|
||||
where each skill is its own subdirectory containing a ``SKILL.md`` file::
|
||||
|
||||
<root>/<skill-name>/SKILL.md
|
||||
|
||||
The middleware calls :meth:`als_info` with the source path and expects a
|
||||
``list[FileInfo]`` whose ``is_dir=True`` entries are descended into. Then it
|
||||
calls :meth:`adownload_files` with the synthesized ``SKILL.md`` paths and
|
||||
parses YAML frontmatter from the returned ``content`` bytes.
|
||||
|
||||
Mounting under :class:`~deepagents.backends.composite.CompositeBackend` at
|
||||
prefix ``/skills/builtin/`` means the middleware can issue paths like
|
||||
``/skills/builtin/kb-research/SKILL.md`` which the composite strips down to
|
||||
``/kb-research/SKILL.md`` before forwarding here. We treat any leading
|
||||
slash as anchoring at :attr:`root`.
|
||||
"""
|
||||
|
||||
def __init__(self, root: Path | str | None = None) -> None:
|
||||
self.root: Path = Path(root).resolve() if root else _default_builtin_root()
|
||||
if not self.root.exists():
|
||||
logger.info(
|
||||
"BuiltinSkillsBackend root %s does not exist; skills will be empty.",
|
||||
self.root,
|
||||
)
|
||||
|
||||
def _resolve(self, path: str) -> Path:
|
||||
"""Resolve a virtual posix path under :attr:`root`, refusing escapes."""
|
||||
bare = path.lstrip("/")
|
||||
candidate = (self.root / bare).resolve() if bare else self.root
|
||||
# Refuse symlink/.. traversal that escapes the root.
|
||||
try:
|
||||
candidate.relative_to(self.root)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"path {path!r} escapes builtin skills root") from exc
|
||||
return candidate
|
||||
|
||||
def ls_info(self, path: str) -> list[FileInfo]:
|
||||
try:
|
||||
target = self._resolve(path)
|
||||
except ValueError as exc:
|
||||
logger.warning("BuiltinSkillsBackend.ls_info refused: %s", exc)
|
||||
return []
|
||||
if not target.exists() or not target.is_dir():
|
||||
return []
|
||||
|
||||
infos: list[FileInfo] = []
|
||||
# Build virtual paths anchored at "/" because CompositeBackend already
|
||||
# stripped the route prefix before calling us.
|
||||
target_virtual = "/" if target == self.root else (
|
||||
"/" + str(target.relative_to(self.root)).replace("\\", "/")
|
||||
)
|
||||
for child in sorted(target.iterdir()):
|
||||
child_virtual = (
|
||||
target_virtual.rstrip("/") + "/" + child.name
|
||||
if target_virtual != "/"
|
||||
else "/" + child.name
|
||||
)
|
||||
info: FileInfo = {
|
||||
"path": child_virtual,
|
||||
"is_dir": child.is_dir(),
|
||||
}
|
||||
if child.is_file():
|
||||
try:
|
||||
info["size"] = child.stat().st_size
|
||||
except OSError: # pragma: no cover - defensive
|
||||
pass
|
||||
infos.append(info)
|
||||
return infos
|
||||
|
||||
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||
responses: list[FileDownloadResponse] = []
|
||||
for p in paths:
|
||||
try:
|
||||
target = self._resolve(p)
|
||||
except ValueError:
|
||||
responses.append(FileDownloadResponse(path=p, error="invalid_path"))
|
||||
continue
|
||||
if not target.exists():
|
||||
responses.append(FileDownloadResponse(path=p, error="file_not_found"))
|
||||
continue
|
||||
if target.is_dir():
|
||||
responses.append(FileDownloadResponse(path=p, error="is_directory"))
|
||||
continue
|
||||
try:
|
||||
# Hard cap to avoid loading rogue mega-files into memory.
|
||||
size = target.stat().st_size
|
||||
if size > _MAX_SKILL_FILE_SIZE:
|
||||
logger.warning(
|
||||
"Builtin skill file %s exceeds %d bytes; truncating.",
|
||||
target,
|
||||
_MAX_SKILL_FILE_SIZE,
|
||||
)
|
||||
with target.open("rb") as fh:
|
||||
content = fh.read(_MAX_SKILL_FILE_SIZE)
|
||||
else:
|
||||
content = target.read_bytes()
|
||||
except PermissionError:
|
||||
responses.append(FileDownloadResponse(path=p, error="permission_denied"))
|
||||
continue
|
||||
except OSError as exc: # pragma: no cover - defensive
|
||||
logger.warning("Builtin skill read failed %s: %s", target, exc)
|
||||
responses.append(FileDownloadResponse(path=p, error="file_not_found"))
|
||||
continue
|
||||
responses.append(FileDownloadResponse(path=p, content=content, error=None))
|
||||
return responses
|
||||
|
||||
|
||||
class SearchSpaceSkillsBackend(BackendProtocol):
|
||||
"""Read-only view of search-space-authored skills.
|
||||
|
||||
Wraps a :class:`KBPostgresBackend` and only ever reads under the privileged
|
||||
folder ``/documents/_skills/`` (configurable). The folder is intended to be
|
||||
writable only by search-space admins; this backend never writes.
|
||||
|
||||
The skills middleware expects a layout like::
|
||||
|
||||
/<source_root>/<skill-name>/SKILL.md
|
||||
|
||||
But the KB stores documents like ``/documents/_skills/<name>/SKILL.md``.
|
||||
We expose the inner namespace by remapping each path. When mounted under
|
||||
:class:`CompositeBackend` at prefix ``/skills/space/`` the paths the
|
||||
middleware sees become ``/skills/space/<name>/SKILL.md``; the composite
|
||||
strips ``/skills/space/`` and hands us ``/<name>/SKILL.md``, which we
|
||||
rewrite to ``/documents/_skills/<name>/SKILL.md`` before forwarding to the
|
||||
KB.
|
||||
|
||||
No new database table is needed: the privileged folder convention is
|
||||
enforced server-side outside of this class. We intentionally swallow any
|
||||
write/edit attempts (the base class raises ``NotImplementedError``).
|
||||
"""
|
||||
|
||||
DEFAULT_KB_ROOT: str = "/documents/_skills"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kb_backend: KBPostgresBackend,
|
||||
*,
|
||||
kb_root: str = DEFAULT_KB_ROOT,
|
||||
) -> None:
|
||||
self._kb = kb_backend
|
||||
# Normalize trailing slash off so we can join cleanly.
|
||||
self._kb_root = kb_root.rstrip("/") or "/"
|
||||
|
||||
def _to_kb(self, path: str) -> str:
|
||||
"""Rewrite a virtual path into the underlying KB namespace."""
|
||||
bare = path.lstrip("/")
|
||||
if not bare:
|
||||
return self._kb_root
|
||||
return f"{self._kb_root}/{bare}"
|
||||
|
||||
def _from_kb(self, kb_path: str) -> str:
|
||||
"""Rewrite a KB path back into our virtual namespace."""
|
||||
if not kb_path.startswith(self._kb_root):
|
||||
return kb_path # pragma: no cover - defensive
|
||||
rel = kb_path[len(self._kb_root) :]
|
||||
return rel if rel.startswith("/") else "/" + rel
|
||||
|
||||
def ls_info(self, path: str) -> list[FileInfo]:
|
||||
# KBPostgresBackend exposes only the async API meaningfully; the sync
|
||||
# path falls back to ``asyncio.to_thread(...)`` in the base class. We
|
||||
# keep this stub to satisfy abstract resolution; the middleware calls
|
||||
# ``als_info``.
|
||||
raise NotImplementedError("SearchSpaceSkillsBackend is async-only")
|
||||
|
||||
async def als_info(self, path: str) -> list[FileInfo]:
|
||||
kb_path = self._to_kb(path)
|
||||
try:
|
||||
infos = await self._kb.als_info(kb_path)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.warning("SearchSpaceSkillsBackend.als_info failed: %s", exc)
|
||||
return []
|
||||
remapped: list[FileInfo] = []
|
||||
for info in infos:
|
||||
kb_p = info.get("path", "")
|
||||
if not kb_p.startswith(self._kb_root):
|
||||
continue
|
||||
remapped.append({**info, "path": self._from_kb(kb_p)})
|
||||
return remapped
|
||||
|
||||
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||
raise NotImplementedError("SearchSpaceSkillsBackend is async-only")
|
||||
|
||||
async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||
kb_paths = [self._to_kb(p) for p in paths]
|
||||
responses = await self._kb.adownload_files(kb_paths)
|
||||
# Re-map response paths back to the virtual namespace so the middleware
|
||||
# correlates them to the input list correctly.
|
||||
remapped: list[FileDownloadResponse] = []
|
||||
for original, resp in zip(paths, responses, strict=True):
|
||||
remapped.append(replace(resp, path=original))
|
||||
return remapped
|
||||
|
||||
|
||||
SKILLS_BUILTIN_PREFIX = "/skills/builtin/"
|
||||
SKILLS_SPACE_PREFIX = "/skills/space/"
|
||||
|
||||
|
||||
def build_skills_backend_factory(
|
||||
*,
|
||||
builtin_root: Path | str | None = None,
|
||||
search_space_id: int | None = None,
|
||||
) -> Callable[[ToolRuntime], BackendProtocol]:
|
||||
"""Return a runtime-aware factory for the skills :class:`CompositeBackend`.
|
||||
|
||||
When ``search_space_id`` is provided the composite includes a
|
||||
:class:`SearchSpaceSkillsBackend` route at ``/skills/space/`` over a fresh
|
||||
per-runtime :class:`KBPostgresBackend`, mirroring how
|
||||
:func:`build_backend_resolver` constructs the main filesystem backend.
|
||||
|
||||
When ``search_space_id`` is ``None`` (e.g., desktop-local mode or unit
|
||||
tests) only the bundled :class:`BuiltinSkillsBackend` is exposed.
|
||||
|
||||
Returning a factory rather than a fixed instance is intentional: the
|
||||
underlying KB backend depends on per-call ``ToolRuntime`` state
|
||||
(``staged_dirs``, ``files`` cache, runtime config), so a single shared
|
||||
instance cannot serve multiple concurrent agent runs.
|
||||
"""
|
||||
builtin = BuiltinSkillsBackend(builtin_root)
|
||||
|
||||
if search_space_id is None:
|
||||
def _factory_builtin_only(runtime: ToolRuntime) -> BackendProtocol:
|
||||
# Default StateBackend is intentionally inert: any path outside the
|
||||
# ``/skills/builtin/`` route resolves to an empty per-runtime state
|
||||
# so the SkillsMiddleware can iterate sources without raising.
|
||||
return CompositeBackend(
|
||||
default=StateBackend(runtime),
|
||||
routes={SKILLS_BUILTIN_PREFIX: builtin},
|
||||
)
|
||||
return _factory_builtin_only
|
||||
|
||||
def _factory_with_space(runtime: ToolRuntime) -> BackendProtocol:
|
||||
# Imported lazily to avoid a hard dependency at module import time:
|
||||
# ``KBPostgresBackend`` pulls in DB models, which are unnecessary for
|
||||
# the unit-tested builtin path.
|
||||
from app.agents.new_chat.middleware.kb_postgres_backend import (
|
||||
KBPostgresBackend,
|
||||
)
|
||||
|
||||
kb = KBPostgresBackend(search_space_id, runtime)
|
||||
space = SearchSpaceSkillsBackend(kb)
|
||||
return CompositeBackend(
|
||||
default=StateBackend(runtime),
|
||||
routes={
|
||||
SKILLS_BUILTIN_PREFIX: builtin,
|
||||
SKILLS_SPACE_PREFIX: space,
|
||||
},
|
||||
)
|
||||
|
||||
return _factory_with_space
|
||||
|
||||
|
||||
def default_skills_sources() -> list[str]:
|
||||
"""Return the canonical source list for SkillsMiddleware (built-in then space)."""
|
||||
return [SKILLS_BUILTIN_PREFIX, SKILLS_SPACE_PREFIX]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SKILLS_BUILTIN_PREFIX",
|
||||
"SKILLS_SPACE_PREFIX",
|
||||
"BuiltinSkillsBackend",
|
||||
"SearchSpaceSkillsBackend",
|
||||
"build_skills_backend_factory",
|
||||
"default_skills_sources",
|
||||
]
|
||||
|
|
@ -0,0 +1,190 @@
|
|||
"""
|
||||
ToolCallNameRepairMiddleware — two-stage tool-name repair.
|
||||
|
||||
Mirrors ``opencode/packages/opencode/src/session/llm.ts:339-358`` plus
|
||||
``opencode/packages/opencode/src/tool/invalid.ts``. Tier 1.7 in the
|
||||
OpenCode-port plan.
|
||||
|
||||
Operation:
|
||||
1. **Stage 1 — lowercase repair:** if a tool call's ``name`` is not in
|
||||
the registry but ``name.lower()`` is, rewrite in place. Catches
|
||||
models that emit ``Search`` instead of ``search``.
|
||||
2. **Stage 2 — invalid fallback:** if still unmatched, rewrite the call
|
||||
to ``invalid`` with ``args={"tool": original_name, "error": <error>}``
|
||||
so the registered :func:`invalid_tool` returns the error to the model
|
||||
for self-correction.
|
||||
|
||||
Distinct from :class:`deepagents.middleware.PatchToolCallsMiddleware`,
|
||||
which patches *dangling* tool calls (no matching ToolMessage) — that
|
||||
class does not handle the wrong-name case at all.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import difflib
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _coerce_existing_tool_call(call: Any) -> dict[str, Any]:
|
||||
"""Normalize a tool call entry to a mutable dict."""
|
||||
if isinstance(call, dict):
|
||||
return dict(call)
|
||||
return {
|
||||
"name": getattr(call, "name", None),
|
||||
"args": getattr(call, "args", {}),
|
||||
"id": getattr(call, "id", None),
|
||||
"type": "tool_call",
|
||||
}
|
||||
|
||||
|
||||
class ToolCallNameRepairMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Two-stage tool-name repair on the most recent ``AIMessage``.
|
||||
|
||||
Args:
|
||||
registered_tool_names: Set of canonically-registered tool names.
|
||||
``invalid`` should be in this set so the fallback dispatches.
|
||||
fuzzy_match_threshold: Optional ``difflib`` ratio (0–1) for the
|
||||
fuzzy-match step that runs *between* lowercase and invalid.
|
||||
Set to ``None`` to disable fuzzy matching (opencode parity).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
registered_tool_names: set[str],
|
||||
fuzzy_match_threshold: float | None = 0.85,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._registered = set(registered_tool_names)
|
||||
self._registered_lower = {name.lower(): name for name in self._registered}
|
||||
self._fuzzy_threshold = fuzzy_match_threshold
|
||||
self.tools = []
|
||||
|
||||
def _registered_for_runtime(self, runtime: Runtime[ContextT]) -> set[str]:
|
||||
"""Allow runtime overrides to expand the set (e.g. dynamic MCP tools)."""
|
||||
ctx_tools = getattr(runtime.context, "registered_tool_names", None)
|
||||
if isinstance(ctx_tools, (set, frozenset)):
|
||||
return self._registered | set(ctx_tools)
|
||||
if isinstance(ctx_tools, (list, tuple)):
|
||||
return self._registered | set(ctx_tools)
|
||||
return self._registered
|
||||
|
||||
def _repair_one(
|
||||
self,
|
||||
call: dict[str, Any],
|
||||
registered: set[str],
|
||||
) -> dict[str, Any]:
|
||||
name = call.get("name")
|
||||
if not isinstance(name, str):
|
||||
return call
|
||||
|
||||
if name in registered:
|
||||
return call
|
||||
|
||||
# Stage 1 — lowercase
|
||||
lowered = name.lower()
|
||||
if lowered in registered:
|
||||
call["name"] = lowered
|
||||
metadata = dict(call.get("response_metadata") or {})
|
||||
metadata.setdefault("repair", "lowercase")
|
||||
call["response_metadata"] = metadata
|
||||
return call
|
||||
|
||||
# Optional fuzzy step (off by default for opencode parity)
|
||||
if self._fuzzy_threshold is not None:
|
||||
close = difflib.get_close_matches(
|
||||
name, registered, n=1, cutoff=self._fuzzy_threshold
|
||||
)
|
||||
if close:
|
||||
call["name"] = close[0]
|
||||
metadata = dict(call.get("response_metadata") or {})
|
||||
metadata.setdefault("repair", f"fuzzy:{name}->{close[0]}")
|
||||
call["response_metadata"] = metadata
|
||||
return call
|
||||
|
||||
# Stage 2 — invalid fallback
|
||||
if INVALID_TOOL_NAME in registered:
|
||||
original_args = call.get("args") or {}
|
||||
error_msg = (
|
||||
f"Tool name '{name}' is not registered. "
|
||||
f"Original arguments were: {original_args!r}."
|
||||
)
|
||||
call["name"] = INVALID_TOOL_NAME
|
||||
call["args"] = {"tool": name, "error": error_msg}
|
||||
metadata = dict(call.get("response_metadata") or {})
|
||||
metadata.setdefault("repair", f"invalid_fallback:{name}")
|
||||
call["response_metadata"] = metadata
|
||||
else:
|
||||
logger.warning(
|
||||
"Could not repair unknown tool call %r; 'invalid' tool not registered",
|
||||
name,
|
||||
)
|
||||
return call
|
||||
|
||||
def _maybe_repair(
|
||||
self,
|
||||
message: AIMessage,
|
||||
registered: set[str],
|
||||
) -> AIMessage | None:
|
||||
if not message.tool_calls:
|
||||
return None
|
||||
|
||||
new_calls: list[dict[str, Any]] = []
|
||||
any_changed = False
|
||||
for raw in message.tool_calls:
|
||||
call = _coerce_existing_tool_call(raw)
|
||||
before = (call.get("name"), call.get("args"))
|
||||
repaired = self._repair_one(call, registered)
|
||||
after = (repaired.get("name"), repaired.get("args"))
|
||||
if before != after:
|
||||
any_changed = True
|
||||
new_calls.append(repaired)
|
||||
|
||||
if not any_changed:
|
||||
return None
|
||||
|
||||
return message.model_copy(update={"tool_calls": new_calls})
|
||||
|
||||
def after_model( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
messages = state.get("messages") or []
|
||||
if not messages:
|
||||
return None
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage):
|
||||
return None
|
||||
|
||||
registered = self._registered_for_runtime(runtime)
|
||||
repaired = self._maybe_repair(last, registered)
|
||||
if repaired is None:
|
||||
return None
|
||||
return {"messages": [repaired]}
|
||||
|
||||
async def aafter_model( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState[ResponseT],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ToolCallNameRepairMiddleware",
|
||||
]
|
||||
204
surfsense_backend/app/agents/new_chat/permissions.py
Normal file
204
surfsense_backend/app/agents/new_chat/permissions.py
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
"""
|
||||
Wildcard pattern matching + rule evaluation for the SurfSense permission system.
|
||||
|
||||
Mirrors ``opencode/packages/opencode/src/permission/evaluate.ts`` and
|
||||
``opencode/packages/opencode/src/util/wildcard.ts`` precisely:
|
||||
|
||||
- ``Wildcard.match`` matches both the ``permission`` and the ``pattern``
|
||||
fields of a rule against the requested ``(permission, pattern)`` pair.
|
||||
``*`` matches any segment, ``**`` matches across separators.
|
||||
- The evaluator runs ``findLast`` over the **flattened** list of rules
|
||||
from all rulesets — last matching rule wins.
|
||||
- The default fallback is ``ask`` (NOT deny), matching opencode.
|
||||
- Multi-pattern requests AND together: if ANY pattern resolves to
|
||||
``deny``, the whole request is denied; if ANY needs ``ask``, an
|
||||
interrupt is raised; only when all patterns ``allow`` does the
|
||||
request proceed.
|
||||
|
||||
Tier 2.1 in the OpenCode-port plan.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
RuleAction = Literal["allow", "deny", "ask"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Rule:
|
||||
"""A single permission rule.
|
||||
|
||||
Attributes:
|
||||
permission: A wildcard-matched permission identifier
|
||||
(e.g. ``"edit"``, ``"linear_*"``, ``"mcp:*"``,
|
||||
``"doom_loop"``). Anchored at start AND end of the input.
|
||||
pattern: A wildcard-matched pattern over the request payload
|
||||
(e.g. ``"/documents/secrets/**"``, ``"page_id=123"``,
|
||||
``"*"``). Anchored at start AND end.
|
||||
action: One of ``"allow"`` / ``"deny"`` / ``"ask"``.
|
||||
"""
|
||||
|
||||
permission: str
|
||||
pattern: str
|
||||
action: RuleAction
|
||||
|
||||
|
||||
@dataclass
|
||||
class Ruleset:
|
||||
"""A list of rules with an associated origin used for debugging."""
|
||||
|
||||
rules: list[Rule] = field(default_factory=list)
|
||||
origin: str = "unknown" # e.g. "defaults", "global", "space", "thread", "runtime"
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Wildcard matcher
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
_GLOB_TOKEN = re.compile(r"\*\*|\*|[^*]+")
|
||||
|
||||
|
||||
def _wildcard_to_regex(pattern: str) -> re.Pattern[str]:
|
||||
"""Translate an opencode-style wildcard pattern to a compiled regex.
|
||||
|
||||
Rules:
|
||||
- ``**`` matches any sequence of any characters (including separators).
|
||||
- ``*`` matches any sequence of characters that does **not** include
|
||||
the path separator ``/`` — same as glob.
|
||||
- All other characters match literally.
|
||||
- The pattern is anchored at both ends (``^...$``).
|
||||
"""
|
||||
parts: list[str] = ["^"]
|
||||
for token in _GLOB_TOKEN.findall(pattern):
|
||||
if token == "**":
|
||||
parts.append(r".*")
|
||||
elif token == "*":
|
||||
parts.append(r"[^/]*")
|
||||
else:
|
||||
parts.append(re.escape(token))
|
||||
parts.append("$")
|
||||
return re.compile("".join(parts))
|
||||
|
||||
|
||||
_REGEX_CACHE: dict[str, re.Pattern[str]] = {}
|
||||
|
||||
|
||||
def wildcard_match(value: str, pattern: str) -> bool:
|
||||
"""Return True if ``value`` matches the wildcard ``pattern``.
|
||||
|
||||
Special case: a bare ``"*"`` pattern matches any value, including
|
||||
those containing ``/`` separators. This mirrors opencode's
|
||||
``Wildcard.match`` short-circuit and matches the convention that
|
||||
``pattern="*"`` means "any pattern" in permission rules.
|
||||
"""
|
||||
if pattern == "*":
|
||||
return True
|
||||
compiled = _REGEX_CACHE.get(pattern)
|
||||
if compiled is None:
|
||||
compiled = _wildcard_to_regex(pattern)
|
||||
_REGEX_CACHE[pattern] = compiled
|
||||
return compiled.match(value) is not None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Evaluator
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def evaluate(
|
||||
permission: str,
|
||||
pattern: str,
|
||||
*rulesets: Ruleset | Iterable[Rule],
|
||||
) -> Rule:
|
||||
"""Find the last rule matching ``(permission, pattern)`` from ``rulesets``.
|
||||
|
||||
Mirrors opencode ``permission/evaluate.ts:9-15`` precisely:
|
||||
- Flatten rulesets in argument order.
|
||||
- Walk the flat list **in reverse**.
|
||||
- First reverse-match wins (i.e. the last specified rule wins).
|
||||
- When no rule matches, default to ``Rule(permission, "*", "ask")``.
|
||||
|
||||
Args:
|
||||
permission: The permission identifier being requested
|
||||
(e.g. tool name, ``"edit"``, ``"doom_loop"``).
|
||||
pattern: The request-specific pattern (e.g. file path,
|
||||
primary arg value). Use ``"*"`` when no specific pattern
|
||||
applies.
|
||||
*rulesets: Layered rulesets, applied earliest to latest. Later
|
||||
rulesets override earlier ones.
|
||||
|
||||
Returns:
|
||||
The matched :class:`Rule`, or the default ask fallback.
|
||||
"""
|
||||
flat: list[Rule] = []
|
||||
for rs in rulesets:
|
||||
if isinstance(rs, Ruleset):
|
||||
flat.extend(rs.rules)
|
||||
else:
|
||||
flat.extend(rs)
|
||||
|
||||
for rule in reversed(flat):
|
||||
if wildcard_match(permission, rule.permission) and wildcard_match(
|
||||
pattern, rule.pattern
|
||||
):
|
||||
return rule
|
||||
|
||||
return Rule(permission=permission, pattern="*", action="ask")
|
||||
|
||||
|
||||
def evaluate_many(
|
||||
permission: str,
|
||||
patterns: Iterable[str],
|
||||
*rulesets: Ruleset | Iterable[Rule],
|
||||
) -> list[Rule]:
|
||||
"""Evaluate ``permission`` against each of ``patterns`` (multi-pattern AND).
|
||||
|
||||
Returns the list of resolved rules in the same order as ``patterns``.
|
||||
The caller is responsible for combining the results — opencode-style
|
||||
multi-pattern AND collapses ``deny`` first, then ``ask``, then
|
||||
``allow``.
|
||||
"""
|
||||
return [evaluate(permission, p, *rulesets) for p in patterns]
|
||||
|
||||
|
||||
def aggregate_action(rules: Iterable[Rule]) -> RuleAction:
|
||||
"""Collapse a list of per-pattern rules into one action.
|
||||
|
||||
Order:
|
||||
1. If any rule is ``deny`` -> ``deny``.
|
||||
2. Else if any rule is ``ask`` -> ``ask``.
|
||||
3. Else if at least one rule is ``allow`` -> ``allow``.
|
||||
4. Else (empty input) -> ``ask`` (safe default mirroring ``evaluate``).
|
||||
|
||||
Mirrors opencode's behavior in ``permission/index.ts:180-272``.
|
||||
"""
|
||||
saw_ask = False
|
||||
saw_allow = False
|
||||
for rule in rules:
|
||||
if rule.action == "deny":
|
||||
return "deny"
|
||||
if rule.action == "ask":
|
||||
saw_ask = True
|
||||
elif rule.action == "allow":
|
||||
saw_allow = True
|
||||
if saw_ask:
|
||||
return "ask"
|
||||
if saw_allow:
|
||||
return "allow"
|
||||
return "ask"
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Rule",
|
||||
"RuleAction",
|
||||
"Ruleset",
|
||||
"aggregate_action",
|
||||
"evaluate",
|
||||
"evaluate_many",
|
||||
"wildcard_match",
|
||||
]
|
||||
157
surfsense_backend/app/agents/new_chat/plugin_loader.py
Normal file
157
surfsense_backend/app/agents/new_chat/plugin_loader.py
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
"""Entry-point based plugin loader for SurfSense agent middleware.
|
||||
|
||||
The realization in the Tier 6 plan: LangChain's :class:`AgentMiddleware` ABC
|
||||
already covers the practical surface most plugins need (``before_agent`` /
|
||||
``before_model`` / ``wrap_tool_call`` / their async counterparts), so a
|
||||
SurfSense-specific plugin protocol is unnecessary.
|
||||
|
||||
A plugin is therefore just an installable Python package that registers a
|
||||
factory callable under the ``surfsense.plugins`` entry-point group:
|
||||
|
||||
.. code-block:: toml
|
||||
|
||||
# in a plugin package's pyproject.toml
|
||||
[project.entry-points."surfsense.plugins"]
|
||||
year_substituter = "my_plugin:make_middleware"
|
||||
|
||||
The factory has the signature ``Callable[[PluginContext], AgentMiddleware]``.
|
||||
It receives a small, sanitized :class:`PluginContext` with the IDs and the
|
||||
LLM the plugin is allowed to talk to — and **never** raw secrets, DB
|
||||
sessions, or other connectors.
|
||||
|
||||
## Trust model
|
||||
|
||||
Plugins are loaded **only if** their entry-point ``name`` appears in
|
||||
``allowed_plugins`` (admin-controlled, sourced from
|
||||
``global_llm_config.yaml`` or :func:`load_allowed_plugin_names_from_env`).
|
||||
There is **no env-driven auto-load**. A plugin failure is logged and
|
||||
isolated; it does not break agent construction.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from importlib.metadata import entry_points
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - type-only
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.db import ChatVisibility
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
PLUGIN_ENTRY_POINT_GROUP = "surfsense.plugins"
|
||||
|
||||
|
||||
class PluginContext(dict):
|
||||
"""Sanitized DI bag handed to each plugin factory.
|
||||
|
||||
Backed by ``dict`` so plugins can inspect the keys they care about
|
||||
without coupling to a concrete dataclass shape. Required keys:
|
||||
|
||||
* ``search_space_id`` (int)
|
||||
* ``user_id`` (str | None)
|
||||
* ``thread_visibility`` (:class:`app.db.ChatVisibility`)
|
||||
* ``llm`` (:class:`langchain_core.language_models.BaseChatModel`)
|
||||
|
||||
The context **never** carries DB sessions, raw secrets, or other
|
||||
connectors. If a future plugin genuinely needs DB access, that
|
||||
integration goes through a rate-limited service interface, not
|
||||
through this bag.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
*,
|
||||
search_space_id: int,
|
||||
user_id: str | None,
|
||||
thread_visibility: ChatVisibility,
|
||||
llm: BaseChatModel,
|
||||
) -> PluginContext:
|
||||
return cls(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
thread_visibility=thread_visibility,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
|
||||
def load_plugin_middlewares(
|
||||
ctx: PluginContext,
|
||||
allowed_plugin_names: Iterable[str],
|
||||
) -> list[AgentMiddleware]:
|
||||
"""Discover, allowlist-filter, and instantiate plugin middleware.
|
||||
|
||||
For each entry-point in :data:`PLUGIN_ENTRY_POINT_GROUP` whose name is
|
||||
in ``allowed_plugin_names``, load the factory and call it with ``ctx``.
|
||||
The factory's return value must be an :class:`AgentMiddleware` instance;
|
||||
anything else is logged and skipped.
|
||||
|
||||
Errors are isolated — a plugin that raises during ``ep.load()`` or
|
||||
factory invocation is logged at ``ERROR`` and ignored. Agent
|
||||
construction continues with whatever plugins did succeed.
|
||||
"""
|
||||
allowed = {name for name in allowed_plugin_names if name}
|
||||
if not allowed:
|
||||
return []
|
||||
|
||||
out: list[AgentMiddleware] = []
|
||||
try:
|
||||
eps = entry_points(group=PLUGIN_ENTRY_POINT_GROUP)
|
||||
except Exception: # pragma: no cover - defensive (entry_points is robust)
|
||||
logger.exception("Failed to enumerate plugin entry points")
|
||||
return []
|
||||
|
||||
for ep in eps:
|
||||
if ep.name not in allowed:
|
||||
logger.info("Skipping non-allowlisted plugin %s", ep.name)
|
||||
continue
|
||||
try:
|
||||
factory = ep.load()
|
||||
except Exception:
|
||||
logger.exception("Failed to load plugin %s", ep.name)
|
||||
continue
|
||||
try:
|
||||
mw = factory(ctx)
|
||||
except Exception:
|
||||
logger.exception("Plugin %s factory raised", ep.name)
|
||||
continue
|
||||
if not isinstance(mw, AgentMiddleware):
|
||||
logger.warning(
|
||||
"Plugin %s returned %s, expected AgentMiddleware; skipping",
|
||||
ep.name,
|
||||
type(mw).__name__,
|
||||
)
|
||||
continue
|
||||
out.append(mw)
|
||||
logger.info("Loaded plugin %s as %s", ep.name, type(mw).__name__)
|
||||
return out
|
||||
|
||||
|
||||
def load_allowed_plugin_names_from_env() -> set[str]:
|
||||
"""Read ``SURFSENSE_ALLOWED_PLUGINS`` (comma-separated) into a set.
|
||||
|
||||
Provided as a thin convenience for deployments that don't surface plugins
|
||||
through ``global_llm_config.yaml`` yet. Whitespace is stripped and empty
|
||||
entries are dropped.
|
||||
"""
|
||||
raw = os.environ.get("SURFSENSE_ALLOWED_PLUGINS", "").strip()
|
||||
if not raw:
|
||||
return set()
|
||||
return {token.strip() for token in raw.split(",") if token.strip()}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PLUGIN_ENTRY_POINT_GROUP",
|
||||
"PluginContext",
|
||||
"load_allowed_plugin_names_from_env",
|
||||
"load_plugin_middlewares",
|
||||
]
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
"""Reference plugins bundled with SurfSense.
|
||||
|
||||
These plugins are intentionally small and demonstrative. They are NOT
|
||||
auto-loaded — they ship as examples that a deployment can opt into via
|
||||
``global_llm_config.yaml`` or ``SURFSENSE_ALLOWED_PLUGINS``.
|
||||
"""
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
"""Reference plugin: substitute ``{{year}}`` in tool descriptions.
|
||||
|
||||
Mirrors the OpenCode ``chat.system.transform`` example. Demonstrates the
|
||||
:meth:`AgentMiddleware.awrap_tool_call` hook -- the plugin sees every tool
|
||||
invocation and can rewrite the request *or* the result. This particular
|
||||
plugin is read-only and only transforms the *description* the user might
|
||||
see in error messages (no request mutation).
|
||||
|
||||
The plugin is built as a factory function so the entry-point loader can
|
||||
inject :class:`PluginContext` (containing the agent's LLM, search-space
|
||||
ID, etc.). The factory signature
|
||||
``Callable[[PluginContext], AgentMiddleware]`` is the only contract --
|
||||
SurfSense doesn't define a custom plugin protocol on top of LangChain's
|
||||
:class:`AgentMiddleware`.
|
||||
|
||||
Wire-up in ``pyproject.toml`` (illustrative; the in-repo plugin doesn't
|
||||
need this -- it's already on the import path)::
|
||||
|
||||
[project.entry-points."surfsense.plugins"]
|
||||
year_substituter = "app.agents.new_chat.plugins.year_substituter:make_middleware"
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - type-only
|
||||
from langchain.agents.middleware.types import ToolCallRequest
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
from app.agents.new_chat.plugin_loader import PluginContext
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _YearSubstituterMiddleware(AgentMiddleware):
|
||||
"""Replace ``{{year}}`` in the result text with the current UTC year."""
|
||||
|
||||
tools = ()
|
||||
|
||||
def __init__(self, year: int | None = None) -> None:
|
||||
super().__init__()
|
||||
self._year = str(year if year is not None else datetime.now(UTC).year)
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[
|
||||
[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]
|
||||
],
|
||||
) -> ToolMessage | Command[Any]:
|
||||
result = await handler(request)
|
||||
try:
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
if isinstance(result, ToolMessage) and isinstance(result.content, str):
|
||||
if "{{year}}" in result.content:
|
||||
new_text = result.content.replace("{{year}}", self._year)
|
||||
result = ToolMessage(
|
||||
content=new_text,
|
||||
tool_call_id=result.tool_call_id,
|
||||
id=result.id,
|
||||
name=result.name,
|
||||
status=result.status,
|
||||
artifact=result.artifact,
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive
|
||||
logger.exception("year_substituter plugin failed; passing original result")
|
||||
return result
|
||||
|
||||
|
||||
def make_middleware(ctx: PluginContext) -> AgentMiddleware:
|
||||
"""Plugin factory used by :func:`load_plugin_middlewares`."""
|
||||
# Plugin is intentionally small so it has no state to threading-protect
|
||||
# and ignores ``ctx`` beyond demonstrating that the loader passes it in.
|
||||
_ = ctx
|
||||
return _YearSubstituterMiddleware()
|
||||
|
||||
|
||||
__all__ = ["make_middleware"]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
"""SurfSense agent prompt fragments.
|
||||
|
||||
The prompt is composed at runtime by :mod:`composer` from the markdown
|
||||
fragments under ``base/``, ``providers/``, ``tools/``, ``examples/``, and
|
||||
``routing/``. ``system_prompt.py`` is now a thin wrapper that delegates
|
||||
to :func:`composer.compose_system_prompt`.
|
||||
"""
|
||||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
You are SurfSense, a reasoning and acting AI agent designed to answer user questions using the user's personal knowledge base.
|
||||
|
||||
Today's date (UTC): {resolved_today}
|
||||
|
||||
When writing mathematical formulas or equations, ALWAYS use LaTeX notation. NEVER use backtick code spans or Unicode symbols for math.
|
||||
|
||||
NEVER expose internal tool parameter names, backend IDs, or implementation details to the user. Always use natural, user-friendly language instead.
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
You are SurfSense, a reasoning and acting AI agent designed to answer questions in this team space using the team's shared knowledge base.
|
||||
|
||||
In this team thread, each message is prefixed with **[DisplayName of the author]**. Use this to attribute and reference the author of anything in the discussion (who asked a question, made a suggestion, or contributed an idea) and to cite who said what in your answers.
|
||||
|
||||
Today's date (UTC): {resolved_today}
|
||||
|
||||
When writing mathematical formulas or equations, ALWAYS use LaTeX notation. NEVER use backtick code spans or Unicode symbols for math.
|
||||
|
||||
NEVER expose internal tool parameter names, backend IDs, or implementation details to the user. Always use natural, user-friendly language instead.
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
<citation_instructions>
|
||||
IMPORTANT: Citations are DISABLED for this configuration.
|
||||
|
||||
DO NOT include any citations in your responses. Specifically:
|
||||
1. Do NOT use the [citation:chunk_id] format anywhere in your response.
|
||||
2. Do NOT reference document IDs, chunk IDs, or source IDs.
|
||||
3. Simply provide the information naturally without any citation markers.
|
||||
4. Write your response as if you're having a normal conversation, incorporating the information from your knowledge seamlessly.
|
||||
|
||||
When answering questions based on documents from the knowledge base:
|
||||
- Present the information directly and confidently
|
||||
- Do not mention that information comes from specific documents or chunks
|
||||
- Integrate facts naturally into your response without attribution markers
|
||||
|
||||
Your goal is to provide helpful, informative answers in a clean, readable format without any citation notation.
|
||||
</citation_instructions>
|
||||
|
|
@ -0,0 +1,90 @@
|
|||
<citation_instructions>
|
||||
CRITICAL CITATION REQUIREMENTS:
|
||||
|
||||
1. For EVERY piece of information you include from the documents, add a citation in the format [citation:chunk_id] where chunk_id is the exact value from the `<chunk id='...'>` tag inside `<document_content>`.
|
||||
2. Make sure ALL factual statements from the documents have proper citations.
|
||||
3. If multiple chunks support the same point, include all relevant citations [citation:chunk_id1], [citation:chunk_id2].
|
||||
4. You MUST use the exact chunk_id values from the `<chunk id='...'>` attributes. Do not create your own citation numbers.
|
||||
5. Every citation MUST be in the format [citation:chunk_id] where chunk_id is the exact chunk id value.
|
||||
6. Never modify or change the chunk_id - always use the original values exactly as provided in the chunk tags.
|
||||
7. Do not return citations as clickable links.
|
||||
8. Never format citations as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only.
|
||||
9. Citations must ONLY appear as [citation:chunk_id] or [citation:chunk_id1], [citation:chunk_id2] format - never with parentheses, hyperlinks, or other formatting.
|
||||
10. Never make up chunk IDs. Only use chunk_id values that are explicitly provided in the `<chunk id='...'>` tags.
|
||||
11. If you are unsure about a chunk_id, do not include a citation rather than guessing or making one up.
|
||||
|
||||
<document_structure_example>
|
||||
The documents you receive are structured like this:
|
||||
|
||||
**Knowledge base documents (numeric chunk IDs):**
|
||||
<document>
|
||||
<document_metadata>
|
||||
<document_id>42</document_id>
|
||||
<document_type>GITHUB_CONNECTOR</document_type>
|
||||
<title><![CDATA[Some repo / file / issue title]]></title>
|
||||
<url><![CDATA[https://example.com]]></url>
|
||||
<metadata_json><![CDATA[{{"any":"other metadata"}}]]></metadata_json>
|
||||
</document_metadata>
|
||||
|
||||
<document_content>
|
||||
<chunk id='123'><![CDATA[First chunk text...]]></chunk>
|
||||
<chunk id='124'><![CDATA[Second chunk text...]]></chunk>
|
||||
</document_content>
|
||||
</document>
|
||||
|
||||
**Web search results (URL chunk IDs):**
|
||||
<document>
|
||||
<document_metadata>
|
||||
<document_type>WEB_SEARCH</document_type>
|
||||
<title><![CDATA[Some web search result]]></title>
|
||||
<url><![CDATA[https://example.com/article]]></url>
|
||||
</document_metadata>
|
||||
|
||||
<document_content>
|
||||
<chunk id='https://example.com/article'><![CDATA[Content from web search...]]></chunk>
|
||||
</document_content>
|
||||
</document>
|
||||
|
||||
IMPORTANT: You MUST cite using the EXACT chunk ids from the `<chunk id='...'>` tags.
|
||||
- For knowledge base documents, chunk ids are numeric (e.g. 123, 124) or prefixed (e.g. doc-45).
|
||||
- For live web search results, chunk ids are URLs (e.g. https://example.com/article).
|
||||
Do NOT cite document_id. Always use the chunk id.
|
||||
</document_structure_example>
|
||||
|
||||
<citation_format>
|
||||
- Every fact from the documents must have a citation in the format [citation:chunk_id] where chunk_id is the EXACT id value from a `<chunk id='...'>` tag
|
||||
- Citations should appear at the end of the sentence containing the information they support
|
||||
- Multiple citations should be separated by commas: [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3]
|
||||
- No need to return references section. Just citations in answer.
|
||||
- NEVER create your own citation format - use the exact chunk_id values from the documents in the [citation:chunk_id] format
|
||||
- NEVER format citations as clickable links or as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only
|
||||
- NEVER make up chunk IDs if you are unsure about the chunk_id. It is better to omit the citation than to guess
|
||||
- Copy the EXACT chunk id from the XML - if it says `<chunk id='doc-123'>`, use [citation:doc-123]
|
||||
- If the chunk id is a URL like `<chunk id='https://example.com/page'>`, use [citation:https://example.com/page]
|
||||
</citation_format>
|
||||
|
||||
<citation_examples>
|
||||
CORRECT citation formats:
|
||||
- [citation:5] (numeric chunk ID from knowledge base)
|
||||
- [citation:doc-123] (for Surfsense documentation chunks)
|
||||
- [citation:https://example.com/article] (URL chunk ID from web search results)
|
||||
- [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3] (multiple citations)
|
||||
|
||||
INCORRECT citation formats (DO NOT use):
|
||||
- Using parentheses and markdown links: ([citation:5](https://github.com/MODSetter/SurfSense))
|
||||
- Using parentheses around brackets: ([citation:5])
|
||||
- Using hyperlinked text: [link to source 5](https://example.com)
|
||||
- Using footnote style: ... library¹
|
||||
- Making up source IDs when source_id is unknown
|
||||
- Using old IEEE format: [1], [2], [3]
|
||||
- Using source types instead of IDs: [citation:GITHUB_CONNECTOR] instead of [citation:5]
|
||||
</citation_examples>
|
||||
|
||||
<citation_output_example>
|
||||
Based on your GitHub repositories and video content, Python's asyncio library provides tools for writing concurrent code using the async/await syntax [citation:5]. It's particularly useful for I/O-bound and high-level structured network code [citation:5].
|
||||
|
||||
According to web search results, the key advantage of asyncio is that it can improve performance by allowing other code to run while waiting for I/O operations to complete [citation:https://docs.python.org/3/library/asyncio.html]. This makes it excellent for scenarios like web scraping, API calls, database operations, or any situation where your program spends time waiting for external resources.
|
||||
|
||||
However, from your video learning, it's important to note that asyncio is not suitable for CPU-bound tasks as it runs on a single thread [citation:12]. For computationally intensive work, you'd want to use multiprocessing instead.
|
||||
</citation_output_example>
|
||||
</citation_instructions>
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
<knowledge_base_only_policy>
|
||||
CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
|
||||
- You MUST answer questions ONLY using information retrieved from the user's knowledge base, web search results, scraped webpages, or other tool outputs.
|
||||
- You MUST NOT answer factual or informational questions from your own training data or general knowledge unless the user explicitly grants permission.
|
||||
- If the knowledge base search returns no relevant results AND no other tool provides the answer, you MUST:
|
||||
1. Inform the user that you could not find relevant information in their knowledge base.
|
||||
2. Ask the user: "Would you like me to answer from my general knowledge instead?"
|
||||
3. ONLY provide a general-knowledge answer AFTER the user explicitly says yes.
|
||||
- This policy does NOT apply to:
|
||||
* Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?")
|
||||
* Formatting, summarization, or analysis of content already present in the conversation
|
||||
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
|
||||
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
||||
* Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) — see <tool_routing> below
|
||||
</knowledge_base_only_policy>
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
<knowledge_base_only_policy>
|
||||
CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
|
||||
- You MUST answer questions ONLY using information retrieved from the team's shared knowledge base, web search results, scraped webpages, or other tool outputs.
|
||||
- You MUST NOT answer factual or informational questions from your own training data or general knowledge unless a team member explicitly grants permission.
|
||||
- If the knowledge base search returns no relevant results AND no other tool provides the answer, you MUST:
|
||||
1. Inform the team that you could not find relevant information in the shared knowledge base.
|
||||
2. Ask: "Would you like me to answer from my general knowledge instead?"
|
||||
3. ONLY provide a general-knowledge answer AFTER a team member explicitly says yes.
|
||||
- This policy does NOT apply to:
|
||||
* Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?")
|
||||
* Formatting, summarization, or analysis of content already present in the conversation
|
||||
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
|
||||
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
||||
* Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) — see <tool_routing> below
|
||||
</knowledge_base_only_policy>
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
<memory_protocol>
|
||||
IMPORTANT — After understanding each user message, ALWAYS check: does this message
|
||||
reveal durable facts about the user (role, interests, preferences, projects,
|
||||
background, or standing instructions)? If yes, you MUST call update_memory
|
||||
alongside your normal response — do not defer this to a later turn.
|
||||
</memory_protocol>
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
<memory_protocol>
|
||||
IMPORTANT — After understanding each user message, ALWAYS check: does this message
|
||||
reveal durable facts about the team (decisions, conventions, architecture, processes,
|
||||
or key facts)? If yes, you MUST call update_memory alongside your normal response —
|
||||
do not defer this to a later turn.
|
||||
</memory_protocol>
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
<parameter_resolution>
|
||||
Some service tools require identifiers or context you do not have (account IDs,
|
||||
workspace names, channel IDs, project keys, etc.). NEVER ask the user for raw
|
||||
IDs or technical identifiers — they cannot memorise them.
|
||||
|
||||
Instead, follow this discovery pattern:
|
||||
1. Call a listing/discovery tool to find available options.
|
||||
2. ONE result → use it silently, no question to the user.
|
||||
3. MULTIPLE results → present the options by their display names and let the
|
||||
user choose. Never show raw UUIDs — always use friendly names.
|
||||
|
||||
Discovery tools by level:
|
||||
- Which account/workspace? → get_connected_accounts("<service>")
|
||||
- Which Jira site (cloudId)? → getAccessibleAtlassianResources
|
||||
- Which Jira project? → getVisibleJiraProjects (after resolving cloudId)
|
||||
- Which Jira issue type? → getJiraProjectIssueTypesMetadata (after resolving project)
|
||||
- Which channel? → slack_search_channels
|
||||
- Which base? → list_bases
|
||||
- Which table? → list_tables_for_base (after resolving baseId)
|
||||
- Which task? → clickup_search
|
||||
- Which issue? → list_issues (Linear) or searchJiraIssuesUsingJql (Jira)
|
||||
|
||||
For Jira specifically: ALWAYS call getAccessibleAtlassianResources first to
|
||||
obtain the cloudId, then pass it to other Jira tools. When creating an issue,
|
||||
chain: getAccessibleAtlassianResources → getVisibleJiraProjects → createJiraIssue.
|
||||
If there is only one option at each step, use it silently. If multiple, present
|
||||
friendly names.
|
||||
|
||||
Chain discovery when needed — e.g. for Airtable records: list_bases → pick
|
||||
base → list_tables_for_base → pick table → list_records_for_table.
|
||||
|
||||
MULTI-ACCOUNT TOOL NAMING: When the user has multiple accounts connected for
|
||||
the same service, tool names are prefixed to avoid collisions — e.g.
|
||||
linear_25_list_issues and linear_30_list_issues instead of two list_issues.
|
||||
Each prefixed tool's description starts with [Account: <display_name>] so you
|
||||
know which account it targets. Use get_connected_accounts("<service>") to see
|
||||
the full list of accounts with their connector IDs and display names.
|
||||
When only one account is connected, tools have their normal unprefixed names.
|
||||
</parameter_resolution>
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
<tool_routing>
|
||||
CRITICAL — You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable.
|
||||
Their data is NEVER in the knowledge base. You MUST call their tools immediately — never
|
||||
say "I don't see it in the knowledge base" or ask the user if they want you to check.
|
||||
Ignore any knowledge base results for these services.
|
||||
|
||||
When to use which tool:
|
||||
- Linear (issues) → list_issues, get_issue, save_issue (create/update)
|
||||
- ClickUp (tasks) → clickup_search, clickup_get_task
|
||||
- Jira (issues) → getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue
|
||||
- Slack (messages, channels) → slack_search_channels, slack_read_channel, slack_read_thread
|
||||
- Airtable (bases, tables, records) → list_bases, list_tables_for_base, list_records_for_table
|
||||
- Knowledge base content (Notion, GitHub, files, notes) → automatically searched
|
||||
- Real-time public web data → call web_search
|
||||
- Reading a specific webpage → call scrape_webpage
|
||||
</tool_routing>
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
<tool_routing>
|
||||
CRITICAL — You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable.
|
||||
Their data is NEVER in the knowledge base. You MUST call their tools immediately — never
|
||||
say "I don't see it in the knowledge base" or ask if they want you to check.
|
||||
Ignore any knowledge base results for these services.
|
||||
|
||||
When to use which tool:
|
||||
- Linear (issues) → list_issues, get_issue, save_issue (create/update)
|
||||
- ClickUp (tasks) → clickup_search, clickup_get_task
|
||||
- Jira (issues) → getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue
|
||||
- Slack (messages, channels) → slack_search_channels, slack_read_channel, slack_read_thread
|
||||
- Airtable (bases, tables, records) → list_bases, list_tables_for_base, list_records_for_table
|
||||
- Knowledge base content (Notion, GitHub, files, notes) → automatically searched
|
||||
- Real-time public web data → call web_search
|
||||
- Reading a specific webpage → call scrape_webpage
|
||||
</tool_routing>
|
||||
359
surfsense_backend/app/agents/new_chat/prompts/composer.py
Normal file
359
surfsense_backend/app/agents/new_chat/prompts/composer.py
Normal file
|
|
@ -0,0 +1,359 @@
|
|||
"""
|
||||
Prompt composer for the SurfSense ``new_chat`` agent.
|
||||
|
||||
This module assembles the agent's system prompt from the markdown fragments
|
||||
under :mod:`app.agents.new_chat.prompts`. It replaces the monolithic
|
||||
``system_prompt.py`` with a clean, fragment-based composition:
|
||||
|
||||
::
|
||||
|
||||
prompts/
|
||||
base/ # agent identity, KB policy, tool routing, …
|
||||
providers/ # provider-specific tweaks (anthropic, gpt5, …)
|
||||
tools/ # one ``<name>.md`` per tool
|
||||
examples/ # one ``<name>.md`` per tool with call examples
|
||||
routing/ # connector-specific routing notes (linear, slack, …)
|
||||
|
||||
Tier 3a in the OpenCode-port plan.
|
||||
|
||||
Backwards compatibility
|
||||
=======================
|
||||
|
||||
``system_prompt.py`` re-exports :func:`compose_system_prompt` and wraps it
|
||||
in functions with the same signatures as the legacy
|
||||
``build_surfsense_system_prompt`` / ``build_configurable_system_prompt`` so
|
||||
existing call sites do not change.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import Iterable
|
||||
from datetime import UTC, datetime
|
||||
from importlib import resources
|
||||
|
||||
from app.db import ChatVisibility
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Provider variant detection
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
ProviderVariant = str # "anthropic" | "openai_reasoning" | "openai_classic" | "google" | "default"
|
||||
|
||||
_OPENAI_REASONING_RE = re.compile(r"\b(gpt-5|o\d|o-)", re.IGNORECASE)
|
||||
_OPENAI_CLASSIC_RE = re.compile(r"\bgpt-4", re.IGNORECASE)
|
||||
_ANTHROPIC_RE = re.compile(r"\bclaude\b", re.IGNORECASE)
|
||||
_GOOGLE_RE = re.compile(r"\bgemini\b", re.IGNORECASE)
|
||||
|
||||
|
||||
def detect_provider_variant(model_name: str | None) -> ProviderVariant:
|
||||
"""Pick a provider-specific prompt variant from a model id string.
|
||||
|
||||
Heuristic match on the model id; returns ``"default"`` when nothing
|
||||
matches so the composer can fall back to the empty placeholder file.
|
||||
"""
|
||||
if not model_name:
|
||||
return "default"
|
||||
name = model_name.strip()
|
||||
if _OPENAI_REASONING_RE.search(name):
|
||||
return "openai_reasoning"
|
||||
if _OPENAI_CLASSIC_RE.search(name):
|
||||
return "openai_classic"
|
||||
if _ANTHROPIC_RE.search(name):
|
||||
return "anthropic"
|
||||
if _GOOGLE_RE.search(name):
|
||||
return "google"
|
||||
return "default"
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Fragment loading
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
_PROMPTS_PACKAGE = "app.agents.new_chat.prompts"
|
||||
|
||||
|
||||
def _read_fragment(subpath: str) -> str:
|
||||
"""Read a fragment file from the ``prompts/`` resource tree.
|
||||
|
||||
Returns the raw contents stripped of any single trailing newline so
|
||||
composition can append explicit separators without compounding blank
|
||||
lines. Missing files return an empty string so optional fragments
|
||||
(e.g. provider hints) act as no-ops.
|
||||
"""
|
||||
parts = subpath.split("/")
|
||||
try:
|
||||
ref = resources.files(_PROMPTS_PACKAGE).joinpath(*parts)
|
||||
if not ref.is_file():
|
||||
return ""
|
||||
text = ref.read_text(encoding="utf-8")
|
||||
except (FileNotFoundError, ModuleNotFoundError):
|
||||
return ""
|
||||
if text.endswith("\n"):
|
||||
text = text[:-1]
|
||||
return text
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tool ordering + memory variant resolution
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
# Ordered for reading flow: fundamentals first, then artifact generators,
|
||||
# then memory at the end (mirrors the legacy ``_ALL_TOOL_NAMES_ORDERED``).
|
||||
ALL_TOOL_NAMES_ORDERED: tuple[str, ...] = (
|
||||
"search_surfsense_docs",
|
||||
"web_search",
|
||||
"generate_podcast",
|
||||
"generate_video_presentation",
|
||||
"generate_report",
|
||||
"generate_resume",
|
||||
"generate_image",
|
||||
"scrape_webpage",
|
||||
"update_memory",
|
||||
)
|
||||
|
||||
|
||||
_MEMORY_VARIANT_TOOLS: frozenset[str] = frozenset({"update_memory"})
|
||||
|
||||
|
||||
def _tool_fragment_path(tool_name: str, variant: str) -> str:
|
||||
"""Resolve a tool's instruction fragment path.
|
||||
|
||||
Tools listed in :data:`_MEMORY_VARIANT_TOOLS` switch on the conversation
|
||||
visibility and load ``tools/<name>_<variant>.md``; everything else
|
||||
falls back to ``tools/<name>.md``.
|
||||
"""
|
||||
if tool_name in _MEMORY_VARIANT_TOOLS:
|
||||
return f"tools/{tool_name}_{variant}.md"
|
||||
return f"tools/{tool_name}.md"
|
||||
|
||||
|
||||
def _example_fragment_path(tool_name: str, variant: str) -> str:
|
||||
if tool_name in _MEMORY_VARIANT_TOOLS:
|
||||
return f"examples/{tool_name}_{variant}.md"
|
||||
return f"examples/{tool_name}.md"
|
||||
|
||||
|
||||
def _format_tool_label(tool_name: str) -> str:
|
||||
return tool_name.replace("_", " ").title()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Section builders
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_system_instructions(
|
||||
*,
|
||||
visibility: ChatVisibility,
|
||||
resolved_today: str,
|
||||
) -> str:
|
||||
"""Reconstruct the legacy ``<system_instruction>`` block from fragments."""
|
||||
variant = "team" if visibility == ChatVisibility.SEARCH_SPACE else "private"
|
||||
|
||||
sections = [
|
||||
_read_fragment(f"base/agent_{variant}.md"),
|
||||
_read_fragment(f"base/kb_only_policy_{variant}.md"),
|
||||
_read_fragment(f"base/tool_routing_{variant}.md"),
|
||||
_read_fragment("base/parameter_resolution.md"),
|
||||
_read_fragment(f"base/memory_protocol_{variant}.md"),
|
||||
]
|
||||
body = "\n\n".join(s for s in sections if s)
|
||||
block = f"\n<system_instruction>\n{body}\n\n</system_instruction>\n"
|
||||
return block.format(resolved_today=resolved_today)
|
||||
|
||||
|
||||
def _build_mcp_routing_block(
|
||||
mcp_connector_tools: dict[str, list[str]] | None,
|
||||
) -> str:
|
||||
"""Emit the ``<mcp_tool_routing>`` block when at least one MCP server is wired."""
|
||||
if not mcp_connector_tools:
|
||||
return ""
|
||||
lines: list[str] = [
|
||||
"\n<mcp_tool_routing>",
|
||||
"You also have direct tools from these user-connected MCP servers.",
|
||||
"Their data is NEVER in the knowledge base — call their tools directly.",
|
||||
"",
|
||||
]
|
||||
for server_name, tool_names in mcp_connector_tools.items():
|
||||
lines.append(f"- {server_name} → {', '.join(tool_names)}")
|
||||
lines.append("</mcp_tool_routing>\n")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _build_tools_section(
|
||||
*,
|
||||
visibility: ChatVisibility,
|
||||
enabled_tool_names: set[str] | None,
|
||||
disabled_tool_names: set[str] | None,
|
||||
) -> str:
|
||||
"""Reconstruct the ``<tools>`` block + ``<tool_call_examples>`` block."""
|
||||
variant = "team" if visibility == ChatVisibility.SEARCH_SPACE else "private"
|
||||
|
||||
parts: list[str] = []
|
||||
preamble = _read_fragment("tools/_preamble.md")
|
||||
if preamble:
|
||||
parts.append(preamble + "\n")
|
||||
|
||||
examples: list[str] = []
|
||||
|
||||
for tool_name in ALL_TOOL_NAMES_ORDERED:
|
||||
if enabled_tool_names is not None and tool_name not in enabled_tool_names:
|
||||
continue
|
||||
|
||||
instruction = _read_fragment(_tool_fragment_path(tool_name, variant))
|
||||
if instruction:
|
||||
parts.append(instruction + "\n")
|
||||
|
||||
example = _read_fragment(_example_fragment_path(tool_name, variant))
|
||||
if example:
|
||||
examples.append(example + "\n")
|
||||
|
||||
known_disabled = (
|
||||
set(disabled_tool_names) & set(ALL_TOOL_NAMES_ORDERED)
|
||||
if disabled_tool_names
|
||||
else set()
|
||||
)
|
||||
if known_disabled:
|
||||
disabled_list = ", ".join(
|
||||
_format_tool_label(n)
|
||||
for n in ALL_TOOL_NAMES_ORDERED
|
||||
if n in known_disabled
|
||||
)
|
||||
parts.append(
|
||||
"\n"
|
||||
"DISABLED TOOLS (by user):\n"
|
||||
f"The following tools are available in SurfSense but have been disabled by the user for this session: {disabled_list}.\n"
|
||||
"You do NOT have access to these tools and MUST NOT claim you can use them.\n"
|
||||
"If the user asks about a capability provided by a disabled tool, let them know the relevant tool\n"
|
||||
"is currently disabled and they can re-enable it.\n"
|
||||
)
|
||||
|
||||
parts.append("\n</tools>\n")
|
||||
|
||||
if examples:
|
||||
parts.append("<tool_call_examples>")
|
||||
parts.extend(examples)
|
||||
parts.append("</tool_call_examples>\n")
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def _build_provider_block(provider_variant: ProviderVariant) -> str:
|
||||
"""Optional provider-tuned hints. Empty for ``"default"``."""
|
||||
if not provider_variant or provider_variant == "default":
|
||||
return ""
|
||||
text = _read_fragment(f"providers/{provider_variant}.md")
|
||||
return f"\n{text}\n" if text else ""
|
||||
|
||||
|
||||
def _build_routing_block(connector_routing: Iterable[str] | None) -> str:
|
||||
if not connector_routing:
|
||||
return ""
|
||||
fragments: list[str] = []
|
||||
for name in connector_routing:
|
||||
text = _read_fragment(f"routing/{name}.md")
|
||||
if text:
|
||||
fragments.append(text)
|
||||
if not fragments:
|
||||
return ""
|
||||
return "\n" + "\n\n".join(fragments) + "\n"
|
||||
|
||||
|
||||
def _build_citation_block(citations_enabled: bool) -> str:
|
||||
fragment = (
|
||||
_read_fragment("base/citations_on.md")
|
||||
if citations_enabled
|
||||
else _read_fragment("base/citations_off.md")
|
||||
)
|
||||
return f"\n{fragment}\n" if fragment else ""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Public API
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compose_system_prompt(
|
||||
*,
|
||||
today: datetime | None = None,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
enabled_tool_names: set[str] | None = None,
|
||||
disabled_tool_names: set[str] | None = None,
|
||||
mcp_connector_tools: dict[str, list[str]] | None = None,
|
||||
custom_system_instructions: str | None = None,
|
||||
use_default_system_instructions: bool = True,
|
||||
citations_enabled: bool = True,
|
||||
provider_variant: ProviderVariant | None = None,
|
||||
model_name: str | None = None,
|
||||
connector_routing: Iterable[str] | None = None,
|
||||
) -> str:
|
||||
"""Assemble the SurfSense system prompt from disk fragments.
|
||||
|
||||
Args:
|
||||
today: Optional clock injection for tests.
|
||||
thread_visibility: Private vs shared (team) — drives memory wording
|
||||
and a few base block variants.
|
||||
enabled_tool_names: When provided, only these tools' instructions
|
||||
are included; ``None`` keeps the legacy "include everything"
|
||||
behavior.
|
||||
disabled_tool_names: User-disabled tools (note appended to prompt).
|
||||
mcp_connector_tools: ``{server_name: [tool_names...]}`` to inject
|
||||
an explicit MCP routing block.
|
||||
custom_system_instructions: Free-form instructions that override
|
||||
the default ``<system_instruction>`` block (legacy support
|
||||
for ``NewLLMConfig.system_instructions``).
|
||||
use_default_system_instructions: When ``custom_system_instructions``
|
||||
is empty/None, fall back to defaults (legacy semantics).
|
||||
citations_enabled: Include ``citations_on.md`` (true) or
|
||||
``citations_off.md`` (false).
|
||||
provider_variant: Explicit provider variant override
|
||||
(``"anthropic" | "openai_reasoning" | "openai_classic" | "google" | "default"``).
|
||||
When ``None``, falls back to :func:`detect_provider_variant`
|
||||
on ``model_name``.
|
||||
model_name: Used to auto-detect ``provider_variant`` when not
|
||||
provided explicitly.
|
||||
connector_routing: Optional list of routing fragment names
|
||||
(``["linear", "slack", ...]``) to include from
|
||||
``prompts/routing/``.
|
||||
|
||||
Returns:
|
||||
The fully composed system prompt string.
|
||||
"""
|
||||
resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat()
|
||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||
|
||||
if custom_system_instructions and custom_system_instructions.strip():
|
||||
sys_block = custom_system_instructions.format(resolved_today=resolved_today)
|
||||
elif use_default_system_instructions:
|
||||
sys_block = _build_system_instructions(
|
||||
visibility=visibility, resolved_today=resolved_today
|
||||
)
|
||||
else:
|
||||
sys_block = ""
|
||||
|
||||
sys_block += _build_mcp_routing_block(mcp_connector_tools)
|
||||
|
||||
if provider_variant is None:
|
||||
provider_variant = detect_provider_variant(model_name)
|
||||
sys_block += _build_provider_block(provider_variant)
|
||||
sys_block += _build_routing_block(connector_routing)
|
||||
|
||||
tools_block = _build_tools_section(
|
||||
visibility=visibility,
|
||||
enabled_tool_names=enabled_tool_names,
|
||||
disabled_tool_names=disabled_tool_names,
|
||||
)
|
||||
citation_block = _build_citation_block(citations_enabled)
|
||||
|
||||
return sys_block + tools_block + citation_block
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ALL_TOOL_NAMES_ORDERED",
|
||||
"ProviderVariant",
|
||||
"compose_system_prompt",
|
||||
"detect_provider_variant",
|
||||
]
|
||||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
|
||||
- User: "Generate an image of a cat"
|
||||
- Call: `generate_image(prompt="A fluffy orange tabby cat sitting on a windowsill, bathed in warm golden sunlight, soft bokeh background with green houseplants, photorealistic style, cozy atmosphere")`
|
||||
- The generated image will automatically be displayed in the chat.
|
||||
- User: "Draw me a logo for a coffee shop called Bean Dream"
|
||||
- Call: `generate_image(prompt="Minimalist modern logo design for a coffee shop called 'Bean Dream', featuring a stylized coffee bean with dream-like swirls of steam, clean vector style, warm brown and cream color palette, white background, professional branding")`
|
||||
- The generated image will automatically be displayed in the chat.
|
||||
- User: "Show me this image: https://example.com/image.png"
|
||||
- Simply include it in your response using markdown: ``
|
||||
- User uploads an image file and asks: "What is this image about?"
|
||||
- The user's uploaded image is already visible in the chat.
|
||||
- Simply analyze the image content and respond directly.
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
|
||||
- User: "Give me a podcast about AI trends based on what we discussed"
|
||||
- First search for relevant content, then call: `generate_podcast(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", podcast_title="AI Trends Podcast")`
|
||||
- User: "Create a podcast summary of this conversation"
|
||||
- Call: `generate_podcast(source_content="Complete conversation summary:\n\nUser asked about [topic 1]:\n[Your detailed response]\n\nUser then asked about [topic 2]:\n[Your detailed response]\n\n[Continue for all exchanges in the conversation]", podcast_title="Conversation Summary")`
|
||||
- User: "Make a podcast about quantum computing"
|
||||
- First explore `/documents/` (ls/glob/grep/read_file), then: `generate_podcast(source_content="Key insights about quantum computing from retrieved files:\n\n[Comprehensive summary of findings]", podcast_title="Quantum Computing Explained")`
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
|
||||
- User: "Generate a report about AI trends"
|
||||
- Call: `generate_report(topic="AI Trends Report", source_strategy="kb_search", search_queries=["AI trends recent developments", "artificial intelligence industry trends", "AI market growth and predictions"], report_style="detailed")`
|
||||
- WHY: Has creation verb "generate" → call the tool. No prior discussion → use kb_search.
|
||||
- User: "Write a research report from this conversation"
|
||||
- Call: `generate_report(topic="Research Report", source_strategy="conversation", source_content="Complete conversation summary:\n\n...", report_style="deep_research")`
|
||||
- WHY: Has creation verb "write" → call the tool. Conversation has the content → use source_strategy="conversation".
|
||||
- User: (after a report on Climate Change was generated) "Add a section about carbon capture technologies"
|
||||
- Call: `generate_report(topic="Climate Crisis: Causes, Impacts, and Solutions", source_strategy="conversation", source_content="[summary of conversation context if any]", parent_report_id=<previous_report_id>, user_instructions="Add a new section about carbon capture technologies")`
|
||||
- WHY: Has modification verb "add" + specific deliverable target → call the tool with parent_report_id.
|
||||
- User: (after a report was generated) "What else could we add to have more depth?"
|
||||
- Do NOT call generate_report. Answer in chat with suggestions.
|
||||
- WHY: No creation/modification verb directed at producing a deliverable.
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
|
||||
- User: "Build me a resume. I'm John Doe, engineer at Acme Corp..."
|
||||
- Call: `generate_resume(user_info="John Doe, engineer at Acme Corp...", max_pages=1)`
|
||||
- WHY: Has creation verb "build" + resume → call the tool.
|
||||
- User: "Create my CV with this info: [experience, education, skills]"
|
||||
- Call: `generate_resume(user_info="[experience, education, skills]", max_pages=1)`
|
||||
- User: "Build me a resume" (and there is a resume/CV document in the conversation context)
|
||||
- Extract the FULL content from the document in context, then call:
|
||||
`generate_resume(user_info="Name: John Doe\nEmail: john@example.com\n\nExperience:\n- Senior Engineer at Acme Corp (2020-2024)\n Led team of 5...\n\nEducation:\n- BS Computer Science, MIT (2016-2020)\n\nSkills: Python, TypeScript, AWS...", max_pages=1)`
|
||||
- WHY: Document content is available in context — extract ALL of it into user_info. Do NOT ignore referenced documents.
|
||||
- User: (after resume generated) "Change my title to Senior Engineer"
|
||||
- Call: `generate_resume(user_info="", user_instructions="Change the job title to Senior Engineer", parent_report_id=<previous_report_id>, max_pages=1)`
|
||||
- WHY: Modification verb "change" + refers to existing resume → set parent_report_id.
|
||||
- User: (after resume generated) "Make this 2 pages and expand projects"
|
||||
- Call: `generate_resume(user_info="", user_instructions="Expand projects and keep this to at most 2 pages", parent_report_id=<previous_report_id>, max_pages=2)`
|
||||
- WHY: Explicit page increase request → set max_pages to 2.
|
||||
- User: "How should I structure my resume?"
|
||||
- Do NOT call generate_resume. Answer in chat with advice.
|
||||
- WHY: No creation/modification verb.
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
|
||||
- User: "Give me a presentation about AI trends based on what we discussed"
|
||||
- First search for relevant content, then call: `generate_video_presentation(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", video_title="AI Trends Presentation")`
|
||||
- User: "Create slides summarizing this conversation"
|
||||
- Call: `generate_video_presentation(source_content="Complete conversation summary:\n\nUser asked about [topic 1]:\n[Your detailed response]\n\nUser then asked about [topic 2]:\n[Your detailed response]\n\n[Continue for all exchanges in the conversation]", video_title="Conversation Summary")`
|
||||
- User: "Make a video presentation about quantum computing"
|
||||
- First explore `/documents/` (ls/glob/grep/read_file), then: `generate_video_presentation(source_content="Key insights about quantum computing from retrieved files:\n\n[Comprehensive summary of findings]", video_title="Quantum Computing Explained")`
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
|
||||
- User: "Check out https://dev.to/some-article"
|
||||
- Call: `scrape_webpage(url="https://dev.to/some-article")`
|
||||
- Respond with a structured analysis — key points, takeaways.
|
||||
- User: "Read this article and summarize it for me: https://example.com/blog/ai-trends"
|
||||
- Call: `scrape_webpage(url="https://example.com/blog/ai-trends")`
|
||||
- Respond with a thorough summary using headings and bullet points.
|
||||
- User: (after discussing https://example.com/stats) "Can you get the live data from that page?"
|
||||
- Call: `scrape_webpage(url="https://example.com/stats")`
|
||||
- IMPORTANT: Always attempt scraping first. Never refuse before trying the tool.
|
||||
- User: "https://example.com/blog/weekend-recipes"
|
||||
- Call: `scrape_webpage(url="https://example.com/blog/weekend-recipes")`
|
||||
- When a user sends just a URL with no instructions, scrape it and provide a concise summary of the content.
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
|
||||
- User: "How do I install SurfSense?"
|
||||
- Call: `search_surfsense_docs(query="installation setup")`
|
||||
- User: "What connectors does SurfSense support?"
|
||||
- Call: `search_surfsense_docs(query="available connectors integrations")`
|
||||
- User: "How do I set up the Notion connector?"
|
||||
- Call: `search_surfsense_docs(query="Notion connector setup configuration")`
|
||||
- User: "How do I use Docker to run SurfSense?"
|
||||
- Call: `search_surfsense_docs(query="Docker installation setup")`
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
|
||||
- <user_name>Alex</user_name>, <user_memory> is empty. User: "I'm a space enthusiast, explain astrophage to me"
|
||||
- The user casually shared a durable fact. Use their first name in the entry, short neutral heading:
|
||||
update_memory(updated_memory="## Interests & background\n- (2025-03-15) [fact] Alex is a space enthusiast\n")
|
||||
- User: "Remember that I prefer concise answers over detailed explanations"
|
||||
- Durable preference. Merge with existing memory, add a new heading:
|
||||
update_memory(updated_memory="## Interests & background\n- (2025-03-15) [fact] Alex is a space enthusiast\n\n## Response style\n- (2025-03-15) [pref] Alex prefers concise answers over detailed explanations\n")
|
||||
- User: "I actually moved to Tokyo last month"
|
||||
- Updated fact, date prefix reflects when recorded:
|
||||
update_memory(updated_memory="## Interests & background\n...\n\n## Personal context\n- (2025-03-15) [fact] Alex lives in Tokyo (previously London)\n...")
|
||||
- User: "I'm a freelance photographer working on a nature documentary"
|
||||
- Durable background info under a fitting heading:
|
||||
update_memory(updated_memory="...\n\n## Current focus\n- (2025-03-15) [fact] Alex is a freelance photographer\n- (2025-03-15) [fact] Alex is working on a nature documentary\n")
|
||||
- User: "Always respond in bullet points"
|
||||
- Standing instruction:
|
||||
update_memory(updated_memory="...\n\n## Response style\n- (2025-03-15) [instr] Always respond to Alex in bullet points\n")
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
|
||||
- User: "Let's remember that we decided to do weekly standup meetings on Mondays"
|
||||
- Durable team decision:
|
||||
update_memory(updated_memory="- (2025-03-15) [fact] Weekly standup meetings on Mondays\n...")
|
||||
- User: "Our office is in downtown Seattle, 5th floor"
|
||||
- Durable team fact:
|
||||
update_memory(updated_memory="- (2025-03-15) [fact] Office location: downtown Seattle, 5th floor\n...")
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
|
||||
- User: "What's the current USD to INR exchange rate?"
|
||||
- Call: `web_search(query="current USD to INR exchange rate")`
|
||||
- Then answer using the returned web results with citations.
|
||||
- User: "What's the latest news about AI?"
|
||||
- Call: `web_search(query="latest AI news today")`
|
||||
- User: "What's the weather in New York?"
|
||||
- Call: `web_search(query="weather New York today")`
|
||||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
<provider_hints>
|
||||
You are running on an Anthropic Claude model. Use XML tags liberally to structure
|
||||
intermediate reasoning when the task is complex. Prefer step-by-step plans inside
|
||||
`<thinking>` blocks before producing the final answer.
|
||||
</provider_hints>
|
||||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
<provider_hints>
|
||||
You are running on a Google Gemini model. Prefer concise, structured responses.
|
||||
When using tools, follow the function-calling protocol and avoid verbose preludes.
|
||||
</provider_hints>
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
<provider_hints>
|
||||
You are running on a classic OpenAI chat model (GPT-4 family). Use direct
|
||||
function-calling for tools. When editing files, use the standard `edit_file`
|
||||
or `write_file` tools rather than diff-based patches.
|
||||
</provider_hints>
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
<provider_hints>
|
||||
You are running on an OpenAI reasoning model (o-series / GPT-5+). Be terse and
|
||||
direct in your responses. When editing files, prefer the `apply_patch` tool format
|
||||
where available. Avoid restating the user request before answering.
|
||||
</provider_hints>
|
||||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
<tools>
|
||||
You have access to the following tools:
|
||||
|
||||
IMPORTANT: You can ONLY use the tools listed below. If a capability is not listed here, you do NOT have it.
|
||||
Do NOT claim you can do something if the corresponding tool is not listed.
|
||||
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
|
||||
- generate_image: Generate images from text descriptions using AI image models.
|
||||
- Use this when the user asks you to create, generate, draw, design, or make an image.
|
||||
- Trigger phrases: "generate an image of", "create a picture of", "draw me", "make an image", "design a logo", "create artwork"
|
||||
- Args:
|
||||
- prompt: A detailed text description of the image to generate. Be specific about subject, style, colors, composition, and mood.
|
||||
- n: Number of images to generate (1-4, default: 1)
|
||||
- Returns: A dictionary with the generated image metadata. The image will automatically be displayed in the chat.
|
||||
- IMPORTANT: Write a detailed, descriptive prompt for best results. Don't just pass the user's words verbatim -
|
||||
expand and improve the prompt with specific details about style, lighting, composition, and mood.
|
||||
- If the user's request is vague (e.g., "make me an image of a cat"), enhance the prompt with artistic details.
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
|
||||
- generate_podcast: Generate an audio podcast from provided content.
|
||||
- Use this when the user asks to create, generate, or make a podcast.
|
||||
- Trigger phrases: "give me a podcast about", "create a podcast", "generate a podcast", "make a podcast", "turn this into a podcast"
|
||||
- Args:
|
||||
- source_content: The text content to convert into a podcast. This MUST be comprehensive and include:
|
||||
* If discussing the current conversation: Include a detailed summary of the FULL chat history (all user questions and your responses)
|
||||
* If based on knowledge base search: Include the key findings and insights from the search results
|
||||
* You can combine both: conversation context + search results for richer podcasts
|
||||
* The more detailed the source_content, the better the podcast quality
|
||||
- podcast_title: Optional title for the podcast (default: "SurfSense Podcast")
|
||||
- user_prompt: Optional instructions for podcast style/format (e.g., "Make it casual and fun")
|
||||
- Returns: A task_id for tracking. The podcast will be generated in the background.
|
||||
- IMPORTANT: Only one podcast can be generated at a time. If a podcast is already being generated, the tool will return status "already_generating".
|
||||
- After calling this tool, inform the user that podcast generation has started and they will see the player when it's ready (takes 3-5 minutes).
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
|
||||
- generate_report: Generate or revise a structured Markdown report artifact.
|
||||
- WHEN TO CALL THIS TOOL — the message must contain a creation or modification VERB directed at producing a deliverable:
|
||||
* Creation verbs: write, create, generate, draft, produce, summarize into, turn into, make
|
||||
* Modification verbs: revise, update, expand, add (a section), rewrite, make (it shorter/longer/formal)
|
||||
* Example triggers: "generate a report about...", "write a document on...", "add a section about budget", "make the report shorter", "rewrite in formal tone"
|
||||
- WHEN NOT TO CALL THIS TOOL (answer in chat instead):
|
||||
* Questions or discussion about the report: "What can we add?", "What's missing?", "Is the data accurate?", "How could this be improved?"
|
||||
* Suggestions or brainstorming: "What other topics could be covered?", "What else could be added?", "What would make this better?"
|
||||
* Asking for explanations: "Can you explain section 2?", "Why did you include that?", "What does this part mean?"
|
||||
* Quick follow-ups or critiques: "Is the conclusion strong enough?", "Are there any gaps?", "What about the competitors?"
|
||||
* THE TEST: Does the message contain a creation/modification VERB (from the list above) directed at producing or changing a deliverable? If NO verb → answer conversationally in chat. Do NOT assume the user wants a revision just because a report exists in the conversation.
|
||||
- IMPORTANT FORMAT RULE: Reports are ALWAYS generated in Markdown.
|
||||
- Args:
|
||||
- topic: Short title for the report (max ~8 words).
|
||||
- source_content: The text content to base the report on.
|
||||
* For source_strategy="conversation" or "provided": Include a comprehensive summary of the relevant content.
|
||||
* For source_strategy="kb_search": Can be empty or minimal — the tool handles searching internally.
|
||||
* For source_strategy="auto": Include what you have; the tool searches KB if it's not enough.
|
||||
- source_strategy: Controls how the tool collects source material. One of:
|
||||
* "conversation" — The conversation already contains enough context (prior Q&A, discussion, pasted text, scraped pages). Pass a thorough summary as source_content.
|
||||
* "kb_search" — The tool will search the knowledge base internally. Provide search_queries with 1-5 targeted queries.
|
||||
* "auto" — Use source_content if sufficient, otherwise fall back to internal KB search using search_queries.
|
||||
* "provided" — Use only what is in source_content (default, backward-compatible).
|
||||
- search_queries: When source_strategy is "kb_search" or "auto", provide 1-5 specific search queries for the knowledge base. These should be precise, not just the topic name repeated.
|
||||
- report_style: Controls report depth. Options: "detailed" (DEFAULT), "deep_research", "brief".
|
||||
Use "brief" ONLY when the user explicitly asks for a short/concise/one-page report (e.g., "one page", "keep it short", "brief report", "500 words"). Default to "detailed" for all other requests.
|
||||
- user_instructions: Optional specific instructions (e.g., "focus on financial impacts", "include recommendations"). When revising (parent_report_id set), describe WHAT TO CHANGE. If the user mentions a length preference (e.g., "one page", "500 words", "2 pages"), include that VERBATIM here AND set report_style="brief".
|
||||
- parent_report_id: Set this to the report_id from a previous generate_report result when the user wants to MODIFY an existing report. Do NOT set it for new reports or questions about reports.
|
||||
- Returns: A dictionary with status "ready" or "failed", report_id, title, and word_count.
|
||||
- The report is generated immediately in Markdown and displayed inline in the chat.
|
||||
- Export/download formats (PDF, DOCX, HTML, LaTeX, EPUB, ODT, plain text) are produced from the generated Markdown report.
|
||||
- SOURCE STRATEGY DECISION (HIGH PRIORITY — follow this exactly):
|
||||
* If the conversation already has substantive Q&A / discussion on the topic → use source_strategy="conversation" with a comprehensive summary as source_content.
|
||||
* If the user wants a report on a topic not yet discussed → use source_strategy="kb_search" with targeted search_queries.
|
||||
* If you have some content but might need more → use source_strategy="auto" with both source_content and search_queries.
|
||||
* When revising an existing report (parent_report_id set) and the conversation has relevant context → use source_strategy="conversation". The revision will use the previous report content plus your source_content.
|
||||
* NEVER run a separate KB lookup step and then pass those results to generate_report. The tool handles KB search internally.
|
||||
- AFTER CALLING THIS TOOL: Do NOT repeat, summarize, or reproduce the report content in the chat. The report is already displayed as an interactive card that the user can open, read, copy, and export. Simply confirm that the report was generated (e.g., "I've generated your report on [topic]. You can view the Markdown report now, and export it in various formats from the card."). NEVER write out the report text in the chat.
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
|
||||
- generate_resume: Generate or revise a professional resume as a Typst document.
|
||||
- WHEN TO CALL: The user asks to create, build, generate, write, or draft a resume or CV.
|
||||
Also when they ask to modify, update, or revise an existing resume from this conversation.
|
||||
- WHEN NOT TO CALL: General career advice, resume tips, cover letters, or reviewing
|
||||
a resume without making changes. For cover letters, use generate_report instead.
|
||||
- The tool produces Typst source code that is compiled to a PDF preview automatically.
|
||||
- PAGE POLICY:
|
||||
- Default behavior is ONE PAGE. For new resume creation, set max_pages=1 unless the user explicitly asks for more.
|
||||
- If the user requests a longer resume (e.g., "make it 2 pages"), set max_pages to that value.
|
||||
- Args:
|
||||
- user_info: The user's resume content — work experience, education, skills, contact
|
||||
info, etc. Can be structured or unstructured text.
|
||||
CRITICAL: user_info must be COMPREHENSIVE. Do NOT just pass the user's raw message.
|
||||
You MUST gather and consolidate ALL available information:
|
||||
* Content from referenced/mentioned documents (e.g., uploaded resumes, CVs, LinkedIn profiles)
|
||||
that appear in the conversation context — extract and include their FULL content.
|
||||
* Information the user shared across multiple messages in the conversation.
|
||||
* Any relevant details from knowledge base search results in the context.
|
||||
The more complete the user_info, the better the resume. Include names, contact info,
|
||||
work experience with dates, education, skills, projects, certifications — everything available.
|
||||
- user_instructions: Optional style or content preferences (e.g. "emphasize leadership",
|
||||
"keep it to one page"). For revisions, describe what to change.
|
||||
- parent_report_id: Set this when the user wants to MODIFY an existing resume from
|
||||
this conversation. Use the report_id from a previous generate_resume result.
|
||||
- max_pages: Maximum resume length in pages (integer 1-5). Default is 1.
|
||||
- Returns: Dict with status, report_id, title, and content_type.
|
||||
- After calling: Give a brief confirmation. Do NOT paste resume content in chat. Do NOT mention report_id or any internal IDs — the resume card is shown automatically.
|
||||
- VERSIONING: Same rules as generate_report — set parent_report_id for modifications
|
||||
of an existing resume, leave as None for new resumes.
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
|
||||
- generate_video_presentation: Generate a video presentation from provided content.
|
||||
- Use this when the user asks to create a video, presentation, slides, or slide deck.
|
||||
- Trigger phrases: "give me a presentation", "create slides", "generate a video", "make a slide deck", "turn this into a presentation"
|
||||
- Args:
|
||||
- source_content: The text content to turn into a presentation. The more detailed, the better.
|
||||
- video_title: Optional title (default: "SurfSense Presentation")
|
||||
- user_prompt: Optional style instructions (e.g., "Make it technical and detailed")
|
||||
- After calling this tool, inform the user that generation has started and they will see the presentation when it's ready.
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
|
||||
- scrape_webpage: Scrape and extract the main content from a webpage.
|
||||
- Use this when the user wants you to READ and UNDERSTAND the actual content of a webpage.
|
||||
- CRITICAL — WHEN TO USE (always attempt scraping, never refuse before trying):
|
||||
* When a user asks to "get", "fetch", "pull", "grab", "scrape", or "read" content from a URL
|
||||
* When the user wants live/dynamic data from a specific webpage (e.g., tables, scores, stats, prices)
|
||||
* When a URL was mentioned earlier in the conversation and the user asks for its actual content
|
||||
* When `/documents/` knowledge-base data is insufficient and the user wants more
|
||||
- Trigger scenarios:
|
||||
* "Read this article and summarize it"
|
||||
* "What does this page say about X?"
|
||||
* "Summarize this blog post for me"
|
||||
* "Tell me the key points from this article"
|
||||
* "What's in this webpage?"
|
||||
* "Can you analyze this article?"
|
||||
* "Can you get the live table/data from [URL]?"
|
||||
* "Scrape it" / "Can you scrape that?" (referring to a previously mentioned URL)
|
||||
* "Fetch the content from [URL]"
|
||||
* "Pull the data from that page"
|
||||
- Args:
|
||||
- url: The URL of the webpage to scrape (must be HTTP/HTTPS)
|
||||
- max_length: Maximum content length to return (default: 50000 chars)
|
||||
- Returns: The page title, description, full content (in markdown), word count, and metadata
|
||||
- After scraping, provide a comprehensive, well-structured summary with key takeaways using headings or bullet points.
|
||||
- Reference the source using markdown links [descriptive text](url) — never bare URLs.
|
||||
- IMAGES: The scraped content may contain image URLs in markdown format like ``.
|
||||
* When you find relevant/important images in the scraped content, include them in your response using standard markdown image syntax: ``.
|
||||
* This makes your response more visual and engaging.
|
||||
* Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content.
|
||||
* Don't show every image - just the most relevant 1-3 images that enhance understanding.
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
|
||||
- search_surfsense_docs: Search the official SurfSense documentation.
|
||||
- Use this tool when the user asks anything about SurfSense itself (the application they are using).
|
||||
- Args:
|
||||
- query: The search query about SurfSense
|
||||
- top_k: Number of documentation chunks to retrieve (default: 10)
|
||||
- Returns: Documentation content with chunk IDs for citations (prefixed with 'doc-', e.g., [citation:doc-123])
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
|
||||
- update_memory: Update your personal memory document about the user.
|
||||
- Your current memory is already in <user_memory> in your context. The `chars` and
|
||||
`limit` attributes show your current usage and the maximum allowed size.
|
||||
- This is your curated long-term memory — the distilled essence of what you know about
|
||||
the user, not raw conversation logs.
|
||||
- Call update_memory when:
|
||||
* The user explicitly asks to remember or forget something
|
||||
* The user shares durable facts or preferences that will matter in future conversations
|
||||
- The user's first name is provided in <user_name>. Use it in memory entries
|
||||
instead of "the user" (e.g. "{name} works at..." not "The user works at...").
|
||||
Do not store the name itself as a separate memory entry.
|
||||
- Do not store short-lived or ephemeral info: one-off questions, greetings,
|
||||
session logistics, or things that only matter for the current task.
|
||||
- Args:
|
||||
- updated_memory: The FULL updated markdown document (not a diff).
|
||||
Merge new facts with existing ones, update contradictions, remove outdated entries.
|
||||
Treat every update as a curation pass — consolidate, don't just append.
|
||||
- Every bullet MUST use this format: - (YYYY-MM-DD) [marker] text
|
||||
Markers:
|
||||
[fact] — durable facts (role, background, projects, tools, expertise)
|
||||
[pref] — preferences (response style, languages, formats, tools)
|
||||
[instr] — standing instructions (always/never do, response rules)
|
||||
- Keep it concise and well under the character limit shown in <user_memory>.
|
||||
- Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and
|
||||
natural. Do NOT include the user's name in headings. Organize by context — e.g.
|
||||
who they are, what they're focused on, how they prefer things. Create, split, or
|
||||
merge headings freely as the memory grows.
|
||||
- Each entry MUST be a single bullet point. Be descriptive but concise — include relevant
|
||||
details and context rather than just a few words.
|
||||
- During consolidation, prioritize keeping: [instr] > [pref] > [fact].
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
|
||||
- update_memory: Update the team's shared memory document for this search space.
|
||||
- Your current team memory is already in <team_memory> in your context. The `chars`
|
||||
and `limit` attributes show current usage and the maximum allowed size.
|
||||
- This is the team's curated long-term memory — decisions, conventions, key facts.
|
||||
- NEVER store personal memory in team memory (e.g. personal bio, individual
|
||||
preferences, or user-only standing instructions).
|
||||
- Call update_memory when:
|
||||
* A team member explicitly asks to remember or forget something
|
||||
* The conversation surfaces durable team decisions, conventions, or facts
|
||||
that will matter in future conversations
|
||||
- Do not store short-lived or ephemeral info: one-off questions, greetings,
|
||||
session logistics, or things that only matter for the current task.
|
||||
- Args:
|
||||
- updated_memory: The FULL updated markdown document (not a diff).
|
||||
Merge new facts with existing ones, update contradictions, remove outdated entries.
|
||||
Treat every update as a curation pass — consolidate, don't just append.
|
||||
- Every bullet MUST use this format: - (YYYY-MM-DD) [fact] text
|
||||
Team memory uses ONLY the [fact] marker. Never use [pref] or [instr] in team memory.
|
||||
- Keep it concise and well under the character limit shown in <team_memory>.
|
||||
- Every entry MUST be under a `##` heading. Keep heading names short (2-3 words) and
|
||||
natural. Organize by context — e.g. what the team decided, current architecture,
|
||||
active processes. Create, split, or merge headings freely as the memory grows.
|
||||
- Each entry MUST be a single bullet point. Be descriptive but concise — include relevant
|
||||
details and context rather than just a few words.
|
||||
- During consolidation, prioritize keeping: decisions/conventions > key facts > current priorities.
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
|
||||
- web_search: Search the web for real-time information using all configured search engines.
|
||||
- Use this for current events, news, prices, weather, public facts, or any question requiring
|
||||
up-to-date information from the internet.
|
||||
- This tool dispatches to all configured search engines (SearXNG, Tavily, Linkup, Baidu) in
|
||||
parallel and merges the results.
|
||||
- IMPORTANT (REAL-TIME / PUBLIC WEB QUERIES): For questions that require current public web data
|
||||
(e.g., live exchange rates, stock prices, breaking news, weather, current events), you MUST call
|
||||
`web_search` instead of answering from memory.
|
||||
- For these real-time/public web queries, DO NOT answer from memory and DO NOT say you lack internet
|
||||
access before attempting a web search.
|
||||
- If the search returns no relevant results, explain that web sources did not return enough
|
||||
data and ask the user if they want you to retry with a refined query.
|
||||
- Args:
|
||||
- query: The search query - use specific, descriptive terms
|
||||
- top_k: Number of results to retrieve (default: 10, max: 50)
|
||||
- If search snippets are insufficient for the user's question, use `scrape_webpage` on the most relevant result URL for full content.
|
||||
- When presenting results, reference sources as markdown links [descriptive text](url) — never bare URLs.
|
||||
7
surfsense_backend/app/agents/new_chat/skills/__init__.py
Normal file
7
surfsense_backend/app/agents/new_chat/skills/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
"""SurfSense built-in agent skills (Anthropic Skills format).
|
||||
|
||||
Each subdirectory corresponds to one skill and contains a ``SKILL.md`` file
|
||||
with YAML frontmatter (name, description, allowed_tools) plus markdown
|
||||
instructions. The :class:`BuiltinSkillsBackend` exposes them to the
|
||||
deepagents :class:`SkillsMiddleware`.
|
||||
"""
|
||||
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
---
|
||||
name: email-drafting
|
||||
description: Draft an email matching the user's voice, with structured intent and CTA
|
||||
allowed-tools: search_surfsense_docs
|
||||
---
|
||||
|
||||
# Email drafting
|
||||
|
||||
## When to use this skill
|
||||
"Draft an email to ...", "reply to this thread", "write a follow-up to X". Plain "summarize the email" is **not** in scope — that's a comprehension task.
|
||||
|
||||
## Voice
|
||||
Search the KB for prior emails from the user to similar audiences (same recipient, same topic class). Mirror tone, opening style, sign-off, and length distribution. If there is no precedent, default to: warm, direct, no filler, short paragraphs, one clear ask.
|
||||
|
||||
## Required structure
|
||||
Every draft includes, in this order:
|
||||
|
||||
1. **Subject line** — concrete, ≤ 8 words, no clickbait, no `Re:` unless replying.
|
||||
2. **Opening (1 sentence)** — context the recipient already shares; never restate what they wrote unless the thread is long.
|
||||
3. **Body** — the actual point in one short paragraph. Bullets only if there are >3 discrete items.
|
||||
4. **Single explicit CTA** — what you want the recipient to do, with a soft deadline if relevant.
|
||||
5. **Sign-off** — match the user's prior closing style.
|
||||
|
||||
## Always offer alternatives
|
||||
End your message with: "Want me to make it shorter, more formal, or add a different angle?" — give the user one obvious next step.
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
---
|
||||
name: kb-research
|
||||
description: Structured approach to finding and synthesizing information from the user's knowledge base
|
||||
allowed-tools: search_surfsense_docs, scrape_webpage, read_file, ls_tree, grep, web_search
|
||||
---
|
||||
|
||||
# Knowledge-base research
|
||||
|
||||
## When to use this skill
|
||||
- The user asks "find/look up/research" something specifically inside their knowledge base.
|
||||
- The user references documents, notes, repos, or connector data they expect to exist already.
|
||||
- A multi-document synthesis is required (e.g., "summarize what we've discussed about X across all my notes").
|
||||
|
||||
## Plan
|
||||
1. Decompose the user's question into 2-4 specific, citation-worthy sub-questions.
|
||||
2. For each sub-question, run **one** targeted KB search (focused on terms the user would have written, not synonyms). Open the most relevant 2-3 documents fully via `read_file` if their excerpts are too short.
|
||||
3. Use `grep` to find supporting passages in long files instead of re-reading them end to end.
|
||||
4. Cite every claim with `[citation:chunk_id]` exactly as the chunk tag specifies.
|
||||
|
||||
## What good output looks like
|
||||
- Short paragraphs with inline citations.
|
||||
- Quoted phrases when wording matters.
|
||||
- An explicit "Not found in your knowledge base" callout when a sub-question has no support — never fabricate.
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
---
|
||||
name: meeting-prep
|
||||
description: Pull together briefing materials before a scheduled meeting
|
||||
allowed-tools: search_surfsense_docs, web_search, scrape_webpage, read_file
|
||||
---
|
||||
|
||||
# Meeting preparation
|
||||
|
||||
## When to use this skill
|
||||
The user mentions an upcoming meeting, call, or interview and asks you to "prep", "brief me", "pull background", or "what do I need to know about X before tomorrow".
|
||||
|
||||
## Output structure
|
||||
Always produce these sections (omit any with no signal — don't pad):
|
||||
|
||||
1. **Attendees & context** — who's in the room, their roles, what they care about. Pull from KB notes about prior interactions; supplement with public profile facts via `web_search` when names or companies are unfamiliar.
|
||||
2. **Open threads** — outstanding action items, unresolved decisions, last-mentioned blockers from prior conversation history.
|
||||
3. **Recent moves** — within the last 30 days: relevant launches, hires, news. Cite KB chunks when present, otherwise external sources.
|
||||
4. **Suggested questions** — 3-5 questions the user could ask, tailored to the open threads and the attendees' likely priorities.
|
||||
|
||||
## Source ordering
|
||||
- Always check the user's KB **first** for prior meeting notes, internal docs, or Slack threads about these attendees.
|
||||
- Only fall back to `web_search` for *publicly verifiable* facts — never to fabricate a participant's preferences or relationships.
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
---
|
||||
name: report-writing
|
||||
description: How to scope, draft, and revise a Markdown report artifact via generate_report
|
||||
allowed-tools: generate_report, search_surfsense_docs, read_file
|
||||
---
|
||||
|
||||
# Report writing
|
||||
|
||||
## When to use this skill
|
||||
The user explicitly requests a deliverable: "write a report on …", "draft a memo", "produce a brief", "expand the previous report". A creation or modification verb pointed at an artifact is required (see `generate_report`'s when-to-call rules).
|
||||
|
||||
## Decision flow
|
||||
1. **Source strategy.** Decide which `source_strategy` fits:
|
||||
- `conversation` — substantive Q&A on the topic already in chat.
|
||||
- `kb_search` — fresh topic; supply 1–5 precise `search_queries`.
|
||||
- `auto` — partial conversation context; let the tool fall back.
|
||||
- `provided` — verbatim source text only.
|
||||
2. **Style.** Default to `report_style="detailed"` unless the user explicitly asks for "brief", "one page", "500 words".
|
||||
3. **Revisions.** When modifying an existing report from this conversation, set `parent_report_id` and put the change list in `user_instructions` ("add carbon-capture section", "tighten conclusion").
|
||||
4. **Never paste the report back into chat** after `generate_report` returns — confirm and let the artifact card render itself.
|
||||
|
||||
## Hooks for KB-only mode
|
||||
If `kb_search`/`auto` returns no results, do **not** silently switch to general knowledge. Surface the gap in your confirmation message.
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
---
|
||||
name: slack-summary
|
||||
description: Distill a Slack channel or thread into actionable summary
|
||||
allowed-tools: search_surfsense_docs
|
||||
---
|
||||
|
||||
# Slack summarization
|
||||
|
||||
## When to use this skill
|
||||
The user asks to summarize Slack ("what happened in #eng-platform this week", "what did Alice say about the launch", "catch me up on the design channel").
|
||||
|
||||
## Required inputs
|
||||
Confirm before searching:
|
||||
- **Which channel(s) or thread(s)?** Don't guess if ambiguous.
|
||||
- **What time window?** Default to the last 7 days when not specified, but say so.
|
||||
|
||||
## Output shape
|
||||
Produce three concise sections:
|
||||
1. **Key decisions** — explicit choices that were made, with the deciding message cited.
|
||||
2. **Open questions** — things asked but not answered, with the asking message cited.
|
||||
3. **Action items** — `@mention` who owes what by when, *only if explicitly stated*. Don't invent assignees.
|
||||
|
||||
## What not to do
|
||||
- Never produce a chronological play-by-play of every message — distill.
|
||||
- Never quote private messages without flagging them as such.
|
||||
- If the channel was empty in the time window, say so — don't fabricate filler.
|
||||
26
surfsense_backend/app/agents/new_chat/subagents/__init__.py
Normal file
26
surfsense_backend/app/agents/new_chat/subagents/__init__.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
"""Specialized user-facing subagents for the SurfSense agent.
|
||||
|
||||
Each subagent is a :class:`deepagents.SubAgent` typed-dict spec passed to
|
||||
:class:`deepagents.SubAgentMiddleware`, which materializes them as ephemeral
|
||||
runnables invoked via the ``task`` tool.
|
||||
|
||||
Per-subagent permission rules are injected as a
|
||||
:class:`PermissionMiddleware` entry inside the subagent's ``middleware``
|
||||
field, mirroring opencode ``tool/task.ts`` which seeds child sessions with
|
||||
deny rules for tools the parent does not want them touching (e.g.
|
||||
``task``/``todowrite`` recursion, write tools for read-only research roles).
|
||||
"""
|
||||
|
||||
from .config import (
|
||||
build_connector_negotiator_subagent,
|
||||
build_explore_subagent,
|
||||
build_report_writer_subagent,
|
||||
build_specialized_subagents,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"build_connector_negotiator_subagent",
|
||||
"build_explore_subagent",
|
||||
"build_report_writer_subagent",
|
||||
"build_specialized_subagents",
|
||||
]
|
||||
427
surfsense_backend/app/agents/new_chat/subagents/config.py
Normal file
427
surfsense_backend/app/agents/new_chat/subagents/config.py
Normal file
|
|
@ -0,0 +1,427 @@
|
|||
"""Builders for specialized SurfSense subagents.
|
||||
|
||||
Each subagent is built from three pieces:
|
||||
|
||||
1. A name + description + system prompt (the user-facing contract for
|
||||
when ``task`` should delegate to this role).
|
||||
2. A filtered tool list (subset of the parent's bound tools).
|
||||
3. A :class:`PermissionMiddleware` instance carrying a deny ruleset that
|
||||
prevents the subagent from acting outside its scope (e.g. an
|
||||
explore-only role cannot mutate state).
|
||||
|
||||
Skill sources (``/skills/builtin/`` + ``/skills/space/``) are inherited
|
||||
from the parent unconditionally — every subagent benefits from the same
|
||||
authored guidance documents.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from app.agents.new_chat.middleware.skills_backends import default_skills_sources
|
||||
from app.agents.new_chat.permissions import Rule, Ruleset
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deepagents import SubAgent
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool name constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Read-only tools that ``explore`` is permitted to use. Names match the
|
||||
# tools provided by the deepagents ``FilesystemMiddleware`` (``ls``, ``read_file``,
|
||||
# ``glob``, ``grep``) plus the SurfSense-side read tools.
|
||||
EXPLORE_READ_TOOLS: frozenset[str] = frozenset(
|
||||
{
|
||||
"search_surfsense_docs",
|
||||
"web_search",
|
||||
"scrape_webpage",
|
||||
"read_file",
|
||||
"ls",
|
||||
"glob",
|
||||
"grep",
|
||||
}
|
||||
)
|
||||
|
||||
# Tools ``report_writer`` may call. The set is intentionally narrow so the
|
||||
# subagent doesn't drift into tangential research; if richer source-gathering
|
||||
# is needed, the parent should hand off to ``explore`` first.
|
||||
REPORT_WRITER_TOOLS: frozenset[str] = frozenset(
|
||||
{
|
||||
"search_surfsense_docs",
|
||||
"read_file",
|
||||
"generate_report",
|
||||
}
|
||||
)
|
||||
|
||||
# Wildcard patterns that match write tools we deny by default in read-only
|
||||
# subagents. Anchored at start AND end via :func:`Rule` semantics. We use
|
||||
# substring-style ``*verb*`` patterns because connector tool names typically
|
||||
# put the verb in the middle (``linear_create_issue``, ``slack_send_message``,
|
||||
# ``notion_update_page``); strict suffix patterns (``*_create``) miss those.
|
||||
#
|
||||
# A handful of canonical exact-match names is appended so that bare verbs
|
||||
# (``edit``, ``write``) are also blocked even when a connector dropped the
|
||||
# usual prefix.
|
||||
WRITE_TOOL_DENY_PATTERNS: tuple[str, ...] = (
|
||||
"*create*",
|
||||
"*update*",
|
||||
"*delete*",
|
||||
"*send*",
|
||||
"*write*",
|
||||
"*edit*",
|
||||
"*move*",
|
||||
"*mkdir*",
|
||||
"*upload*",
|
||||
"edit_file",
|
||||
"write_file",
|
||||
"move_file",
|
||||
"mkdir",
|
||||
"update_memory",
|
||||
"update_memory_team",
|
||||
"update_memory_private",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# Tool names that are NOT in the registry's ``tools`` list because they
|
||||
# are provided dynamically by middleware at compile time. We don't pass
|
||||
# them through ``_filter_tools`` (the actual ``BaseTool`` instances live
|
||||
# inside the middleware), but we do exempt them from the "missing" warning
|
||||
# below — operators were seeing spurious noise like
|
||||
# ``missing: ['glob', 'grep', 'ls', 'read_file']`` even though those
|
||||
# tools are reachable via :class:`SurfSenseFilesystemMiddleware` once the
|
||||
# subagent is compiled.
|
||||
_MIDDLEWARE_PROVIDED_TOOL_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
"ls",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"edit_file",
|
||||
"glob",
|
||||
"grep",
|
||||
"execute",
|
||||
"write_todos",
|
||||
"task",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _filter_tools(
|
||||
tools: Sequence[BaseTool],
|
||||
allowed_names: Iterable[str],
|
||||
) -> list[BaseTool]:
|
||||
"""Return only tools whose ``name`` appears in ``allowed_names``.
|
||||
|
||||
Tools are looked up by exact name. Names matching
|
||||
:data:`_MIDDLEWARE_PROVIDED_TOOL_NAMES` are intentionally absent from
|
||||
``tools`` (they're injected by middleware at compile time) and are
|
||||
silently excluded from the "missing" warning so operators don't see
|
||||
false positives every build.
|
||||
"""
|
||||
allowed = set(allowed_names)
|
||||
selected = [t for t in tools if t.name in allowed]
|
||||
missing = sorted(
|
||||
(allowed - {t.name for t in selected}) - _MIDDLEWARE_PROVIDED_TOOL_NAMES
|
||||
)
|
||||
if missing:
|
||||
logger.info(
|
||||
"Subagent build: %d/%d registry tools available; missing: %s",
|
||||
len(selected),
|
||||
len(allowed - _MIDDLEWARE_PROVIDED_TOOL_NAMES),
|
||||
missing,
|
||||
)
|
||||
return selected
|
||||
|
||||
|
||||
def _read_only_deny_rules() -> list[Rule]:
|
||||
"""Synthesize a list of deny rules covering common write-tool patterns."""
|
||||
return [
|
||||
Rule(permission=pattern, pattern="*", action="deny")
|
||||
for pattern in WRITE_TOOL_DENY_PATTERNS
|
||||
]
|
||||
|
||||
|
||||
def _build_permission_middleware(deny_rules: list[Rule], origin: str):
|
||||
"""Construct a :class:`PermissionMiddleware` seeded with ``deny_rules``.
|
||||
|
||||
Imported lazily because the middleware module pulls in interrupt/HITL
|
||||
machinery we don't want at import time of this config file.
|
||||
"""
|
||||
from app.agents.new_chat.middleware.permission import PermissionMiddleware
|
||||
|
||||
return PermissionMiddleware(
|
||||
rulesets=[Ruleset(rules=deny_rules, origin=origin)],
|
||||
)
|
||||
|
||||
|
||||
def _wrap_with_subagent_essentials(
|
||||
custom_middleware: list,
|
||||
*,
|
||||
agent_tools: Sequence[BaseTool],
|
||||
extra_middleware: Sequence[Any] | None = None,
|
||||
):
|
||||
"""Compose the final middleware list for a specialized subagent.
|
||||
|
||||
Order, outer to inner:
|
||||
|
||||
1. ``extra_middleware`` — provided by the caller (typically the parent
|
||||
agent's ``SurfSenseFilesystemMiddleware`` and ``TodoListMiddleware``)
|
||||
so the subagent inherits the parent's filesystem/todo view. These
|
||||
run **before** the subagent-local middleware so their tools are
|
||||
wired up before permissioning kicks in.
|
||||
2. ``custom_middleware`` — subagent-local rules (e.g. permission deny
|
||||
lists).
|
||||
3. :class:`PatchToolCallsMiddleware` — normalizes tool-call shapes.
|
||||
4. :class:`DedupHITLToolCallsMiddleware` — collapses duplicate HITL
|
||||
calls using metadata declared at registry time.
|
||||
|
||||
Without ``extra_middleware`` the subagent will only have the registry
|
||||
tools listed in its ``tools`` field — meaning ``read_file``, ``ls``,
|
||||
``grep``, etc. won't exist. Always pass ``extra_middleware`` from the
|
||||
parent unless you specifically want a sandboxed subagent.
|
||||
"""
|
||||
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
|
||||
from app.agents.new_chat.middleware import DedupHITLToolCallsMiddleware
|
||||
|
||||
return [
|
||||
*(extra_middleware or []),
|
||||
*custom_middleware,
|
||||
PatchToolCallsMiddleware(),
|
||||
DedupHITLToolCallsMiddleware(agent_tools=list(agent_tools)),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# System prompts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
EXPLORE_SYSTEM_PROMPT = """You are the **explore** subagent for SurfSense.
|
||||
|
||||
## Your job
|
||||
Conduct read-only research across the user's knowledge base, the web, and any documents the parent agent has surfaced. Return a synthesized answer with explicit citations — never speculate beyond the sources you have actually inspected.
|
||||
|
||||
## Tools available
|
||||
- `search_surfsense_docs` — fast hybrid search over the user's knowledge base.
|
||||
- `web_search` — only when the user's KB clearly does not contain the answer.
|
||||
- `scrape_webpage` — to read a URL the user or the search results provided.
|
||||
- `read_file`, `ls`, `glob`, `grep` — to inspect specific documents or trees the parent has flagged.
|
||||
|
||||
## Rules
|
||||
- Read-only. You cannot create, edit, delete, send, or move anything.
|
||||
- Cite every claim. Use `[citation:chunk_id]` exactly as the chunk tag specifies.
|
||||
- If a sub-question has no support in the inspected sources, say so explicitly. Do not fabricate.
|
||||
- Return the most useful synthesis in your single final message. The parent agent will not be able to follow up.
|
||||
"""
|
||||
|
||||
|
||||
REPORT_WRITER_SYSTEM_PROMPT = """You are the **report_writer** subagent for SurfSense.
|
||||
|
||||
## Your job
|
||||
Produce a single high-quality report deliverable using `generate_report`. The parent has already gathered (or knows where to gather) the underlying sources.
|
||||
|
||||
## Workflow
|
||||
1. **Outline first.** Before calling `generate_report`, write a one-paragraph outline of the sections you plan to produce. Confirm the outline reflects the parent's instructions.
|
||||
2. **Source resolution.** Decide whether to call `search_surfsense_docs` and `read_file` for any final-checks, or whether the parent's earlier tool calls already cover the source set.
|
||||
3. **One report.** Call `generate_report` exactly once with `source_strategy` chosen per the topic and chat history (see the `report-writing` skill).
|
||||
4. **Confirm.** End with a one-sentence summary in your final message — never paste the report back into chat; the artifact card renders itself.
|
||||
"""
|
||||
|
||||
|
||||
CONNECTOR_NEGOTIATOR_SYSTEM_PROMPT = """You are the **connector_negotiator** subagent for SurfSense.
|
||||
|
||||
## Your job
|
||||
Coordinate cross-connector workflows: chains where the result of one service's tool feeds into another's. Common shapes include "find Linear issues mentioned in last week's Slack messages", "draft a Gmail reply citing a Notion doc", or "list Linear tickets opened by the same person who filed Jira FOO-123".
|
||||
|
||||
## Workflow
|
||||
1. **Plan.** Identify the connector hops needed and the order they should run in. Write a short plan in your first message.
|
||||
2. **Verify access.** Use `get_connected_accounts` to confirm the relevant connectors are actually wired up before issuing tool calls. If a connector is missing, stop and report — do not fabricate.
|
||||
3. **Execute.** Run each hop, citing IDs (issue keys, message ts, page IDs) in your scratch notes so the parent can audit.
|
||||
4. **Hand back.** Return a structured summary with the final answer plus the chain of evidence (issue → message → page, etc.).
|
||||
|
||||
## Caveats
|
||||
- If a hop fails, do not retry blindly — return the partial result and explain.
|
||||
- Mutating tools (create, update, delete, send) require parent permission; you are NOT cleared to call them on your own.
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subagent builders
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def build_explore_subagent(
|
||||
*,
|
||||
tools: Sequence[BaseTool],
|
||||
model: BaseChatModel | None = None,
|
||||
extra_middleware: Sequence[Any] | None = None,
|
||||
) -> SubAgent:
|
||||
"""Build the read-only ``explore`` subagent spec.
|
||||
|
||||
Pass ``extra_middleware`` (typically the parent's filesystem + todo
|
||||
middleware) so the subagent can actually use ``read_file``, ``ls``,
|
||||
``grep``, ``glob`` — which its system prompt promises but which only
|
||||
exist when their middleware is mounted.
|
||||
"""
|
||||
from deepagents import SubAgent # noqa: F401 (TypedDict for type clarity)
|
||||
|
||||
selected_tools = _filter_tools(tools, EXPLORE_READ_TOOLS)
|
||||
deny_rules = _read_only_deny_rules()
|
||||
permission_mw = _build_permission_middleware(
|
||||
deny_rules, origin="subagent_explore"
|
||||
)
|
||||
|
||||
spec: dict = {
|
||||
"name": "explore",
|
||||
"description": (
|
||||
"Read-only research across the user's knowledge base and the web. "
|
||||
"Use when the parent needs deeply-cited synthesis without "
|
||||
"modifying anything."
|
||||
),
|
||||
"system_prompt": EXPLORE_SYSTEM_PROMPT,
|
||||
"tools": selected_tools,
|
||||
"middleware": _wrap_with_subagent_essentials(
|
||||
[permission_mw],
|
||||
agent_tools=selected_tools,
|
||||
extra_middleware=extra_middleware,
|
||||
),
|
||||
"skills": default_skills_sources(),
|
||||
}
|
||||
if model is not None:
|
||||
spec["model"] = model
|
||||
return spec # type: ignore[return-value]
|
||||
|
||||
|
||||
def build_report_writer_subagent(
|
||||
*,
|
||||
tools: Sequence[BaseTool],
|
||||
model: BaseChatModel | None = None,
|
||||
extra_middleware: Sequence[Any] | None = None,
|
||||
) -> SubAgent:
|
||||
"""Build the ``report_writer`` subagent spec.
|
||||
|
||||
Read-only deny ruleset still applies — the subagent should call
|
||||
``generate_report`` and nothing else mutating. ``generate_report``
|
||||
creates a report artifact via a backend service and is intentionally
|
||||
**not** denied.
|
||||
|
||||
Pass ``extra_middleware`` (typically the parent's filesystem + todo
|
||||
middleware) so the subagent can run ``read_file`` for source-checks
|
||||
before calling ``generate_report``.
|
||||
"""
|
||||
selected_tools = _filter_tools(tools, REPORT_WRITER_TOOLS)
|
||||
deny_rules = _read_only_deny_rules()
|
||||
permission_mw = _build_permission_middleware(
|
||||
deny_rules, origin="subagent_report_writer"
|
||||
)
|
||||
|
||||
spec: dict = {
|
||||
"name": "report_writer",
|
||||
"description": (
|
||||
"Produce a single Markdown report artifact via generate_report, "
|
||||
"using the outline-then-fill protocol. Use when the parent has "
|
||||
"decided a deliverable is needed."
|
||||
),
|
||||
"system_prompt": REPORT_WRITER_SYSTEM_PROMPT,
|
||||
"tools": selected_tools,
|
||||
"middleware": _wrap_with_subagent_essentials(
|
||||
[permission_mw],
|
||||
agent_tools=selected_tools,
|
||||
extra_middleware=extra_middleware,
|
||||
),
|
||||
"skills": default_skills_sources(),
|
||||
}
|
||||
if model is not None:
|
||||
spec["model"] = model
|
||||
return spec # type: ignore[return-value]
|
||||
|
||||
|
||||
def build_connector_negotiator_subagent(
|
||||
*,
|
||||
tools: Sequence[BaseTool],
|
||||
model: BaseChatModel | None = None,
|
||||
extra_middleware: Sequence[Any] | None = None,
|
||||
) -> SubAgent:
|
||||
"""Build the ``connector_negotiator`` subagent spec.
|
||||
|
||||
Inherits all MCP / connector tools the parent has plus
|
||||
``get_connected_accounts``. Read-only by default; permission rules deny
|
||||
write/mutation patterns. The parent agent re-asks for permission if a
|
||||
connector mutation is genuinely needed.
|
||||
|
||||
Pass ``extra_middleware`` (typically the parent's filesystem + todo
|
||||
middleware) so this subagent shares the parent's filesystem view when
|
||||
citing evidence across hops.
|
||||
"""
|
||||
parent_tool_names = {t.name for t in tools}
|
||||
allowed: set[str] = set()
|
||||
if "get_connected_accounts" in parent_tool_names:
|
||||
allowed.add("get_connected_accounts")
|
||||
# Inherit anything that smells connector- or MCP-related but is not a
|
||||
# bulk-write API. Heuristic: keep all parent tools; rely on the deny
|
||||
# ruleset to block mutation patterns. This mirrors the plan: "all
|
||||
# MCP/connector tools the parent has".
|
||||
for name in parent_tool_names:
|
||||
allowed.add(name)
|
||||
selected_tools = _filter_tools(tools, allowed)
|
||||
|
||||
deny_rules = _read_only_deny_rules()
|
||||
permission_mw = _build_permission_middleware(
|
||||
deny_rules, origin="subagent_connector_negotiator"
|
||||
)
|
||||
|
||||
spec: dict = {
|
||||
"name": "connector_negotiator",
|
||||
"description": (
|
||||
"Coordinate read-only chains across connectors (Slack → Linear, "
|
||||
"Notion → Gmail, etc.). Returns a structured summary with the "
|
||||
"evidence chain. Cannot mutate connector state."
|
||||
),
|
||||
"system_prompt": CONNECTOR_NEGOTIATOR_SYSTEM_PROMPT,
|
||||
"tools": selected_tools,
|
||||
"middleware": _wrap_with_subagent_essentials(
|
||||
[permission_mw],
|
||||
agent_tools=selected_tools,
|
||||
extra_middleware=extra_middleware,
|
||||
),
|
||||
"skills": default_skills_sources(),
|
||||
}
|
||||
if model is not None:
|
||||
spec["model"] = model
|
||||
return spec # type: ignore[return-value]
|
||||
|
||||
|
||||
def build_specialized_subagents(
|
||||
*,
|
||||
tools: Sequence[BaseTool],
|
||||
model: BaseChatModel | None = None,
|
||||
extra_middleware: Sequence[Any] | None = None,
|
||||
) -> list[SubAgent]:
|
||||
"""Return the canonical list of specialized subagents to register.
|
||||
|
||||
Order matters only for the order they appear in the ``task`` tool
|
||||
description — most useful first.
|
||||
"""
|
||||
return [
|
||||
build_explore_subagent(
|
||||
tools=tools, model=model, extra_middleware=extra_middleware
|
||||
),
|
||||
build_report_writer_subagent(
|
||||
tools=tools, model=model, extra_middleware=extra_middleware
|
||||
),
|
||||
build_connector_negotiator_subagent(
|
||||
tools=tools, model=model, extra_middleware=extra_middleware
|
||||
),
|
||||
]
|
||||
File diff suppressed because it is too large
Load diff
52
surfsense_backend/app/agents/new_chat/tools/invalid_tool.py
Normal file
52
surfsense_backend/app/agents/new_chat/tools/invalid_tool.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
"""
|
||||
The ``invalid`` fallback tool.
|
||||
|
||||
When the model emits a tool call whose name doesn't match any registered
|
||||
tool, :class:`ToolCallNameRepairMiddleware` rewrites the call to ``invalid``
|
||||
with the original name and a parser/validation error string. This tool's
|
||||
execution then returns that error to the model so it can self-correct.
|
||||
|
||||
Mirrors ``opencode/packages/opencode/src/tool/invalid.ts``. Tier 1.6 in
|
||||
the OpenCode-port plan.
|
||||
|
||||
Critically, the :class:`ToolDefinition` for this tool is **excluded** from
|
||||
the system-prompt tool list and from ``LLMToolSelectorMiddleware`` selection
|
||||
(see ``ToolDefinition.always_include`` filtering in the registry) — the
|
||||
model never advertises ``invalid`` as a callable. It only ever shows up
|
||||
in the tool registry so LangGraph can dispatch the rewritten call.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
INVALID_TOOL_NAME = "invalid"
|
||||
INVALID_TOOL_DESCRIPTION = "Do not use"
|
||||
|
||||
|
||||
def _format_invalid_message(tool: str | None, error: str | None) -> str:
|
||||
"""Return the user-visible error string. Mirrors ``invalid.ts``."""
|
||||
name = tool or "<unknown>"
|
||||
detail = error or "(no error message provided)"
|
||||
return (
|
||||
f"The arguments provided to the tool `{name}` are invalid: {detail}\n"
|
||||
f"Read the tool's docstring carefully and try again with valid arguments."
|
||||
)
|
||||
|
||||
|
||||
@tool(name_or_callable=INVALID_TOOL_NAME, description=INVALID_TOOL_DESCRIPTION)
|
||||
def invalid_tool(tool: str | None = None, error: str | None = None) -> str:
|
||||
"""Return a human-readable explanation of a tool-call validation failure.
|
||||
|
||||
Activated only when :class:`ToolCallNameRepairMiddleware` rewrites a
|
||||
failed tool call to ``invalid`` with the original tool name and the
|
||||
error message produced during validation.
|
||||
"""
|
||||
return _format_invalid_message(tool, error)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"INVALID_TOOL_DESCRIPTION",
|
||||
"INVALID_TOOL_NAME",
|
||||
"invalid_tool",
|
||||
]
|
||||
|
|
@ -43,6 +43,9 @@ from typing import Any
|
|||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.new_chat.middleware.dedup_tool_calls import (
|
||||
wrap_dedup_key_by_arg_name,
|
||||
)
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from .confluence import (
|
||||
|
|
@ -125,6 +128,14 @@ class ToolDefinition:
|
|||
enabled_by_default: Whether the tool is enabled when no explicit config is provided
|
||||
required_connector: Searchable type string (e.g. ``"LINEAR_CONNECTOR"``)
|
||||
that must be in ``available_connectors`` for the tool to be enabled.
|
||||
dedup_key: Optional callable that maps a tool's ``args`` dict to a
|
||||
string signature used by :class:`DedupHITLToolCallsMiddleware`
|
||||
to drop duplicate calls. Replaces the legacy hardcoded
|
||||
``_NATIVE_HITL_TOOL_DEDUP_KEYS`` map (Tier 2.3 in the
|
||||
OpenCode-port plan).
|
||||
reverse: Optional callable that, given the tool's ``(args, result)``,
|
||||
returns a ``ReverseDescriptor`` describing the inverse tool
|
||||
invocation. Consumed by the snapshot/revert pipeline (Tier 5).
|
||||
|
||||
"""
|
||||
|
||||
|
|
@ -135,6 +146,8 @@ class ToolDefinition:
|
|||
enabled_by_default: bool = True
|
||||
hidden: bool = False
|
||||
required_connector: str | None = None
|
||||
dedup_key: Callable[[dict[str, Any]], str] | None = None
|
||||
reverse: Callable[[dict[str, Any], Any], dict[str, Any]] | None = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -288,6 +301,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="NOTION_CONNECTOR",
|
||||
dedup_key=wrap_dedup_key_by_arg_name("title"),
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_notion_page",
|
||||
|
|
@ -299,6 +313,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="NOTION_CONNECTOR",
|
||||
dedup_key=wrap_dedup_key_by_arg_name("page_title"),
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_notion_page",
|
||||
|
|
@ -310,6 +325,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="NOTION_CONNECTOR",
|
||||
dedup_key=wrap_dedup_key_by_arg_name("page_title"),
|
||||
),
|
||||
# =========================================================================
|
||||
# GOOGLE DRIVE TOOLS - create files, delete files
|
||||
|
|
@ -325,6 +341,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_DRIVE_FILE",
|
||||
dedup_key=wrap_dedup_key_by_arg_name("file_name"),
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_google_drive_file",
|
||||
|
|
@ -336,6 +353,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_DRIVE_FILE",
|
||||
dedup_key=wrap_dedup_key_by_arg_name("file_name"),
|
||||
),
|
||||
# =========================================================================
|
||||
# DROPBOX TOOLS - create and trash files
|
||||
|
|
@ -351,6 +369,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="DROPBOX_FILE",
|
||||
dedup_key=wrap_dedup_key_by_arg_name("file_name"),
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_dropbox_file",
|
||||
|
|
@ -362,6 +381,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="DROPBOX_FILE",
|
||||
dedup_key=wrap_dedup_key_by_arg_name("file_name"),
|
||||
),
|
||||
# =========================================================================
|
||||
# ONEDRIVE TOOLS - create and trash files
|
||||
|
|
@ -377,6 +397,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="ONEDRIVE_FILE",
|
||||
dedup_key=wrap_dedup_key_by_arg_name("file_name"),
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_onedrive_file",
|
||||
|
|
@ -388,6 +409,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="ONEDRIVE_FILE",
|
||||
dedup_key=wrap_dedup_key_by_arg_name("file_name"),
|
||||
),
|
||||
# =========================================================================
|
||||
# GOOGLE CALENDAR TOOLS - search, create, update, delete events
|
||||
|
|
@ -414,6 +436,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_CALENDAR_CONNECTOR",
|
||||
dedup_key=wrap_dedup_key_by_arg_name("title"),
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_calendar_event",
|
||||
|
|
@ -425,6 +448,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_CALENDAR_CONNECTOR",
|
||||
dedup_key=wrap_dedup_key_by_arg_name("event_title_or_id"),
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_calendar_event",
|
||||
|
|
@ -436,6 +460,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_CALENDAR_CONNECTOR",
|
||||
dedup_key=wrap_dedup_key_by_arg_name("event_title_or_id"),
|
||||
),
|
||||
# =========================================================================
|
||||
# GMAIL TOOLS - search, read, create drafts, update drafts, send, trash
|
||||
|
|
@ -473,6 +498,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
dedup_key=wrap_dedup_key_by_arg_name("subject"),
|
||||
),
|
||||
ToolDefinition(
|
||||
name="send_gmail_email",
|
||||
|
|
@ -484,6 +510,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
dedup_key=wrap_dedup_key_by_arg_name("subject"),
|
||||
),
|
||||
ToolDefinition(
|
||||
name="trash_gmail_email",
|
||||
|
|
@ -495,6 +522,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
dedup_key=wrap_dedup_key_by_arg_name("email_subject_or_id"),
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_gmail_draft",
|
||||
|
|
@ -506,6 +534,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
dedup_key=wrap_dedup_key_by_arg_name("draft_subject_or_id"),
|
||||
),
|
||||
# =========================================================================
|
||||
# CONFLUENCE TOOLS - create, update, delete pages
|
||||
|
|
@ -521,6 +550,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="CONFLUENCE_CONNECTOR",
|
||||
dedup_key=wrap_dedup_key_by_arg_name("title"),
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_confluence_page",
|
||||
|
|
@ -532,6 +562,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="CONFLUENCE_CONNECTOR",
|
||||
dedup_key=wrap_dedup_key_by_arg_name("page_title_or_id"),
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_confluence_page",
|
||||
|
|
@ -543,6 +574,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="CONFLUENCE_CONNECTOR",
|
||||
dedup_key=wrap_dedup_key_by_arg_name("page_title_or_id"),
|
||||
),
|
||||
# =========================================================================
|
||||
# DISCORD TOOLS - list channels, read messages, send messages
|
||||
|
|
@ -755,6 +787,24 @@ def build_tools(
|
|||
|
||||
# Create the tool
|
||||
tool = tool_def.factory(dependencies)
|
||||
# Propagate the registry-level metadata so middleware (e.g.
|
||||
# ``DedupHITLToolCallsMiddleware``) and the action-log/revert
|
||||
# pipeline can pick the resolvers up via ``tool.metadata`` without
|
||||
# re-importing :data:`BUILTIN_TOOLS`.
|
||||
if tool_def.dedup_key is not None or tool_def.reverse is not None:
|
||||
existing_meta = getattr(tool, "metadata", None) or {}
|
||||
merged_meta = dict(existing_meta)
|
||||
if tool_def.dedup_key is not None:
|
||||
merged_meta.setdefault("dedup_key", tool_def.dedup_key)
|
||||
if tool_def.reverse is not None:
|
||||
merged_meta.setdefault("reverse", tool_def.reverse)
|
||||
try:
|
||||
tool.metadata = merged_meta
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Tool %s rejected metadata mutation; relying on registry lookup",
|
||||
tool_def.name,
|
||||
)
|
||||
tools.append(tool)
|
||||
|
||||
# Add any additional custom tools
|
||||
|
|
|
|||
|
|
@ -2250,6 +2250,202 @@ else:
|
|||
)
|
||||
|
||||
|
||||
class AgentActionLog(BaseModel):
|
||||
"""Append-only audit trail of every tool call dispatched by the agent.
|
||||
|
||||
One row per ``ToolMessage`` produced; written by ``ActionLogMiddleware``
|
||||
in its ``aafter_tool`` hook. Rows are referenced by the
|
||||
``/api/threads/{thread_id}/revert/{action_id}`` route to look up an
|
||||
action's stored ``reverse_descriptor`` and replay it.
|
||||
|
||||
The table is intentionally narrow: large tool outputs are NOT stored
|
||||
here. Result text lives in the langgraph checkpoint; this row only
|
||||
keeps a short ``result_id`` (the LangChain ``ToolMessage.id`` or a
|
||||
spilled-content path) for correlation.
|
||||
"""
|
||||
|
||||
__tablename__ = "agent_action_log"
|
||||
|
||||
thread_id = Column(
|
||||
Integer,
|
||||
ForeignKey("new_chat_threads.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
search_space_id = Column(
|
||||
Integer,
|
||||
ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
turn_id = Column(String(64), nullable=True, index=True)
|
||||
message_id = Column(String(128), nullable=True, index=True)
|
||||
tool_name = Column(String(255), nullable=False, index=True)
|
||||
args = Column(JSONB, nullable=True)
|
||||
result_id = Column(String(255), nullable=True)
|
||||
reversible = Column(
|
||||
Boolean, nullable=False, default=False, server_default=text("false")
|
||||
)
|
||||
reverse_descriptor = Column(JSONB, nullable=True)
|
||||
error = Column(JSONB, nullable=True)
|
||||
reverse_of = Column(
|
||||
Integer,
|
||||
ForeignKey("agent_action_log.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
created_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(UTC),
|
||||
server_default=text("(now() AT TIME ZONE 'utc')"),
|
||||
index=True,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_agent_action_log_thread_created", "thread_id", "created_at"),
|
||||
)
|
||||
|
||||
|
||||
class DocumentRevision(BaseModel):
|
||||
"""Snapshot of a :class:`Document` row taken before a mutating tool call.
|
||||
|
||||
Written by :class:`KnowledgeBasePersistenceMiddleware` (or its safety-net
|
||||
`commit_staged_filesystem_state`) ahead of any NOTE / FILE / EXTENSION
|
||||
document write. The row is referenced by ``/revert/{action_id}`` to
|
||||
restore the original content in place.
|
||||
"""
|
||||
|
||||
__tablename__ = "document_revisions"
|
||||
|
||||
document_id = Column(
|
||||
Integer,
|
||||
ForeignKey("documents.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
search_space_id = Column(
|
||||
Integer,
|
||||
ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
content_before = Column(Text, nullable=True)
|
||||
title_before = Column(String, nullable=True)
|
||||
folder_id_before = Column(Integer, nullable=True)
|
||||
chunks_before = Column(JSONB, nullable=True)
|
||||
metadata_before = Column("metadata_before", JSONB, nullable=True)
|
||||
created_by_turn_id = Column(String(64), nullable=True, index=True)
|
||||
agent_action_id = Column(
|
||||
Integer,
|
||||
ForeignKey("agent_action_log.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
created_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(UTC),
|
||||
server_default=text("(now() AT TIME ZONE 'utc')"),
|
||||
index=True,
|
||||
)
|
||||
|
||||
|
||||
class FolderRevision(BaseModel):
|
||||
"""Snapshot of a :class:`Folder` row taken before a mkdir / move."""
|
||||
|
||||
__tablename__ = "folder_revisions"
|
||||
|
||||
folder_id = Column(
|
||||
Integer,
|
||||
ForeignKey("folders.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
search_space_id = Column(
|
||||
Integer,
|
||||
ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
name_before = Column(String(255), nullable=True)
|
||||
parent_id_before = Column(Integer, nullable=True)
|
||||
position_before = Column(String(50), nullable=True)
|
||||
created_by_turn_id = Column(String(64), nullable=True, index=True)
|
||||
agent_action_id = Column(
|
||||
Integer,
|
||||
ForeignKey("agent_action_log.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
created_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(UTC),
|
||||
server_default=text("(now() AT TIME ZONE 'utc')"),
|
||||
index=True,
|
||||
)
|
||||
|
||||
|
||||
class AgentPermissionRule(BaseModel):
|
||||
"""Persistent permission rule consumed by :class:`PermissionMiddleware`.
|
||||
|
||||
Scoped at one of: search-space-wide (``user_id`` and ``thread_id`` NULL),
|
||||
user-wide (``user_id`` set, ``thread_id`` NULL), or per-thread
|
||||
(``thread_id`` set). Loaded at agent build time and converted to
|
||||
:class:`Rule` instances inside the agent factory.
|
||||
"""
|
||||
|
||||
__tablename__ = "agent_permission_rules"
|
||||
|
||||
search_space_id = Column(
|
||||
Integer,
|
||||
ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
thread_id = Column(
|
||||
Integer,
|
||||
ForeignKey("new_chat_threads.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
permission = Column(String(255), nullable=False)
|
||||
pattern = Column(String(255), nullable=False, default="*", server_default="*")
|
||||
action = Column(String(16), nullable=False) # allow / deny / ask
|
||||
created_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(UTC),
|
||||
server_default=text("(now() AT TIME ZONE 'utc')"),
|
||||
index=True,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"search_space_id",
|
||||
"user_id",
|
||||
"thread_id",
|
||||
"permission",
|
||||
"pattern",
|
||||
"action",
|
||||
name="uq_agent_permission_rules_scope",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class RefreshToken(Base, TimestampMixin):
|
||||
"""
|
||||
Stores refresh tokens for user session management.
|
||||
|
|
|
|||
7
surfsense_backend/app/observability/__init__.py
Normal file
7
surfsense_backend/app/observability/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
"""SurfSense observability surface.
|
||||
|
||||
The single user-visible API right now is :mod:`otel`, which exposes a
|
||||
small wrapper around the optional ``opentelemetry`` instrumentation. The
|
||||
wrapper is a no-op when OTEL is not configured, so importing it from
|
||||
performance-critical paths is safe.
|
||||
"""
|
||||
319
surfsense_backend/app/observability/otel.py
Normal file
319
surfsense_backend/app/observability/otel.py
Normal file
|
|
@ -0,0 +1,319 @@
|
|||
"""
|
||||
OpenTelemetry instrumentation helpers for the SurfSense agent stack.
|
||||
|
||||
Tier 3b in the OpenCode-port plan.
|
||||
|
||||
Goals
|
||||
=====
|
||||
|
||||
- Provide one tiny, ergonomic API for the spans listed in the plan
|
||||
(``tool.call``, ``model.call``, ``kb.search``, ``kb.persist``,
|
||||
``compaction.run``, ``interrupt.raised``, ``permission.asked``).
|
||||
- Keep span **names** low-cardinality (``tool.call`` rather than
|
||||
``tool.call.<name>``); tool name lives in the ``tool.name`` attribute
|
||||
so dashboards aggregate cleanly.
|
||||
- Default to **no-op** behavior unless ``OTEL_EXPORTER_OTLP_ENDPOINT`` is
|
||||
set, OR an external SDK has installed a real ``TracerProvider`` already
|
||||
(e.g. via the ``opentelemetry-instrument`` agent).
|
||||
- Coexist with LangSmith: we never disable LangSmith tracing; we add OTel
|
||||
alongside.
|
||||
- Gracefully degrade if the ``opentelemetry-api`` package is missing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Lazy/optional OpenTelemetry import
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
from opentelemetry import trace as _ot_trace
|
||||
from opentelemetry.trace import (
|
||||
Span as _OtSpan,
|
||||
Status as _OtStatus,
|
||||
StatusCode as _OtStatusCode,
|
||||
)
|
||||
|
||||
_OTEL_AVAILABLE = True
|
||||
except ImportError: # pragma: no cover — optional dep
|
||||
_ot_trace = None # type: ignore[assignment]
|
||||
_OtSpan = Any # type: ignore[assignment, misc]
|
||||
_OtStatus = Any # type: ignore[assignment, misc]
|
||||
_OtStatusCode = Any # type: ignore[assignment, misc]
|
||||
_OTEL_AVAILABLE = False
|
||||
|
||||
|
||||
_INSTRUMENTATION_NAME = "surfsense.new_chat"
|
||||
_INSTRUMENTATION_VERSION = "0.1.0"
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_enabled() -> bool:
|
||||
"""Return True if OTel spans should actually be emitted."""
|
||||
if not _OTEL_AVAILABLE:
|
||||
return False
|
||||
# Honor an explicit kill-switch first.
|
||||
if os.environ.get("SURFSENSE_DISABLE_OTEL", "").lower() in {"1", "true", "yes"}:
|
||||
return False
|
||||
# Treat a configured endpoint as the canonical "OTel is wired up" signal.
|
||||
if os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"):
|
||||
return True
|
||||
# Or honor an external SDK that already installed a non-default TracerProvider.
|
||||
if _ot_trace is not None:
|
||||
try:
|
||||
provider = _ot_trace.get_tracer_provider()
|
||||
# The default proxy provider has no real exporter wired up.
|
||||
type_name = type(provider).__name__
|
||||
if type_name not in {"ProxyTracerProvider", "NoOpTracerProvider"}:
|
||||
return True
|
||||
except Exception: # pragma: no cover — defensive
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
_ENABLED: bool = _resolve_enabled()
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
"""Return True if instrumentation is actively emitting spans."""
|
||||
return _ENABLED
|
||||
|
||||
|
||||
def _get_tracer():
|
||||
if not _OTEL_AVAILABLE:
|
||||
return None
|
||||
try:
|
||||
return _ot_trace.get_tracer(_INSTRUMENTATION_NAME, _INSTRUMENTATION_VERSION)
|
||||
except Exception: # pragma: no cover — defensive
|
||||
return None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# No-op span used when OTel is disabled (avoids a None check at every call site)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _NoopSpan:
|
||||
"""A lightweight stand-in that mimics the subset of ``Span`` we use."""
|
||||
|
||||
def set_attribute(self, key: str, value: Any) -> None:
|
||||
return None
|
||||
|
||||
def set_attributes(self, attributes: dict[str, Any]) -> None:
|
||||
return None
|
||||
|
||||
def add_event(self, name: str, attributes: dict[str, Any] | None = None) -> None:
|
||||
return None
|
||||
|
||||
def record_exception(self, exception: BaseException) -> None:
|
||||
return None
|
||||
|
||||
def set_status(self, status: Any) -> None:
|
||||
return None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Public span helpers
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextmanager
|
||||
def span(
|
||||
name: str,
|
||||
*,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
) -> Iterator[Any]:
|
||||
"""Generic span context manager.
|
||||
|
||||
Yields the underlying span (or a :class:`_NoopSpan` when disabled)
|
||||
so callers can attach attributes/events incrementally.
|
||||
|
||||
On exception, the span records the error via :meth:`record_exception`
|
||||
and sets ``StatusCode.ERROR``; the exception is then re-raised.
|
||||
"""
|
||||
if not _ENABLED:
|
||||
yield _NoopSpan()
|
||||
return
|
||||
|
||||
tracer = _get_tracer()
|
||||
if tracer is None: # pragma: no cover — defensive
|
||||
yield _NoopSpan()
|
||||
return
|
||||
|
||||
with tracer.start_as_current_span(name) as sp:
|
||||
if attributes:
|
||||
try:
|
||||
sp.set_attributes(attributes)
|
||||
except Exception: # pragma: no cover — defensive
|
||||
pass
|
||||
try:
|
||||
yield sp
|
||||
except BaseException as exc:
|
||||
try:
|
||||
sp.record_exception(exc)
|
||||
sp.set_status(_OtStatus(_OtStatusCode.ERROR, str(exc)))
|
||||
except Exception: # pragma: no cover — defensive
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Domain-specific shortcuts (mirror the plan's enumerated span list)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def tool_call_span(
|
||||
tool_name: str,
|
||||
*,
|
||||
input_size: int | None = None,
|
||||
extra: dict[str, Any] | None = None,
|
||||
):
|
||||
"""Span for an individual tool execution.
|
||||
|
||||
Span name is the constant ``tool.call`` (low-cardinality); the tool
|
||||
identifier lives in the ``tool.name`` attribute.
|
||||
"""
|
||||
attrs: dict[str, Any] = {"tool.name": tool_name}
|
||||
if input_size is not None:
|
||||
attrs["tool.input.size"] = int(input_size)
|
||||
if extra:
|
||||
attrs.update(extra)
|
||||
return span("tool.call", attributes=attrs)
|
||||
|
||||
|
||||
def model_call_span(
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
provider: str | None = None,
|
||||
extra: dict[str, Any] | None = None,
|
||||
):
|
||||
"""Span around a single ``astream`` / ``ainvoke`` call to the LLM."""
|
||||
attrs: dict[str, Any] = {}
|
||||
if model_id:
|
||||
attrs["model.id"] = model_id
|
||||
if provider:
|
||||
attrs["model.provider"] = provider
|
||||
if extra:
|
||||
attrs.update(extra)
|
||||
return span("model.call", attributes=attrs)
|
||||
|
||||
|
||||
def kb_search_span(
|
||||
*,
|
||||
search_space_id: int | None = None,
|
||||
query_chars: int | None = None,
|
||||
extra: dict[str, Any] | None = None,
|
||||
):
|
||||
"""Span around knowledge-base search routines."""
|
||||
attrs: dict[str, Any] = {}
|
||||
if search_space_id is not None:
|
||||
attrs["search_space.id"] = int(search_space_id)
|
||||
if query_chars is not None:
|
||||
attrs["query.chars"] = int(query_chars)
|
||||
if extra:
|
||||
attrs.update(extra)
|
||||
return span("kb.search", attributes=attrs)
|
||||
|
||||
|
||||
def kb_persist_span(
|
||||
*,
|
||||
document_type: str | None = None,
|
||||
document_id: int | None = None,
|
||||
extra: dict[str, Any] | None = None,
|
||||
):
|
||||
"""Span around knowledge-base persistence operations (NOTE/EXTENSION/FILE)."""
|
||||
attrs: dict[str, Any] = {}
|
||||
if document_type:
|
||||
attrs["document.type"] = document_type
|
||||
if document_id is not None:
|
||||
attrs["document.id"] = int(document_id)
|
||||
if extra:
|
||||
attrs.update(extra)
|
||||
return span("kb.persist", attributes=attrs)
|
||||
|
||||
|
||||
def compaction_span(
|
||||
*,
|
||||
reason: str | None = None,
|
||||
messages_in: int | None = None,
|
||||
extra: dict[str, Any] | None = None,
|
||||
):
|
||||
"""Span around the compaction (summarization) middleware run."""
|
||||
attrs: dict[str, Any] = {}
|
||||
if reason:
|
||||
attrs["compaction.reason"] = reason
|
||||
if messages_in is not None:
|
||||
attrs["compaction.messages.in"] = int(messages_in)
|
||||
if extra:
|
||||
attrs.update(extra)
|
||||
return span("compaction.run", attributes=attrs)
|
||||
|
||||
|
||||
def interrupt_span(
|
||||
*,
|
||||
interrupt_type: str,
|
||||
extra: dict[str, Any] | None = None,
|
||||
):
|
||||
"""Span recording an interrupt being raised (HITL or permission_ask)."""
|
||||
attrs: dict[str, Any] = {"interrupt.type": interrupt_type}
|
||||
if extra:
|
||||
attrs.update(extra)
|
||||
return span("interrupt.raised", attributes=attrs)
|
||||
|
||||
|
||||
def permission_asked_span(
|
||||
*,
|
||||
permission: str,
|
||||
pattern: str | None = None,
|
||||
extra: dict[str, Any] | None = None,
|
||||
):
|
||||
"""Span recording a permission ask (PermissionMiddleware)."""
|
||||
attrs: dict[str, Any] = {"permission.permission": permission}
|
||||
if pattern:
|
||||
attrs["permission.pattern"] = pattern
|
||||
if extra:
|
||||
attrs.update(extra)
|
||||
return span("permission.asked", attributes=attrs)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Test/utility hooks
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def reload_for_tests() -> bool:
|
||||
"""Re-evaluate :data:`_ENABLED` from the current environment.
|
||||
|
||||
Tests that toggle ``OTEL_EXPORTER_OTLP_ENDPOINT`` or
|
||||
``SURFSENSE_DISABLE_OTEL`` can call this to reset cached state.
|
||||
Returns the new value of :func:`is_enabled`.
|
||||
"""
|
||||
global _ENABLED
|
||||
_ENABLED = _resolve_enabled()
|
||||
return _ENABLED
|
||||
|
||||
|
||||
__all__ = [
|
||||
"compaction_span",
|
||||
"interrupt_span",
|
||||
"is_enabled",
|
||||
"kb_persist_span",
|
||||
"kb_search_span",
|
||||
"model_call_span",
|
||||
"permission_asked_span",
|
||||
"reload_for_tests",
|
||||
"span",
|
||||
"tool_call_span",
|
||||
]
|
||||
|
|
@ -1,5 +1,9 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
from .agent_action_log_route import router as agent_action_log_router
|
||||
from .agent_flags_route import router as agent_flags_router
|
||||
from .agent_permissions_route import router as agent_permissions_router
|
||||
from .agent_revert_route import router as agent_revert_router
|
||||
from .airtable_add_connector_route import (
|
||||
router as airtable_add_connector_router,
|
||||
)
|
||||
|
|
@ -66,6 +70,12 @@ router.include_router(documents_router)
|
|||
router.include_router(folders_router)
|
||||
router.include_router(notes_router)
|
||||
router.include_router(new_chat_router) # Chat with assistant-ui persistence
|
||||
router.include_router(agent_revert_router) # POST /threads/{id}/revert/{action_id}
|
||||
router.include_router(agent_action_log_router) # GET /threads/{id}/actions
|
||||
router.include_router(
|
||||
agent_permissions_router
|
||||
) # CRUD for /searchspaces/{id}/agent/permissions/rules
|
||||
router.include_router(agent_flags_router) # GET /agent/flags
|
||||
router.include_router(sandbox_router) # Sandbox file downloads (Daytona)
|
||||
router.include_router(chat_comments_router)
|
||||
router.include_router(podcasts_router) # Podcast task status and audio
|
||||
|
|
|
|||
186
surfsense_backend/app/routes/agent_action_log_route.py
Normal file
186
surfsense_backend/app/routes/agent_action_log_route.py
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
"""``GET /api/threads/{thread_id}/actions``: list agent action-log entries.
|
||||
|
||||
Pairs with ``POST /api/threads/{thread_id}/revert/{action_id}`` (see
|
||||
``agent_revert_route.py``). The action log is the read-side surface for
|
||||
the audit/undo UI: it returns a paginated list of every tool call
|
||||
recorded by :class:`ActionLogMiddleware` against the thread, plus
|
||||
metadata about whether the action is reversible and whether it has
|
||||
already been reverted.
|
||||
|
||||
The route is gated by the same ``SURFSENSE_ENABLE_ACTION_LOG`` flag that
|
||||
controls the middleware. When the flag is off the endpoint returns 503
|
||||
so the UI can detect "this deployment doesn't have the action log
|
||||
enabled" without 404-ing on a missing route.
|
||||
|
||||
The list is ordered DESC by ``created_at`` (newest first) so the
|
||||
revert UI can render a familiar reverse-chronological feed without an
|
||||
additional client-side sort.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.feature_flags import get_flags
|
||||
from app.db import (
|
||||
AgentActionLog,
|
||||
NewChatThread,
|
||||
Permission,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AgentActionRead(BaseModel):
|
||||
"""One row of the action log surfaced to the client."""
|
||||
|
||||
id: int
|
||||
thread_id: int
|
||||
user_id: str | None
|
||||
search_space_id: int
|
||||
tool_name: str
|
||||
args: dict[str, Any] | None
|
||||
result_id: str | None
|
||||
reversible: bool
|
||||
reverse_descriptor: dict[str, Any] | None
|
||||
error: dict[str, Any] | None
|
||||
reverse_of: int | None
|
||||
reverted_by_action_id: int | None
|
||||
is_revert_action: bool
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class AgentActionListResponse(BaseModel):
|
||||
"""Paginated list response for the action log."""
|
||||
|
||||
items: list[AgentActionRead]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _flag_guard() -> None:
|
||||
flags = get_flags()
|
||||
if flags.disable_new_agent_stack or not flags.enable_action_log:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"Action log is not available on this deployment. Flip "
|
||||
"SURFSENSE_ENABLE_ACTION_LOG to enable it."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/threads/{thread_id}/actions",
|
||||
response_model=AgentActionListResponse,
|
||||
)
|
||||
async def list_thread_actions(
|
||||
thread_id: int,
|
||||
page: int = Query(0, ge=0),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
) -> AgentActionListResponse:
|
||||
"""List agent actions for a thread, newest first.
|
||||
|
||||
Authorization:
|
||||
* Caller must be a member of the thread's search space with
|
||||
``CHATS_READ`` permission.
|
||||
|
||||
Pagination:
|
||||
* ``page`` is 0-indexed.
|
||||
* ``page_size`` defaults to 50, max 200.
|
||||
"""
|
||||
|
||||
_flag_guard()
|
||||
|
||||
thread = await session.get(NewChatThread, thread_id)
|
||||
if thread is None:
|
||||
raise HTTPException(status_code=404, detail="Thread not found.")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to view this thread's action log.",
|
||||
)
|
||||
|
||||
total_stmt = select(func.count(AgentActionLog.id)).where(
|
||||
AgentActionLog.thread_id == thread_id
|
||||
)
|
||||
total = (await session.execute(total_stmt)).scalar_one()
|
||||
|
||||
rows_stmt = (
|
||||
select(AgentActionLog)
|
||||
.where(AgentActionLog.thread_id == thread_id)
|
||||
.order_by(AgentActionLog.created_at.desc(), AgentActionLog.id.desc())
|
||||
.offset(page * page_size)
|
||||
.limit(page_size)
|
||||
)
|
||||
rows = (await session.execute(rows_stmt)).scalars().all()
|
||||
|
||||
# Build a reverse_of -> revert_action_id map so the UI can render
|
||||
# "Reverted" badges on actions that have already been undone.
|
||||
if rows:
|
||||
original_ids = [r.id for r in rows]
|
||||
reverts_stmt = select(AgentActionLog.id, AgentActionLog.reverse_of).where(
|
||||
AgentActionLog.reverse_of.in_(original_ids)
|
||||
)
|
||||
reverts = (await session.execute(reverts_stmt)).all()
|
||||
revert_map: dict[int, int] = {orig: rev for rev, orig in reverts}
|
||||
else:
|
||||
revert_map = {}
|
||||
|
||||
items = [
|
||||
AgentActionRead(
|
||||
id=row.id,
|
||||
thread_id=row.thread_id,
|
||||
user_id=str(row.user_id) if row.user_id is not None else None,
|
||||
search_space_id=row.search_space_id,
|
||||
tool_name=row.tool_name,
|
||||
args=row.args,
|
||||
result_id=row.result_id,
|
||||
reversible=bool(row.reversible),
|
||||
reverse_descriptor=row.reverse_descriptor,
|
||||
error=row.error,
|
||||
reverse_of=row.reverse_of,
|
||||
reverted_by_action_id=revert_map.get(row.id),
|
||||
is_revert_action=row.reverse_of is not None,
|
||||
created_at=row.created_at,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
return AgentActionListResponse(
|
||||
items=items,
|
||||
total=int(total),
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
has_more=(page + 1) * page_size < int(total),
|
||||
)
|
||||
71
surfsense_backend/app/routes/agent_flags_route.py
Normal file
71
surfsense_backend/app/routes/agent_flags_route.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
"""``GET /api/agent/flags``: read-only feature-flag status.
|
||||
|
||||
Surfaces :class:`AgentFeatureFlags` to the frontend so the UI can:
|
||||
|
||||
* Render conditional surfaces (e.g. show the action-log button only when
|
||||
``enable_action_log`` is on).
|
||||
* Display an admin diagnostics card so operators can verify which
|
||||
middleware tier is active without shelling into the box.
|
||||
|
||||
The endpoint is *read-only*. Flipping flags requires an env-var change
|
||||
plus a process restart — by design, since the values are baked into the
|
||||
agent factory at build time. The route does not require any special
|
||||
permission (any authenticated user can see them) since the flag values
|
||||
do not leak data, and the UI surfaces are conditionally rendered based
|
||||
on them anyway.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.agents.new_chat.feature_flags import AgentFeatureFlags, get_flags
|
||||
from app.db import User
|
||||
from app.users import current_active_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class AgentFeatureFlagsRead(BaseModel):
|
||||
"""Mirror of :class:`AgentFeatureFlags`. Updated together with it."""
|
||||
|
||||
disable_new_agent_stack: bool
|
||||
|
||||
enable_context_editing: bool
|
||||
enable_compaction_v2: bool
|
||||
enable_retry_after: bool
|
||||
enable_model_fallback: bool
|
||||
enable_model_call_limit: bool
|
||||
enable_tool_call_limit: bool
|
||||
enable_tool_call_repair: bool
|
||||
enable_doom_loop: bool
|
||||
|
||||
enable_permission: bool
|
||||
enable_busy_mutex: bool
|
||||
enable_llm_tool_selector: bool
|
||||
|
||||
enable_skills: bool
|
||||
enable_specialized_subagents: bool
|
||||
enable_kb_planner_runnable: bool
|
||||
|
||||
enable_action_log: bool
|
||||
enable_revert_route: bool
|
||||
|
||||
enable_plugin_loader: bool
|
||||
|
||||
enable_otel: bool
|
||||
|
||||
@classmethod
|
||||
def from_flags(cls, flags: AgentFeatureFlags) -> "AgentFeatureFlagsRead":
|
||||
# asdict() avoids missing-field bugs when AgentFeatureFlags grows.
|
||||
return cls(**asdict(flags))
|
||||
|
||||
|
||||
@router.get("/agent/flags", response_model=AgentFeatureFlagsRead)
|
||||
async def get_agent_flags(
|
||||
_user: User = Depends(current_active_user),
|
||||
) -> AgentFeatureFlagsRead:
|
||||
return AgentFeatureFlagsRead.from_flags(get_flags())
|
||||
280
surfsense_backend/app/routes/agent_permissions_route.py
Normal file
280
surfsense_backend/app/routes/agent_permissions_route.py
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
"""CRUD for :class:`app.db.AgentPermissionRule`.
|
||||
|
||||
Surfaces the permission rules consumed by
|
||||
:class:`PermissionMiddleware`. Rules are scoped at one of three levels:
|
||||
|
||||
* **Search-space wide** — both ``user_id`` and ``thread_id`` are NULL.
|
||||
* **Per-user** — ``user_id`` set, ``thread_id`` NULL.
|
||||
* **Per-thread** — ``thread_id`` set (``user_id`` typically NULL).
|
||||
|
||||
The middleware reads these rows at agent build time (see
|
||||
``chat_deepagent.py``). UI lets a search-space owner curate them so
|
||||
the agent can ask for approval / auto-deny / auto-allow specific
|
||||
tool patterns.
|
||||
|
||||
The route group is gated by ``SURFSENSE_ENABLE_PERMISSION``: when off
|
||||
all endpoints return 503 so the UI can render a "feature not enabled"
|
||||
empty state without breaking on a missing route.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.feature_flags import get_flags
|
||||
from app.db import (
|
||||
AgentPermissionRule,
|
||||
NewChatThread,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_ACTION_VALUES: tuple[str, ...] = ("allow", "deny", "ask")
|
||||
_PERMISSION_PATTERN = re.compile(r"^[a-zA-Z0-9_:.\-*]+$")
|
||||
|
||||
|
||||
class AgentPermissionRuleRead(BaseModel):
|
||||
id: int
|
||||
search_space_id: int
|
||||
user_id: str | None
|
||||
thread_id: int | None
|
||||
permission: str
|
||||
pattern: str
|
||||
action: Literal["allow", "deny", "ask"]
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class AgentPermissionRuleCreate(BaseModel):
|
||||
permission: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
description="Tool / capability the rule targets, e.g. 'tool:create_linear_issue'.",
|
||||
)
|
||||
pattern: str = Field(
|
||||
"*",
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
description="Wildcard pattern (e.g. '*' or 'production-*') applied to the matched tool argument.",
|
||||
)
|
||||
action: Literal["allow", "deny", "ask"]
|
||||
user_id: str | None = None
|
||||
thread_id: int | None = None
|
||||
|
||||
|
||||
class AgentPermissionRuleUpdate(BaseModel):
|
||||
pattern: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
action: Literal["allow", "deny", "ask"] | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _flag_guard() -> None:
|
||||
flags = get_flags()
|
||||
if flags.disable_new_agent_stack or not flags.enable_permission:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"Agent permission rules are not enabled on this deployment. "
|
||||
"Flip SURFSENSE_ENABLE_PERMISSION to enable them."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _validate_permission_string(value: str) -> str:
|
||||
if not _PERMISSION_PATTERN.match(value):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"permission must contain only alphanumerics, '.', '_', ':', '-', "
|
||||
"or '*' wildcards."
|
||||
),
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def _to_read(row: AgentPermissionRule) -> AgentPermissionRuleRead:
|
||||
return AgentPermissionRuleRead(
|
||||
id=row.id,
|
||||
search_space_id=row.search_space_id,
|
||||
user_id=str(row.user_id) if row.user_id is not None else None,
|
||||
thread_id=row.thread_id,
|
||||
permission=row.permission,
|
||||
pattern=row.pattern,
|
||||
action=row.action, # type: ignore[arg-type]
|
||||
created_at=row.created_at,
|
||||
)
|
||||
|
||||
|
||||
async def _ensure_search_space_membership_admin(
|
||||
session: AsyncSession, user: User, search_space_id: int
|
||||
) -> None:
|
||||
"""Curating agent rules == "settings" administration on the space."""
|
||||
space = await session.get(SearchSpace, search_space_id)
|
||||
if space is None:
|
||||
raise HTTPException(status_code=404, detail="Search space not found.")
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.SETTINGS_UPDATE.value,
|
||||
"You don't have permission to manage agent permission rules in this space.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/searchspaces/{search_space_id}/agent/permissions/rules",
|
||||
response_model=list[AgentPermissionRuleRead],
|
||||
)
|
||||
async def list_rules(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
) -> list[AgentPermissionRuleRead]:
|
||||
_flag_guard()
|
||||
await _ensure_search_space_membership_admin(session, user, search_space_id)
|
||||
|
||||
stmt = (
|
||||
select(AgentPermissionRule)
|
||||
.where(AgentPermissionRule.search_space_id == search_space_id)
|
||||
.order_by(AgentPermissionRule.created_at.desc(), AgentPermissionRule.id.desc())
|
||||
)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
return [_to_read(r) for r in rows]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/searchspaces/{search_space_id}/agent/permissions/rules",
|
||||
response_model=AgentPermissionRuleRead,
|
||||
status_code=201,
|
||||
)
|
||||
async def create_rule(
|
||||
search_space_id: int,
|
||||
payload: AgentPermissionRuleCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
) -> AgentPermissionRuleRead:
|
||||
_flag_guard()
|
||||
await _ensure_search_space_membership_admin(session, user, search_space_id)
|
||||
|
||||
permission = _validate_permission_string(payload.permission.strip())
|
||||
pattern = payload.pattern.strip() or "*"
|
||||
|
||||
if payload.thread_id is not None:
|
||||
thread = await session.get(NewChatThread, payload.thread_id)
|
||||
if thread is None or thread.search_space_id != search_space_id:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Thread not found in this search space.",
|
||||
)
|
||||
|
||||
row = AgentPermissionRule(
|
||||
search_space_id=search_space_id,
|
||||
user_id=payload.user_id,
|
||||
thread_id=payload.thread_id,
|
||||
permission=permission,
|
||||
pattern=pattern,
|
||||
action=payload.action,
|
||||
)
|
||||
session.add(row)
|
||||
try:
|
||||
await session.commit()
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=(
|
||||
"An identical rule already exists for this scope. Update the "
|
||||
"existing rule instead."
|
||||
),
|
||||
)
|
||||
await session.refresh(row)
|
||||
return _to_read(row)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/searchspaces/{search_space_id}/agent/permissions/rules/{rule_id}",
|
||||
response_model=AgentPermissionRuleRead,
|
||||
)
|
||||
async def update_rule(
|
||||
search_space_id: int,
|
||||
rule_id: int,
|
||||
payload: AgentPermissionRuleUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
) -> AgentPermissionRuleRead:
|
||||
_flag_guard()
|
||||
await _ensure_search_space_membership_admin(session, user, search_space_id)
|
||||
|
||||
row = await session.get(AgentPermissionRule, rule_id)
|
||||
if row is None or row.search_space_id != search_space_id:
|
||||
raise HTTPException(status_code=404, detail="Rule not found.")
|
||||
|
||||
if payload.pattern is not None:
|
||||
row.pattern = payload.pattern.strip() or "*"
|
||||
if payload.action is not None:
|
||||
row.action = payload.action
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Update would create a duplicate rule for this scope.",
|
||||
)
|
||||
await session.refresh(row)
|
||||
return _to_read(row)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/searchspaces/{search_space_id}/agent/permissions/rules/{rule_id}",
|
||||
status_code=204,
|
||||
)
|
||||
async def delete_rule(
|
||||
search_space_id: int,
|
||||
rule_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
) -> None:
|
||||
_flag_guard()
|
||||
await _ensure_search_space_membership_admin(session, user, search_space_id)
|
||||
|
||||
row = await session.get(AgentPermissionRule, rule_id)
|
||||
if row is None or row.search_space_id != search_space_id:
|
||||
raise HTTPException(status_code=404, detail="Rule not found.")
|
||||
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
return None
|
||||
122
surfsense_backend/app/routes/agent_revert_route.py
Normal file
122
surfsense_backend/app/routes/agent_revert_route.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""POST ``/api/threads/{thread_id}/revert/{action_id}``: undo an agent action.
|
||||
|
||||
Per the Tier 5 plan, the route ships **before** the UI lights up the per-message
|
||||
"Undo from here" affordance. To prevent accidental usage during the gap we
|
||||
return ``503 Service Unavailable`` until the
|
||||
``SURFSENSE_ENABLE_REVERT_ROUTE`` flag flips. Once enabled, the route runs:
|
||||
|
||||
1. Authentication via :func:`current_active_user`.
|
||||
2. Action lookup; 404 if the action does not belong to the thread.
|
||||
3. Authorization via :func:`app.services.revert_service.can_revert`.
|
||||
4. Revert dispatch via :func:`app.services.revert_service.revert_action`.
|
||||
5. Idempotent on retries: if the same action is reverted twice the second
|
||||
call returns 409 ``"already reverted"``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.feature_flags import get_flags
|
||||
from app.db import (
|
||||
AgentActionLog,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.services.revert_service import (
|
||||
RevertOutcome,
|
||||
can_revert,
|
||||
load_action,
|
||||
load_thread,
|
||||
revert_action,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/revert/{action_id}")
|
||||
async def revert_agent_action(
|
||||
thread_id: int,
|
||||
action_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
) -> dict:
|
||||
flags = get_flags()
|
||||
if flags.disable_new_agent_stack or not flags.enable_revert_route:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"Revert is not available on this deployment yet. The route "
|
||||
"ships before the UI; flip SURFSENSE_ENABLE_REVERT_ROUTE to "
|
||||
"enable it."
|
||||
),
|
||||
)
|
||||
|
||||
thread = await load_thread(session, thread_id=thread_id)
|
||||
if thread is None:
|
||||
raise HTTPException(status_code=404, detail="Thread not found.")
|
||||
|
||||
action = await load_action(session, action_id=action_id, thread_id=thread_id)
|
||||
if action is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Action not found or does not belong to this thread.",
|
||||
)
|
||||
|
||||
# Idempotency: if a successful revert already exists, return 409.
|
||||
existing_revert = await session.execute(
|
||||
select(AgentActionLog).where(AgentActionLog.reverse_of == action.id)
|
||||
)
|
||||
if existing_revert.scalars().first() is not None:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="This action has already been reverted.",
|
||||
)
|
||||
|
||||
if not can_revert(
|
||||
requester_user_id=str(user.id) if user is not None else None,
|
||||
action=action,
|
||||
is_admin=False, # role lookup is done by RBAC layer; default conservative
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You are not allowed to revert this action.",
|
||||
)
|
||||
|
||||
outcome: RevertOutcome
|
||||
try:
|
||||
outcome = await revert_action(
|
||||
session,
|
||||
action=action,
|
||||
requester_user_id=str(user.id) if user is not None else None,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Revert dispatch raised for action_id=%s", action_id)
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="Internal error during revert.")
|
||||
|
||||
if outcome.status == "ok":
|
||||
await session.commit()
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": outcome.message,
|
||||
"new_action_id": outcome.new_action_id,
|
||||
}
|
||||
|
||||
await session.rollback()
|
||||
|
||||
if outcome.status == "not_found" or outcome.status == "tool_unavailable":
|
||||
raise HTTPException(status_code=409, detail=outcome.message)
|
||||
if outcome.status == "permission_denied":
|
||||
raise HTTPException(status_code=403, detail=outcome.message)
|
||||
if outcome.status == "reverse_not_implemented":
|
||||
raise HTTPException(status_code=501, detail=outcome.message)
|
||||
# not_reversible
|
||||
raise HTTPException(status_code=409, detail=outcome.message)
|
||||
279
surfsense_backend/app/services/revert_service.py
Normal file
279
surfsense_backend/app/services/revert_service.py
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
"""Revert service for the SurfSense agent action log.
|
||||
|
||||
Implements the actual revert workflow used by
|
||||
``POST /api/threads/{thread_id}/revert/{action_id}``. The route handler is a
|
||||
thin auth + flag wrapper around the functions defined here.
|
||||
|
||||
Operation outcomes mirror the plan:
|
||||
|
||||
* **KB-owned actions** (NOTE / FILE / FOLDER mutations): restore from
|
||||
:class:`app.db.DocumentRevision` / :class:`app.db.FolderRevision` rows
|
||||
written before the original mutation.
|
||||
* **Connector-owned actions with a declared ``reverse_descriptor``**: invoke
|
||||
the inverse tool through the agent's normal permission stack (NOT
|
||||
bypassed). Out of scope for this PR — returns ``REVERSE_NOT_IMPLEMENTED``.
|
||||
* **Anything else** (deprecated tool / no descriptor / schema drift):
|
||||
returns ``NOT_REVERSIBLE`` and the route surfaces it as 409.
|
||||
|
||||
A successful revert appends a NEW row to ``agent_action_log`` with
|
||||
``reverse_of=<original_action_id>`` and the requesting user's
|
||||
``user_id``, preserving an auditable chain.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import (
|
||||
AgentActionLog,
|
||||
DocumentRevision,
|
||||
FolderRevision,
|
||||
NewChatThread,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
RevertOutcomeStatus = Literal[
|
||||
"ok",
|
||||
"not_reversible",
|
||||
"not_found",
|
||||
"permission_denied",
|
||||
"tool_unavailable",
|
||||
"reverse_not_implemented",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RevertOutcome:
|
||||
"""Structured result of :func:`revert_action`."""
|
||||
|
||||
status: RevertOutcomeStatus
|
||||
message: str
|
||||
new_action_id: int | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lookup helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def load_action(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
action_id: int,
|
||||
thread_id: int,
|
||||
) -> AgentActionLog | None:
|
||||
"""Load the action_log row for ``action_id`` if it belongs to the thread."""
|
||||
stmt = select(AgentActionLog).where(
|
||||
AgentActionLog.id == action_id,
|
||||
AgentActionLog.thread_id == thread_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
async def load_thread(
|
||||
session: AsyncSession, *, thread_id: int
|
||||
) -> NewChatThread | None:
|
||||
stmt = select(NewChatThread).where(NewChatThread.id == thread_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Authorization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def can_revert(
|
||||
*,
|
||||
requester_user_id: str | None,
|
||||
action: AgentActionLog,
|
||||
is_admin: bool,
|
||||
) -> bool:
|
||||
"""Return True iff the requester is allowed to revert this action.
|
||||
|
||||
The plan's rule: "requester must be the original `user_id` on the
|
||||
action, or hold the search-space admin role." Anonymous actions
|
||||
(``action.user_id is None``) can only be reverted by admins.
|
||||
"""
|
||||
if is_admin:
|
||||
return True
|
||||
if action.user_id is None:
|
||||
return False
|
||||
return str(action.user_id) == str(requester_user_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Revert paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _restore_document_revision(
|
||||
session: AsyncSession, *, action: AgentActionLog
|
||||
) -> RevertOutcome:
|
||||
"""Restore the most recent :class:`DocumentRevision` for ``action``."""
|
||||
stmt = (
|
||||
select(DocumentRevision)
|
||||
.where(DocumentRevision.agent_action_id == action.id)
|
||||
.order_by(DocumentRevision.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
revision = result.scalars().first()
|
||||
if revision is None:
|
||||
return RevertOutcome(
|
||||
status="not_reversible",
|
||||
message="No document_revisions row tied to this action.",
|
||||
)
|
||||
|
||||
from app.db import Document # late import to avoid cycles at module load
|
||||
|
||||
doc = await session.get(Document, revision.document_id)
|
||||
if doc is None:
|
||||
return RevertOutcome(
|
||||
status="tool_unavailable",
|
||||
message="Original document has been deleted; revert cannot proceed.",
|
||||
)
|
||||
|
||||
if revision.content_before is not None:
|
||||
doc.content = revision.content_before
|
||||
if revision.title_before is not None:
|
||||
doc.title = revision.title_before
|
||||
if revision.folder_id_before is not None:
|
||||
doc.folder_id = revision.folder_id_before
|
||||
doc.updated_at = datetime.now(UTC)
|
||||
return RevertOutcome(status="ok", message="Document restored from snapshot.")
|
||||
|
||||
|
||||
async def _restore_folder_revision(
|
||||
session: AsyncSession, *, action: AgentActionLog
|
||||
) -> RevertOutcome:
|
||||
stmt = (
|
||||
select(FolderRevision)
|
||||
.where(FolderRevision.agent_action_id == action.id)
|
||||
.order_by(FolderRevision.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
revision = result.scalars().first()
|
||||
if revision is None:
|
||||
return RevertOutcome(
|
||||
status="not_reversible",
|
||||
message="No folder_revisions row tied to this action.",
|
||||
)
|
||||
|
||||
from app.db import Folder
|
||||
|
||||
folder = await session.get(Folder, revision.folder_id)
|
||||
if folder is None:
|
||||
return RevertOutcome(
|
||||
status="tool_unavailable",
|
||||
message="Original folder has been deleted; revert cannot proceed.",
|
||||
)
|
||||
|
||||
if revision.name_before is not None:
|
||||
folder.name = revision.name_before
|
||||
if revision.parent_id_before is not None:
|
||||
folder.parent_id = revision.parent_id_before
|
||||
if revision.position_before is not None:
|
||||
folder.position = revision.position_before
|
||||
folder.updated_at = datetime.now(UTC)
|
||||
return RevertOutcome(status="ok", message="Folder restored from snapshot.")
|
||||
|
||||
|
||||
# Tool-name prefixes that route to KB document / folder revert paths. Kept
|
||||
# as data so a future PR adding new KB-owned tools doesn't have to touch
|
||||
# this module's control flow.
|
||||
_DOC_TOOL_PREFIXES: tuple[str, ...] = (
|
||||
"edit_file",
|
||||
"write_file",
|
||||
"update_memory",
|
||||
"create_note",
|
||||
"update_note",
|
||||
"delete_note",
|
||||
)
|
||||
_FOLDER_TOOL_PREFIXES: tuple[str, ...] = (
|
||||
"mkdir",
|
||||
"move_file",
|
||||
"rename_folder",
|
||||
"delete_folder",
|
||||
)
|
||||
|
||||
|
||||
async def revert_action(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
action: AgentActionLog,
|
||||
requester_user_id: str | None,
|
||||
) -> RevertOutcome:
|
||||
"""Execute the revert for ``action`` and return a structured outcome.
|
||||
|
||||
The function does **not** commit — the caller is expected to commit on
|
||||
success or roll back on failure. A new ``agent_action_log`` row is
|
||||
added to the session on success with ``reverse_of=action.id``.
|
||||
"""
|
||||
tool_name = (action.tool_name or "").lower()
|
||||
|
||||
if tool_name.startswith(_DOC_TOOL_PREFIXES):
|
||||
outcome = await _restore_document_revision(session, action=action)
|
||||
elif tool_name.startswith(_FOLDER_TOOL_PREFIXES):
|
||||
outcome = await _restore_folder_revision(session, action=action)
|
||||
elif action.reverse_descriptor:
|
||||
# Connector-owned reversibles run through the normal permission
|
||||
# stack; out of scope for this PR — the route returns 503 anyway
|
||||
# until UI ships, so 501-style "not implemented" is fine.
|
||||
return RevertOutcome(
|
||||
status="reverse_not_implemented",
|
||||
message=(
|
||||
"Connector-action revert is not yet implemented. The "
|
||||
"reverse_descriptor is stored; future work will replay it "
|
||||
"through PermissionMiddleware."
|
||||
),
|
||||
)
|
||||
else:
|
||||
return RevertOutcome(
|
||||
status="not_reversible",
|
||||
message=(
|
||||
f"Tool {action.tool_name!r} is not reversible: no document "
|
||||
"revision and no reverse_descriptor."
|
||||
),
|
||||
)
|
||||
|
||||
if outcome.status != "ok":
|
||||
return outcome
|
||||
|
||||
new_row = AgentActionLog(
|
||||
thread_id=action.thread_id,
|
||||
user_id=requester_user_id,
|
||||
search_space_id=action.search_space_id,
|
||||
turn_id=None,
|
||||
message_id=None,
|
||||
tool_name=f"_revert:{action.tool_name}",
|
||||
args={"reverted_action_id": action.id},
|
||||
result_id=None,
|
||||
reversible=False,
|
||||
reverse_descriptor=None,
|
||||
error=None,
|
||||
reverse_of=action.id,
|
||||
)
|
||||
session.add(new_row)
|
||||
await session.flush()
|
||||
outcome.new_action_id = new_row.id
|
||||
return outcome
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RevertOutcome",
|
||||
"can_revert",
|
||||
"load_action",
|
||||
"load_thread",
|
||||
"revert_action",
|
||||
]
|
||||
|
|
@ -33,7 +33,7 @@ F = TypeVar("F", bound=Callable)
|
|||
def _is_retryable(exc: BaseException) -> bool:
|
||||
if isinstance(exc, ConnectorError):
|
||||
return exc.retryable
|
||||
return bool(isinstance(exc, (httpx.TimeoutException, httpx.ConnectError)))
|
||||
return bool(isinstance(exc, httpx.TimeoutException | httpx.ConnectError))
|
||||
|
||||
|
||||
def build_retry(
|
||||
|
|
|
|||
146
surfsense_backend/tests/integration/harness/__init__.py
Normal file
146
surfsense_backend/tests/integration/harness/__init__.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
"""
|
||||
Integration test harness for the SurfSense agent stack.
|
||||
|
||||
The plan calls for an ``LLMToolEmulator``-backed harness for end-to-end
|
||||
replay of ``stream_new_chat``. The currently-installed langchain version
|
||||
does not expose ``LLMToolEmulator``, so this harness builds the equivalent
|
||||
on top of :class:`langchain_core.language_models.fake_chat_models.FakeMessagesListChatModel`.
|
||||
|
||||
The harness lets a test author script a sequence of model responses
|
||||
(text + optional tool calls) and replay them against the new_chat agent
|
||||
graph. Tools are stubbed via ``StubToolSpec`` -> ``langchain_core.tools.tool``
|
||||
decorator and execute deterministic Python callbacks.
|
||||
|
||||
Used by:
|
||||
- ``tests/integration/agents/new_chat/test_feature_flag_smoke.py`` to
|
||||
confirm the kill-switch path produces identical-shape output regardless
|
||||
of which middleware flags are toggled.
|
||||
- Future per-tier PRs to record golden transcripts.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.fake_chat_models import (
|
||||
FakeMessagesListChatModel,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool, tool
|
||||
|
||||
|
||||
class _ToolBindingFakeChatModel(FakeMessagesListChatModel):
|
||||
"""Adapter so the harness model can pretend it understands ``bind_tools``.
|
||||
|
||||
The base ``FakeMessagesListChatModel`` raises ``NotImplementedError`` from
|
||||
``bind_tools``, but ``langchain.agents.create_agent`` always calls
|
||||
``bind_tools`` to attach the tool registry. We don't actually need the
|
||||
fake to honor the tool schema — it's already scripted to emit the right
|
||||
tool calls — so we return self.
|
||||
"""
|
||||
|
||||
def bind_tools( # type: ignore[override]
|
||||
self,
|
||||
tools: Sequence[Any],
|
||||
*,
|
||||
tool_choice: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, AIMessage]:
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubToolSpec:
|
||||
"""A test-mode tool: a name, description, and a deterministic body."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
handler: Callable[..., Any]
|
||||
args_schema: dict[str, Any] | None = None
|
||||
|
||||
def build(self) -> BaseTool:
|
||||
"""Realize as a `langchain_core.tools.BaseTool`."""
|
||||
|
||||
@tool(name_or_callable=self.name, description=self.description)
|
||||
def _stub_tool(**kwargs: Any) -> Any:
|
||||
return self.handler(**kwargs)
|
||||
|
||||
return _stub_tool
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptedTurn:
|
||||
"""One scripted assistant turn.
|
||||
|
||||
`text` is the assistant text (may be empty if pure tool call).
|
||||
`tool_calls` is a list of dicts ``{name, args, id}``; if non-empty, the
|
||||
agent will route to those tools and append a follow-up turn.
|
||||
"""
|
||||
|
||||
text: str = ""
|
||||
tool_calls: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
def build_scripted_messages(turns: list[ScriptedTurn]) -> list[BaseMessage]:
|
||||
"""Convert :class:`ScriptedTurn` records to AIMessage payloads."""
|
||||
out: list[BaseMessage] = []
|
||||
for turn in turns:
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
for tc in turn.tool_calls:
|
||||
tool_calls.append(
|
||||
{
|
||||
"name": tc["name"],
|
||||
"args": tc.get("args", {}),
|
||||
"id": tc.get("id") or f"call_{uuid.uuid4().hex[:8]}",
|
||||
}
|
||||
)
|
||||
out.append(AIMessage(content=turn.text, tool_calls=tool_calls or []))
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptedHarness:
|
||||
"""Bundle of (model, tools) ready to plug into ``create_agent``."""
|
||||
|
||||
model: _ToolBindingFakeChatModel
|
||||
tools: list[BaseTool]
|
||||
|
||||
|
||||
def build_scripted_harness(
|
||||
*,
|
||||
turns: list[ScriptedTurn],
|
||||
tools: list[StubToolSpec] | None = None,
|
||||
sleep: float | None = None,
|
||||
) -> ScriptedHarness:
|
||||
"""Construct a deterministic agent harness from a script.
|
||||
|
||||
Example::
|
||||
|
||||
harness = build_scripted_harness(
|
||||
turns=[
|
||||
ScriptedTurn(tool_calls=[{"name": "echo", "args": {"x": 1}}]),
|
||||
ScriptedTurn(text="done"),
|
||||
],
|
||||
tools=[
|
||||
StubToolSpec(name="echo", description="echo args", handler=lambda **kw: kw),
|
||||
],
|
||||
)
|
||||
"""
|
||||
messages = build_scripted_messages(turns)
|
||||
model = _ToolBindingFakeChatModel(responses=messages, sleep=sleep)
|
||||
realized_tools = [t.build() for t in (tools or [])]
|
||||
return ScriptedHarness(model=model, tools=realized_tools)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ScriptedHarness",
|
||||
"ScriptedTurn",
|
||||
"StubToolSpec",
|
||||
"build_scripted_harness",
|
||||
"build_scripted_messages",
|
||||
]
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
"""Smoke test: scripted harness drives create_agent end-to-end and produces a tool-call-then-final-text trace."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from tests.integration.harness import (
|
||||
ScriptedTurn,
|
||||
StubToolSpec,
|
||||
build_scripted_harness,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scripted_harness_drives_basic_agent() -> None:
|
||||
harness = build_scripted_harness(
|
||||
turns=[
|
||||
ScriptedTurn(
|
||||
tool_calls=[
|
||||
{"name": "echo", "args": {"x": 1}, "id": "call_1"},
|
||||
]
|
||||
),
|
||||
ScriptedTurn(text="done"),
|
||||
],
|
||||
tools=[
|
||||
StubToolSpec(
|
||||
name="echo",
|
||||
description="Echo args back.",
|
||||
handler=lambda **kwargs: {"echoed": kwargs},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
agent = create_agent(
|
||||
harness.model,
|
||||
system_prompt="You are a test agent.",
|
||||
tools=harness.tools,
|
||||
)
|
||||
|
||||
result = await agent.ainvoke({"messages": [("user", "do the thing")]})
|
||||
messages = result["messages"]
|
||||
final_ai = next(
|
||||
(m for m in reversed(messages) if m.__class__.__name__ == "AIMessage"),
|
||||
None,
|
||||
)
|
||||
assert final_ai is not None
|
||||
assert final_ai.content == "done"
|
||||
tool_messages = [m for m in messages if m.__class__.__name__ == "ToolMessage"]
|
||||
assert len(tool_messages) == 1
|
||||
assert "echoed" in str(tool_messages[0].content)
|
||||
1
surfsense_backend/tests/unit/agents/__init__.py
Normal file
1
surfsense_backend/tests/unit/agents/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
1
surfsense_backend/tests/unit/agents/new_chat/__init__.py
Normal file
1
surfsense_backend/tests/unit/agents/new_chat/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""__init__ stub so pytest discovers the prompts test module."""
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue