mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(agents): move mac-only modules out of the cross-agent shared kernel
app/agents/shared/ is a sibling of anonymous_chat/podcaster/multi_agent_chat/
video_presentation, so it should only hold code shared across 2+ of those
agents. In practice podcaster and video_presentation import nothing from it,
and anonymous_chat needs only context + compaction + retry_after + web_search.
Everything else was multi_agent_chat-only (the boundary just passes through).
Move the multi_agent_chat-only cluster into multi_agent_chat/shared/ (files
moved verbatim via git rename; ~116 import sites rewritten):
errors, feature_flags, filesystem_selection, path_resolver, prompt_caching,
sandbox, llm_config, mention_resolver
middleware/busy_mutex, middleware/kb_persistence
busy_mutex/llm_config/mention_resolver are boundary-only but import the moved
modules, so they were folded in to avoid a backwards shared -> multi_agent_chat
dependency. main_agent builders now import the impls directly; the shared
middleware barrel keeps only the genuinely-shared compaction + retry_after.
Also delete the dead leftover shared/plugins and shared/skills dirs (live
copies already live under main_agent/).
Remaining in app/agents/shared/: context, system_prompt(+prompts), checkpointer,
middleware/{compaction,retry_after,dedup_tool_calls}, tools/. checkpointer and
system_prompt are boundary-only infra pending a dedicated home decision.
This commit is contained in:
parent
c0c4f57f5d
commit
82c5dc5b02
126 changed files with 238 additions and 196 deletions
|
|
@ -14,9 +14,9 @@ from langgraph.types import Checkpointer
|
|||
from app.agents.multi_agent_chat.main_agent.middleware.stack import (
|
||||
build_main_agent_deepagent_middleware,
|
||||
)
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.shared.context import SurfSenseContextSchema
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.db import ChatVisibility
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.middleware.flags import enabled
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
|
||||
from .middleware import ActionLogMiddleware
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ from langchain.agents.middleware import AgentMiddleware
|
|||
from langchain_core.callbacks import adispatch_custom_event
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
from app.agents.shared.feature_flags import get_flags
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import get_flags
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - type-only
|
||||
from langchain.agents.middleware.types import ToolCallRequest
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
from .middleware import AnonymousDocumentMiddleware
|
||||
|
||||
|
|
|
|||
|
|
@ -24,10 +24,13 @@ from typing import Any
|
|||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agents.multi_agent_chat.shared.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
safe_filename,
|
||||
)
|
||||
from app.agents.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.agents.shared.path_resolver import DOCUMENTS_ROOT, safe_filename
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,9 +2,8 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.shared.middleware import BusyMutexMiddleware
|
||||
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.middleware.busy_mutex import BusyMutexMiddleware
|
||||
from app.agents.multi_agent_chat.shared.middleware.flags import enabled
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ from langchain_core.tools import BaseTool
|
|||
from app.agents.multi_agent_chat.main_agent.context_prune.prune_tool_names import (
|
||||
safe_exclude_tools,
|
||||
)
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.middleware.flags import enabled
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
|
||||
from .middleware import (
|
||||
ClearToolUsesEdit,
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.middleware.flags import enabled
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
|
||||
from .middleware import DoomLoopMiddleware
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.shared.middleware import KnowledgeBasePersistenceMiddleware
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.middleware.kb_persistence import (
|
||||
KnowledgeBasePersistenceMiddleware,
|
||||
)
|
||||
|
||||
|
||||
def build_kb_persistence_mw(
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@ from __future__ import annotations
|
|||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.middleware.knowledge_search import (
|
||||
KnowledgePriorityMiddleware,
|
||||
)
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.services.llm_service import get_planner_llm
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
from .middleware import KnowledgeTreeMiddleware
|
||||
|
||||
|
|
|
|||
|
|
@ -33,16 +33,16 @@ from langchain_core.messages import SystemMessage
|
|||
from langgraph.runtime import Runtime
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.shared.path_resolver import (
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
PathIndex,
|
||||
build_path_index,
|
||||
doc_to_virtual_path,
|
||||
)
|
||||
from app.agents.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.db import Document, shielded_async_session
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.middleware.flags import enabled
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
|
||||
from .middleware import NoopInjectionMiddleware
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.middleware.flags import enabled
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
|
||||
from .middleware import OtelSpanMiddleware
|
||||
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from typing import Any
|
|||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.middleware.flags import enabled
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from ..plugins.loader import (
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@ import logging
|
|||
|
||||
from deepagents.middleware.skills import SkillsMiddleware
|
||||
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.middleware.flags import enabled
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
from ..skills.backends import build_skills_backend_factory, default_skills_sources
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ from langchain_core.language_models import BaseChatModel
|
|||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.middleware.anthropic_cache import (
|
||||
build_anthropic_cache_mw,
|
||||
)
|
||||
|
|
@ -52,8 +54,6 @@ from app.agents.multi_agent_chat.subagents.builtins.knowledge_base.ask_knowledge
|
|||
from app.agents.multi_agent_chat.subagents.shared.middleware.middleware_stack import (
|
||||
build_subagent_middleware_stack,
|
||||
)
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from .action_log import build_action_log_mw
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ from collections.abc import Sequence
|
|||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.middleware.flags import enabled
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
|
||||
from .middleware import ToolCallNameRepairMiddleware
|
||||
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ from langchain_core.language_models import BaseChatModel
|
|||
from langchain_core.tools import BaseTool
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from ..graph.compile_graph_sync import build_compiled_agent_graph_sync
|
||||
|
|
|
|||
|
|
@ -12,9 +12,21 @@ from langchain_core.tools import BaseTool
|
|||
from langgraph.types import Checkpointer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import (
|
||||
AgentFeatureFlags,
|
||||
get_flags,
|
||||
)
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import (
|
||||
FilesystemMode,
|
||||
FilesystemSelection,
|
||||
)
|
||||
from app.agents.multi_agent_chat.shared.llm_config import AgentConfig
|
||||
from app.agents.multi_agent_chat.shared.middleware.filesystem.backends.resolver import (
|
||||
build_backend_resolver,
|
||||
)
|
||||
from app.agents.multi_agent_chat.shared.prompt_caching import (
|
||||
apply_litellm_prompt_caching,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents import (
|
||||
get_subagents_to_exclude,
|
||||
main_prompt_registry_subagent_lines,
|
||||
|
|
@ -22,10 +34,6 @@ from app.agents.multi_agent_chat.subagents import (
|
|||
from app.agents.multi_agent_chat.subagents.mcp_tools.index import (
|
||||
load_mcp_tools_by_connector,
|
||||
)
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags, get_flags
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||
from app.agents.shared.llm_config import AgentConfig
|
||||
from app.agents.shared.prompt_caching import apply_litellm_prompt_caching
|
||||
from app.db import ChatVisibility
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.services.user_tool_allowlist import (
|
||||
|
|
|
|||
|
|
@ -0,0 +1,95 @@
|
|||
"""
|
||||
Typed error taxonomy for the SurfSense agent stack.
|
||||
|
||||
Used by:
|
||||
- :class:`RetryAfterMiddleware` — its ``retry_on`` callable consults
|
||||
the error code to decide whether a retry is appropriate.
|
||||
- :class:`PermissionMiddleware` — emits ``code="permission_denied"``
|
||||
errors when a deny rule trips.
|
||||
- All tools — return :class:`StreamingError` payloads in
|
||||
``ToolMessage.additional_kwargs["error"]`` so the model and the
|
||||
retry/permission layers share a contract.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
ErrorCode = Literal[
|
||||
"rate_limit",
|
||||
"auth",
|
||||
"tool_validation",
|
||||
"tool_runtime",
|
||||
"context_overflow",
|
||||
"provider",
|
||||
"permission_denied",
|
||||
"doom_loop",
|
||||
"busy",
|
||||
"cancelled",
|
||||
]
|
||||
|
||||
|
||||
class StreamingError(BaseModel):
|
||||
"""Structured error payload attached to ``ToolMessage.additional_kwargs["error"]``.
|
||||
|
||||
Tools and middleware emit this so retry, permission, and routing
|
||||
layers can decide what to do without parsing free-form strings.
|
||||
"""
|
||||
|
||||
code: ErrorCode
|
||||
retryable: bool = False
|
||||
suggestion: str | None = None
|
||||
correlation_id: str | None = None
|
||||
detail: str | None = Field(
|
||||
default=None,
|
||||
description="Free-form additional context. Not surfaced to the model.",
|
||||
)
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
|
||||
class RejectedError(Exception):
|
||||
"""Raised when the user rejects a permission ask without feedback.
|
||||
|
||||
Caught by :class:`PermissionMiddleware`; the agent stops the current
|
||||
tool fan-out and surfaces a user-facing rejection.
|
||||
"""
|
||||
|
||||
def __init__(self, *, tool: str | None = None, pattern: str | None = None) -> None:
|
||||
super().__init__(f"Permission rejected for tool {tool!r}, pattern {pattern!r}")
|
||||
self.tool = tool
|
||||
self.pattern = pattern
|
||||
|
||||
|
||||
class CorrectedError(Exception):
|
||||
"""Raised when the user rejects a permission ask *with* feedback.
|
||||
|
||||
The :class:`PermissionMiddleware` translates the feedback into a
|
||||
synthetic ``ToolMessage`` so the model sees the user's correction
|
||||
and can retry the request differently.
|
||||
"""
|
||||
|
||||
def __init__(self, feedback: str, *, tool: str | None = None) -> None:
|
||||
super().__init__(feedback)
|
||||
self.feedback = feedback
|
||||
self.tool = tool
|
||||
|
||||
|
||||
class BusyError(Exception):
|
||||
"""Raised when a second prompt arrives while the same thread is mid-stream."""
|
||||
|
||||
def __init__(self, request_id: str | None = None) -> None:
|
||||
super().__init__("Thread is busy with another request")
|
||||
self.request_id = request_id
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BusyError",
|
||||
"CorrectedError",
|
||||
"ErrorCode",
|
||||
"RejectedError",
|
||||
"StreamingError",
|
||||
]
|
||||
|
|
@ -0,0 +1,257 @@
|
|||
"""
|
||||
Feature flags for the SurfSense new_chat agent stack.
|
||||
|
||||
These flags gate the newer agent middleware (some ported from OpenCode,
|
||||
some sourced from ``langchain.agents.middleware`` / ``deepagents``, some
|
||||
SurfSense-native). Most shipped agent-stack upgrades default ON so Docker
|
||||
image updates work even when older installs do not have newly introduced
|
||||
environment variables. Risky/experimental integrations stay default OFF,
|
||||
and the master kill-switch can still disable everything new.
|
||||
|
||||
All new middleware checks its flag at agent build time. If the master
|
||||
kill-switch ``SURFSENSE_DISABLE_NEW_AGENT_STACK`` is set, every new
|
||||
middleware is disabled regardless of its individual flag. This gives
|
||||
operators a single switch to revert to pre-port behavior.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
Defaults:
|
||||
|
||||
SURFSENSE_ENABLE_CONTEXT_EDITING=true
|
||||
SURFSENSE_ENABLE_COMPACTION_V2=true
|
||||
SURFSENSE_ENABLE_RETRY_AFTER=true
|
||||
SURFSENSE_ENABLE_MODEL_FALLBACK=false
|
||||
SURFSENSE_ENABLE_MODEL_CALL_LIMIT=true
|
||||
SURFSENSE_ENABLE_TOOL_CALL_LIMIT=true
|
||||
SURFSENSE_ENABLE_TOOL_CALL_REPAIR=true
|
||||
SURFSENSE_ENABLE_PERMISSION=true
|
||||
SURFSENSE_ENABLE_DOOM_LOOP=true
|
||||
SURFSENSE_ENABLE_LLM_TOOL_SELECTOR=false # adds a per-turn LLM call
|
||||
|
||||
Master kill-switch (overrides everything else):
|
||||
|
||||
SURFSENSE_DISABLE_NEW_AGENT_STACK=true
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _env_bool(name: str, default: bool) -> bool:
|
||||
"""Parse a boolean env var. Accepts ``1``/``true``/``yes``/``on`` (case-insensitive)."""
|
||||
raw = os.environ.get(name)
|
||||
if raw is None:
|
||||
return default
|
||||
return raw.strip().lower() in ("1", "true", "yes", "on")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentFeatureFlags:
|
||||
"""Resolved feature-flag state for one agent build.
|
||||
|
||||
Constructed via :meth:`from_env`. The dataclass is frozen so it can be
|
||||
safely shared across coroutines.
|
||||
"""
|
||||
|
||||
# Master kill-switch — when true, every flag below resolves to False
|
||||
# regardless of its env value. Used for rapid rollback.
|
||||
disable_new_agent_stack: bool = False
|
||||
|
||||
# Agent quality — context budget, retry/limits, name-repair, doom-loop
|
||||
enable_context_editing: bool = True
|
||||
enable_compaction_v2: bool = True
|
||||
enable_retry_after: bool = True
|
||||
enable_model_fallback: bool = False
|
||||
enable_model_call_limit: bool = True
|
||||
enable_tool_call_limit: bool = True
|
||||
enable_tool_call_repair: bool = True
|
||||
enable_doom_loop: bool = True
|
||||
|
||||
# Safety — permissions, concurrency, tool-set narrowing
|
||||
enable_permission: bool = True
|
||||
enable_busy_mutex: bool = True
|
||||
enable_llm_tool_selector: bool = False # Default OFF — adds per-turn LLM cost
|
||||
|
||||
# Skills + subagents
|
||||
enable_skills: bool = True
|
||||
enable_specialized_subagents: bool = True
|
||||
enable_kb_planner_runnable: bool = True
|
||||
|
||||
# Snapshot / revert
|
||||
enable_action_log: bool = True
|
||||
enable_revert_route: bool = True
|
||||
|
||||
# Plugins
|
||||
enable_plugin_loader: bool = False
|
||||
|
||||
# Observability — OTel (orthogonal; also requires OTEL_EXPORTER_OTLP_ENDPOINT)
|
||||
enable_otel: bool = False
|
||||
|
||||
# Performance — compiled-agent cache (Phase 1 + Phase 2).
|
||||
# When ON, ``create_surfsense_deep_agent`` reuses a previously-compiled
|
||||
# graph if the cache key matches (LLM config + thread + tool surface +
|
||||
# flags + system prompt + filesystem mode). Cuts per-turn agent-build
|
||||
# wall clock from ~4-5s to <50µs on cache hits.
|
||||
#
|
||||
# SAFETY (Phase 2 unblocked this default-on):
|
||||
# All connector mutation tools (``tools/notion``, ``tools/gmail``,
|
||||
# ``tools/google_drive``, ``tools/dropbox``, ``tools/onedrive``,
|
||||
# ``tools/google_calendar``, ``tools/confluence``, ``tools/discord``,
|
||||
# ``tools/teams``, ``tools/luma``, ``connected_accounts``,
|
||||
# ``update_memory``) now acquire fresh
|
||||
# short-lived ``AsyncSession`` instances per call via
|
||||
# :data:`async_session_maker`. The factory still accepts ``db_session``
|
||||
# for registry compatibility but ``del``'s it immediately — see any
|
||||
# of those files' factory docstrings for the rationale. The ``llm``
|
||||
# closure is per-(provider, model, config_id) which is already in
|
||||
# the cache key, so the LLM is safe to share across cached hits of
|
||||
# the same key. The KB priority middleware reads
|
||||
# ``mentioned_document_ids`` from ``runtime.context`` (Phase 1.5),
|
||||
# not its constructor closure, so the same compiled agent serves
|
||||
# turns with different mention lists correctly.
|
||||
#
|
||||
# Rollback: set ``SURFSENSE_ENABLE_AGENT_CACHE=false`` in the
|
||||
# environment if a regression surfaces. The path is exercised by
|
||||
# the ``tests/unit/agents/new_chat/test_agent_cache_*`` suite.
|
||||
enable_agent_cache: bool = True
|
||||
# Phase 1 (deferred — measure first): pre-build & share the
|
||||
# general-purpose subagent ``CompiledSubAgent`` across cold-cache
|
||||
# misses. Only helps when the outer cache MISSES (cache hits already
|
||||
# reuse the entire SubAgentMiddleware-compiled graph). Off by default
|
||||
# until we have data showing cold misses are frequent enough to
|
||||
# justify the extra global state.
|
||||
enable_agent_cache_share_gp_subagent: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> AgentFeatureFlags:
|
||||
"""Read flags from environment.
|
||||
|
||||
Master kill-switch is evaluated first; when set, all other flags
|
||||
force to False.
|
||||
"""
|
||||
master_off = _env_bool("SURFSENSE_DISABLE_NEW_AGENT_STACK", False)
|
||||
if master_off:
|
||||
logger.info(
|
||||
"SURFSENSE_DISABLE_NEW_AGENT_STACK is set: every new agent "
|
||||
"middleware is forced OFF for this build."
|
||||
)
|
||||
return cls(
|
||||
disable_new_agent_stack=True,
|
||||
enable_context_editing=False,
|
||||
enable_compaction_v2=False,
|
||||
enable_retry_after=False,
|
||||
enable_model_fallback=False,
|
||||
enable_model_call_limit=False,
|
||||
enable_tool_call_limit=False,
|
||||
enable_tool_call_repair=False,
|
||||
enable_doom_loop=False,
|
||||
enable_permission=False,
|
||||
enable_busy_mutex=False,
|
||||
enable_llm_tool_selector=False,
|
||||
enable_skills=False,
|
||||
enable_specialized_subagents=False,
|
||||
enable_kb_planner_runnable=False,
|
||||
enable_action_log=False,
|
||||
enable_revert_route=False,
|
||||
enable_plugin_loader=False,
|
||||
enable_otel=False,
|
||||
enable_agent_cache=False,
|
||||
enable_agent_cache_share_gp_subagent=False,
|
||||
)
|
||||
|
||||
return cls(
|
||||
disable_new_agent_stack=False,
|
||||
# Agent quality
|
||||
enable_context_editing=_env_bool("SURFSENSE_ENABLE_CONTEXT_EDITING", True),
|
||||
enable_compaction_v2=_env_bool("SURFSENSE_ENABLE_COMPACTION_V2", True),
|
||||
enable_retry_after=_env_bool("SURFSENSE_ENABLE_RETRY_AFTER", True),
|
||||
enable_model_fallback=_env_bool("SURFSENSE_ENABLE_MODEL_FALLBACK", False),
|
||||
enable_model_call_limit=_env_bool(
|
||||
"SURFSENSE_ENABLE_MODEL_CALL_LIMIT", True
|
||||
),
|
||||
enable_tool_call_limit=_env_bool("SURFSENSE_ENABLE_TOOL_CALL_LIMIT", True),
|
||||
enable_tool_call_repair=_env_bool(
|
||||
"SURFSENSE_ENABLE_TOOL_CALL_REPAIR", True
|
||||
),
|
||||
enable_doom_loop=_env_bool("SURFSENSE_ENABLE_DOOM_LOOP", True),
|
||||
# Safety
|
||||
enable_permission=_env_bool("SURFSENSE_ENABLE_PERMISSION", True),
|
||||
enable_busy_mutex=_env_bool("SURFSENSE_ENABLE_BUSY_MUTEX", True),
|
||||
enable_llm_tool_selector=_env_bool(
|
||||
"SURFSENSE_ENABLE_LLM_TOOL_SELECTOR", False
|
||||
),
|
||||
# Skills + subagents
|
||||
enable_skills=_env_bool("SURFSENSE_ENABLE_SKILLS", True),
|
||||
enable_specialized_subagents=_env_bool(
|
||||
"SURFSENSE_ENABLE_SPECIALIZED_SUBAGENTS", True
|
||||
),
|
||||
enable_kb_planner_runnable=_env_bool(
|
||||
"SURFSENSE_ENABLE_KB_PLANNER_RUNNABLE", True
|
||||
),
|
||||
# Snapshot / revert
|
||||
enable_action_log=_env_bool("SURFSENSE_ENABLE_ACTION_LOG", True),
|
||||
enable_revert_route=_env_bool("SURFSENSE_ENABLE_REVERT_ROUTE", True),
|
||||
# Plugins
|
||||
enable_plugin_loader=_env_bool("SURFSENSE_ENABLE_PLUGIN_LOADER", False),
|
||||
# Observability
|
||||
enable_otel=_env_bool("SURFSENSE_ENABLE_OTEL", False),
|
||||
# Performance
|
||||
enable_agent_cache=_env_bool("SURFSENSE_ENABLE_AGENT_CACHE", True),
|
||||
enable_agent_cache_share_gp_subagent=_env_bool(
|
||||
"SURFSENSE_ENABLE_AGENT_CACHE_SHARE_GP_SUBAGENT", False
|
||||
),
|
||||
)
|
||||
|
||||
def any_new_middleware_enabled(self) -> bool:
|
||||
"""Return True if any new middleware flag is on."""
|
||||
if self.disable_new_agent_stack:
|
||||
return False
|
||||
return any(
|
||||
(
|
||||
self.enable_context_editing,
|
||||
self.enable_compaction_v2,
|
||||
self.enable_retry_after,
|
||||
self.enable_model_fallback,
|
||||
self.enable_model_call_limit,
|
||||
self.enable_tool_call_limit,
|
||||
self.enable_tool_call_repair,
|
||||
self.enable_doom_loop,
|
||||
self.enable_permission,
|
||||
self.enable_busy_mutex,
|
||||
self.enable_llm_tool_selector,
|
||||
self.enable_skills,
|
||||
self.enable_specialized_subagents,
|
||||
self.enable_kb_planner_runnable,
|
||||
self.enable_action_log,
|
||||
self.enable_revert_route,
|
||||
self.enable_plugin_loader,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_flags() -> AgentFeatureFlags:
|
||||
"""Return the resolved feature-flag state from the **current** process environment.
|
||||
|
||||
Intentionally **not** cached: ``load_dotenv`` and operator edits to env vars
|
||||
must affect the next agent build without requiring a full process restart.
|
||||
Cost is negligible (reads ``os.environ`` once per call).
|
||||
"""
|
||||
return AgentFeatureFlags.from_env()
|
||||
|
||||
|
||||
def reload_for_tests() -> AgentFeatureFlags:
|
||||
"""Compatibility helper for tests; equivalent to :func:`get_flags`."""
|
||||
return AgentFeatureFlags.from_env()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentFeatureFlags",
|
||||
"get_flags",
|
||||
"reload_for_tests",
|
||||
]
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
"""Filesystem mode contracts and selection helpers for chat sessions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class FilesystemMode(StrEnum):
|
||||
"""Supported filesystem backends for agent tool execution."""
|
||||
|
||||
CLOUD = "cloud"
|
||||
DESKTOP_LOCAL_FOLDER = "desktop_local_folder"
|
||||
|
||||
|
||||
class ClientPlatform(StrEnum):
|
||||
"""Client runtime reported by the caller."""
|
||||
|
||||
WEB = "web"
|
||||
DESKTOP = "desktop"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LocalFilesystemMount:
|
||||
"""Canonical mount mapping provided by desktop runtime."""
|
||||
|
||||
mount_id: str
|
||||
root_path: str
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class FilesystemSelection:
|
||||
"""Resolved filesystem selection for a single chat request."""
|
||||
|
||||
mode: FilesystemMode = FilesystemMode.CLOUD
|
||||
client_platform: ClientPlatform = ClientPlatform.WEB
|
||||
local_mounts: tuple[LocalFilesystemMount, ...] = ()
|
||||
|
||||
@property
|
||||
def is_local_mode(self) -> bool:
|
||||
return self.mode == FilesystemMode.DESKTOP_LOCAL_FOLDER
|
||||
|
|
@ -0,0 +1,624 @@
|
|||
"""
|
||||
LLM configuration utilities for SurfSense agents.
|
||||
|
||||
This module provides functions for loading LLM configurations from:
|
||||
1. Auto mode (ID 0) - Uses LiteLLM Router for load balancing
|
||||
2. YAML files (global configs with negative IDs)
|
||||
3. Database NewLLMConfig table (user-created configs with positive IDs)
|
||||
|
||||
It also provides utilities for creating ChatLiteLLM instances and
|
||||
managing prompt configurations.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_litellm import ChatLiteLLM
|
||||
from litellm import get_model_info
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.multi_agent_chat.shared.prompt_caching import (
|
||||
apply_litellm_prompt_caching,
|
||||
)
|
||||
from app.services.llm_router_service import (
|
||||
AUTO_MODE_ID,
|
||||
ChatLiteLLMRouter,
|
||||
LLMRouterService,
|
||||
_sanitize_content,
|
||||
get_auto_mode_llm,
|
||||
is_auto_mode,
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_messages(messages: list[BaseMessage]) -> list[BaseMessage]:
|
||||
"""Sanitize content on every message so it is safe for any provider.
|
||||
|
||||
Handles three cross-provider incompatibilities:
|
||||
- List content with provider-specific blocks (e.g. ``thinking``)
|
||||
- List content with bare strings or empty text blocks
|
||||
- AI messages with empty content + tool calls: some providers (Bedrock)
|
||||
convert ``""`` to ``[{"type":"text","text":""}]`` server-side then
|
||||
reject the blank text. The OpenAI spec says ``content`` should be
|
||||
``null`` when an assistant message only carries tool calls.
|
||||
"""
|
||||
for msg in messages:
|
||||
if isinstance(msg.content, list):
|
||||
msg.content = _sanitize_content(msg.content)
|
||||
if (
|
||||
isinstance(msg, AIMessage)
|
||||
and (not msg.content or msg.content == "")
|
||||
and getattr(msg, "tool_calls", None)
|
||||
):
|
||||
msg.content = None # type: ignore[assignment]
|
||||
return messages
|
||||
|
||||
|
||||
class SanitizedChatLiteLLM(ChatLiteLLM):
|
||||
"""ChatLiteLLM subclass that strips provider-specific content blocks
|
||||
(e.g. ``thinking`` from reasoning models) and normalises bare strings
|
||||
in content arrays before forwarding to the underlying provider."""
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
return super()._generate(
|
||||
_sanitize_messages(messages), stop, run_manager, **kwargs
|
||||
)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
async for chunk in super()._astream(
|
||||
_sanitize_messages(messages), stop, run_manager, **kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
# Provider mapping for LiteLLM model string construction.
|
||||
#
|
||||
# Single source of truth lives in
|
||||
# :mod:`app.services.provider_capabilities` so the YAML loader (which
|
||||
# runs during ``app.config`` class-body init) can resolve provider
|
||||
# prefixes without dragging the agent / tools tree into module load
|
||||
# order. Re-exported here under the historical ``PROVIDER_MAP`` name
|
||||
# so existing callers (``llm_router_service``, ``image_gen_router_service``,
|
||||
# tests) keep working unchanged.
|
||||
from app.services.provider_capabilities import ( # noqa: E402
|
||||
_PROVIDER_PREFIX_MAP as PROVIDER_MAP,
|
||||
)
|
||||
|
||||
|
||||
def _attach_model_profile(llm: ChatLiteLLM, model_string: str) -> None:
|
||||
"""Attach a ``profile`` dict to ChatLiteLLM with model context metadata."""
|
||||
try:
|
||||
info = get_model_info(model_string)
|
||||
max_input_tokens = info.get("max_input_tokens")
|
||||
if isinstance(max_input_tokens, int) and max_input_tokens > 0:
|
||||
llm.profile = {
|
||||
"max_input_tokens": max_input_tokens,
|
||||
"max_input_tokens_upper": max_input_tokens,
|
||||
"token_count_model": model_string,
|
||||
"token_count_models": [model_string],
|
||||
}
|
||||
except Exception:
|
||||
return
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
"""
|
||||
Complete configuration for the SurfSense agent.
|
||||
|
||||
This combines LLM settings with prompt configuration from NewLLMConfig.
|
||||
Supports Auto mode (ID 0) which uses LiteLLM Router for load balancing.
|
||||
"""
|
||||
|
||||
# LLM Model Settings
|
||||
provider: str
|
||||
model_name: str
|
||||
api_key: str
|
||||
api_base: str | None = None
|
||||
custom_provider: str | None = None
|
||||
litellm_params: dict | None = None
|
||||
|
||||
# Prompt Configuration
|
||||
system_instructions: str | None = None
|
||||
use_default_system_instructions: bool = True
|
||||
citations_enabled: bool = True
|
||||
|
||||
# Metadata
|
||||
config_id: int | None = None
|
||||
config_name: str | None = None
|
||||
|
||||
# Auto mode flag
|
||||
is_auto_mode: bool = False
|
||||
|
||||
# Token quota and policy
|
||||
billing_tier: str = "free"
|
||||
is_premium: bool = False
|
||||
anonymous_enabled: bool = False
|
||||
quota_reserve_tokens: int | None = None
|
||||
|
||||
# Capability flag: best-effort True for the chat selector / catalog.
|
||||
# Resolved via :func:`provider_capabilities.derive_supports_image_input`
|
||||
# which prefers OpenRouter's ``architecture.input_modalities`` and
|
||||
# otherwise consults LiteLLM's authoritative model map. Default True
|
||||
# is the conservative-allow stance — the streaming-task safety net
|
||||
# (``is_known_text_only_chat_model``) is the *only* place a False
|
||||
# actually blocks a request. Setting this to False here without an
|
||||
# authoritative source would silently hide vision-capable models
|
||||
# (the regression we're fixing).
|
||||
supports_image_input: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_auto_mode(cls) -> "AgentConfig":
|
||||
"""
|
||||
Create an AgentConfig for Auto mode (LiteLLM Router load balancing).
|
||||
|
||||
Returns:
|
||||
AgentConfig instance configured for Auto mode
|
||||
"""
|
||||
return cls(
|
||||
provider="AUTO",
|
||||
model_name="auto",
|
||||
api_key="", # Not needed for router
|
||||
api_base=None,
|
||||
custom_provider=None,
|
||||
litellm_params=None,
|
||||
system_instructions=None,
|
||||
use_default_system_instructions=True,
|
||||
citations_enabled=True,
|
||||
config_id=AUTO_MODE_ID,
|
||||
config_name="Auto (Fastest)",
|
||||
is_auto_mode=True,
|
||||
billing_tier="free",
|
||||
is_premium=False,
|
||||
anonymous_enabled=False,
|
||||
quota_reserve_tokens=None,
|
||||
# Auto routes across the configured pool, which usually
|
||||
# contains at least one vision-capable deployment; the router
|
||||
# will surface a 404 from a non-vision deployment as a normal
|
||||
# ``allowed_fails`` event and fail over rather than blocking
|
||||
# the request outright.
|
||||
supports_image_input=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_new_llm_config(cls, config) -> "AgentConfig":
|
||||
"""
|
||||
Create an AgentConfig from a NewLLMConfig database model.
|
||||
|
||||
Args:
|
||||
config: NewLLMConfig database model instance
|
||||
|
||||
Returns:
|
||||
AgentConfig instance
|
||||
"""
|
||||
# Lazy import to avoid pulling provider_capabilities (and its
|
||||
# transitive litellm import) into module-init order.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
provider_value = (
|
||||
config.provider.value
|
||||
if hasattr(config.provider, "value")
|
||||
else str(config.provider)
|
||||
)
|
||||
litellm_params = config.litellm_params or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model")
|
||||
if isinstance(litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
return cls(
|
||||
provider=provider_value,
|
||||
model_name=config.model_name,
|
||||
api_key=config.api_key,
|
||||
api_base=config.api_base,
|
||||
custom_provider=config.custom_provider,
|
||||
litellm_params=config.litellm_params,
|
||||
system_instructions=config.system_instructions,
|
||||
use_default_system_instructions=config.use_default_system_instructions,
|
||||
citations_enabled=config.citations_enabled,
|
||||
config_id=config.id,
|
||||
config_name=config.name,
|
||||
is_auto_mode=False,
|
||||
billing_tier="free",
|
||||
is_premium=False,
|
||||
anonymous_enabled=False,
|
||||
quota_reserve_tokens=None,
|
||||
# BYOK rows have no operator-curated capability flag, so we
|
||||
# ask LiteLLM (default-allow on unknown). The streaming
|
||||
# safety net still blocks if the model is *explicitly*
|
||||
# marked text-only.
|
||||
supports_image_input=derive_supports_image_input(
|
||||
provider=provider_value,
|
||||
model_name=config.model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=config.custom_provider,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig":
|
||||
"""
|
||||
Create an AgentConfig from a YAML configuration dictionary.
|
||||
|
||||
YAML configs now support the same prompt configuration fields as NewLLMConfig:
|
||||
- system_instructions: Custom system instructions (empty string uses defaults)
|
||||
- use_default_system_instructions: Whether to use default instructions
|
||||
- citations_enabled: Whether citations are enabled
|
||||
|
||||
Args:
|
||||
yaml_config: Configuration dictionary from YAML file
|
||||
|
||||
Returns:
|
||||
AgentConfig instance
|
||||
"""
|
||||
# Lazy import to avoid pulling provider_capabilities (and its
|
||||
# transitive litellm import) into module-init order.
|
||||
from app.services.provider_capabilities import derive_supports_image_input
|
||||
|
||||
# Get system instructions from YAML, default to empty string
|
||||
system_instructions = yaml_config.get("system_instructions", "")
|
||||
|
||||
provider = yaml_config.get("provider", "").upper()
|
||||
model_name = yaml_config.get("model_name", "")
|
||||
custom_provider = yaml_config.get("custom_provider")
|
||||
litellm_params = yaml_config.get("litellm_params") or {}
|
||||
base_model = (
|
||||
litellm_params.get("base_model")
|
||||
if isinstance(litellm_params, dict)
|
||||
else None
|
||||
)
|
||||
|
||||
# Explicit YAML override wins; otherwise derive from LiteLLM /
|
||||
# OpenRouter modalities. The YAML loader already populates this
|
||||
# field, but this method is also called from
|
||||
# ``load_global_llm_config_by_id``'s file fallback (hot reload),
|
||||
# so we re-derive here for safety. The bool() coercion preserves
|
||||
# the loader's behaviour for explicit ``true`` / ``false``
|
||||
# strings that PyYAML may surface.
|
||||
if "supports_image_input" in yaml_config:
|
||||
supports_image_input = bool(yaml_config.get("supports_image_input"))
|
||||
else:
|
||||
supports_image_input = derive_supports_image_input(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
custom_provider=custom_provider,
|
||||
)
|
||||
|
||||
return cls(
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
api_key=yaml_config.get("api_key", ""),
|
||||
api_base=yaml_config.get("api_base"),
|
||||
custom_provider=custom_provider,
|
||||
litellm_params=yaml_config.get("litellm_params"),
|
||||
# Prompt configuration from YAML (with defaults for backwards compatibility)
|
||||
system_instructions=system_instructions if system_instructions else None,
|
||||
use_default_system_instructions=yaml_config.get(
|
||||
"use_default_system_instructions", True
|
||||
),
|
||||
citations_enabled=yaml_config.get("citations_enabled", True),
|
||||
config_id=yaml_config.get("id"),
|
||||
config_name=yaml_config.get("name"),
|
||||
is_auto_mode=False,
|
||||
billing_tier=yaml_config.get("billing_tier", "free"),
|
||||
is_premium=yaml_config.get("billing_tier", "free") == "premium",
|
||||
anonymous_enabled=yaml_config.get("anonymous_enabled", False),
|
||||
quota_reserve_tokens=yaml_config.get("quota_reserve_tokens"),
|
||||
supports_image_input=supports_image_input,
|
||||
)
|
||||
|
||||
|
||||
def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None:
|
||||
"""
|
||||
Load a specific LLM config from global_llm_config.yaml.
|
||||
|
||||
Args:
|
||||
llm_config_id: The id of the config to load (default: -1)
|
||||
|
||||
Returns:
|
||||
LLM config dict or None if not found
|
||||
"""
|
||||
# Get the config file path
|
||||
base_dir = Path(__file__).resolve().parent.parent.parent.parent
|
||||
config_file = base_dir / "app" / "config" / "global_llm_config.yaml"
|
||||
|
||||
# Fallback to example file if main config doesn't exist
|
||||
if not config_file.exists():
|
||||
config_file = base_dir / "app" / "config" / "global_llm_config.example.yaml"
|
||||
if not config_file.exists():
|
||||
print("Error: No global_llm_config.yaml or example file found")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(config_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
configs = data.get("global_llm_configs", [])
|
||||
for cfg in configs:
|
||||
if isinstance(cfg, dict) and cfg.get("id") == llm_config_id:
|
||||
return cfg
|
||||
|
||||
print(f"Error: Global LLM config id {llm_config_id} not found")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error loading config: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def load_global_llm_config_by_id(llm_config_id: int) -> dict | None:
|
||||
"""
|
||||
Load a global LLM config by ID, checking in-memory configs first.
|
||||
|
||||
This handles both static YAML configs and dynamically injected configs
|
||||
(e.g. OpenRouter integration models that only exist in memory).
|
||||
|
||||
Args:
|
||||
llm_config_id: The negative ID of the global config to load
|
||||
|
||||
Returns:
|
||||
LLM config dict or None if not found
|
||||
"""
|
||||
from app.config import config as app_config
|
||||
|
||||
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
||||
if cfg.get("id") == llm_config_id:
|
||||
return cfg
|
||||
# Fallback to YAML file read (covers edge cases like hot-reload)
|
||||
return load_llm_config_from_yaml(llm_config_id)
|
||||
|
||||
|
||||
async def load_new_llm_config_from_db(
|
||||
session: AsyncSession,
|
||||
config_id: int,
|
||||
) -> "AgentConfig | None":
|
||||
"""
|
||||
Load a NewLLMConfig from the database by ID.
|
||||
|
||||
Args:
|
||||
session: AsyncSession for database access
|
||||
config_id: The ID of the NewLLMConfig to load
|
||||
|
||||
Returns:
|
||||
AgentConfig instance or None if not found
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from app.db import NewLLMConfig
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
|
||||
)
|
||||
config = result.scalars().first()
|
||||
|
||||
if not config:
|
||||
print(f"Error: NewLLMConfig with id {config_id} not found")
|
||||
return None
|
||||
|
||||
return AgentConfig.from_new_llm_config(config)
|
||||
except Exception as e:
|
||||
print(f"Error loading NewLLMConfig from database: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def load_agent_llm_config_for_search_space(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> "AgentConfig | None":
|
||||
"""
|
||||
Load the agent LLM configuration for a search space.
|
||||
|
||||
This loads the LLM config based on the search space's agent_llm_id setting:
|
||||
- Positive ID: Load from NewLLMConfig database table
|
||||
- Negative ID: Load from YAML global configs
|
||||
- None: Falls back to first global config (id=-1)
|
||||
|
||||
Args:
|
||||
session: AsyncSession for database access
|
||||
search_space_id: The search space ID
|
||||
|
||||
Returns:
|
||||
AgentConfig instance or None if not found
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from app.db import SearchSpace
|
||||
|
||||
try:
|
||||
# Get the search space to check its agent_llm_id preference
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
|
||||
if not search_space:
|
||||
print(f"Error: SearchSpace with id {search_space_id} not found")
|
||||
return None
|
||||
|
||||
# Use agent_llm_id from search space, fallback to -1 (first global config)
|
||||
config_id = (
|
||||
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
|
||||
)
|
||||
|
||||
# Load the config using the unified loader
|
||||
return await load_agent_config(session, config_id, search_space_id)
|
||||
except Exception as e:
|
||||
print(f"Error loading agent LLM config for search space {search_space_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def load_agent_config(
|
||||
session: AsyncSession,
|
||||
config_id: int,
|
||||
search_space_id: int | None = None,
|
||||
) -> "AgentConfig | None":
|
||||
"""
|
||||
Load an agent configuration, supporting Auto mode, YAML, and database configs.
|
||||
|
||||
This is the main entry point for loading configurations:
|
||||
- ID 0: Auto mode (uses LiteLLM Router for load balancing)
|
||||
- Negative IDs: Load from YAML file (global configs)
|
||||
- Positive IDs: Load from NewLLMConfig database table
|
||||
|
||||
Args:
|
||||
session: AsyncSession for database access
|
||||
config_id: The config ID (0 for Auto, negative for YAML, positive for database)
|
||||
search_space_id: Optional search space ID for context
|
||||
|
||||
Returns:
|
||||
AgentConfig instance or None if not found
|
||||
"""
|
||||
# Auto mode (ID 0) - use LiteLLM Router
|
||||
if is_auto_mode(config_id):
|
||||
if not LLMRouterService.is_initialized():
|
||||
print("Error: Auto mode requested but LLM Router not initialized")
|
||||
return None
|
||||
return AgentConfig.from_auto_mode()
|
||||
|
||||
if config_id < 0:
|
||||
# Check in-memory configs first (includes static YAML + dynamic OpenRouter)
|
||||
from app.config import config as app_config
|
||||
|
||||
for cfg in app_config.GLOBAL_LLM_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return AgentConfig.from_yaml_config(cfg)
|
||||
# Fallback to YAML file read for safety
|
||||
yaml_config = load_llm_config_from_yaml(config_id)
|
||||
if yaml_config:
|
||||
return AgentConfig.from_yaml_config(yaml_config)
|
||||
return None
|
||||
else:
|
||||
# Load from database (NewLLMConfig)
|
||||
return await load_new_llm_config_from_db(session, config_id)
|
||||
|
||||
|
||||
def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
|
||||
"""
|
||||
Create a ChatLiteLLM instance from a global LLM config dictionary.
|
||||
|
||||
Args:
|
||||
llm_config: LLM configuration dictionary from YAML
|
||||
|
||||
Returns:
|
||||
ChatLiteLLM instance or None on error
|
||||
"""
|
||||
# Build the model string
|
||||
if llm_config.get("custom_provider"):
|
||||
model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}"
|
||||
else:
|
||||
provider = llm_config.get("provider", "").upper()
|
||||
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
|
||||
model_string = f"{provider_prefix}/{llm_config['model_name']}"
|
||||
|
||||
# Create ChatLiteLLM instance with streaming enabled
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": llm_config.get("api_key"),
|
||||
"streaming": True, # Enable streaming for real-time token streaming
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if llm_config.get("api_base"):
|
||||
litellm_kwargs["api_base"] = llm_config["api_base"]
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if llm_config.get("litellm_params"):
|
||||
litellm_kwargs.update(llm_config["litellm_params"])
|
||||
|
||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
_attach_model_profile(llm, model_string)
|
||||
# Configure LiteLLM-native prompt caching (cache_control_injection_points
|
||||
# for Anthropic/Bedrock/Vertex/Gemini/Azure-AI/OpenRouter/Databricks/etc.).
|
||||
# ``agent_config=None`` here — the YAML path doesn't have provider intent
|
||||
# in a structured form, so we set only the universal injection points.
|
||||
apply_litellm_prompt_caching(llm)
|
||||
return llm
|
||||
|
||||
|
||||
def create_chat_litellm_from_agent_config(
|
||||
agent_config: AgentConfig,
|
||||
) -> ChatLiteLLM | ChatLiteLLMRouter | None:
|
||||
"""
|
||||
Create a ChatLiteLLM or ChatLiteLLMRouter instance from an AgentConfig.
|
||||
|
||||
For Auto mode configs, returns a ChatLiteLLMRouter that uses LiteLLM Router
|
||||
for automatic load balancing across available providers.
|
||||
|
||||
Args:
|
||||
agent_config: AgentConfig instance
|
||||
|
||||
Returns:
|
||||
ChatLiteLLM or ChatLiteLLMRouter instance, or None on error
|
||||
"""
|
||||
# Handle Auto mode - return ChatLiteLLMRouter
|
||||
if agent_config.is_auto_mode:
|
||||
if not LLMRouterService.is_initialized():
|
||||
print("Error: Auto mode requested but LLM Router not initialized")
|
||||
return None
|
||||
try:
|
||||
router_llm = get_auto_mode_llm()
|
||||
if router_llm is not None:
|
||||
# Universal cache_control_injection_points only — auto-mode
|
||||
# fans out across providers, so OpenAI-only kwargs (e.g.
|
||||
# ``prompt_cache_key``) are left off here. ``drop_params``
|
||||
# would strip them at the provider boundary anyway, but
|
||||
# there's no point setting them when we don't know the
|
||||
# destination.
|
||||
apply_litellm_prompt_caching(router_llm, agent_config=agent_config)
|
||||
return router_llm
|
||||
except Exception as e:
|
||||
print(f"Error creating ChatLiteLLMRouter: {e}")
|
||||
return None
|
||||
|
||||
# Build the model string
|
||||
if agent_config.custom_provider:
|
||||
model_string = f"{agent_config.custom_provider}/{agent_config.model_name}"
|
||||
else:
|
||||
provider_prefix = PROVIDER_MAP.get(
|
||||
agent_config.provider, agent_config.provider.lower()
|
||||
)
|
||||
model_string = f"{provider_prefix}/{agent_config.model_name}"
|
||||
|
||||
# Create ChatLiteLLM instance with streaming enabled
|
||||
litellm_kwargs = {
|
||||
"model": model_string,
|
||||
"api_key": agent_config.api_key,
|
||||
"streaming": True, # Enable streaming for real-time token streaming
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if agent_config.api_base:
|
||||
litellm_kwargs["api_base"] = agent_config.api_base
|
||||
|
||||
# Add any additional litellm parameters
|
||||
if agent_config.litellm_params:
|
||||
litellm_kwargs.update(agent_config.litellm_params)
|
||||
|
||||
llm = SanitizedChatLiteLLM(**litellm_kwargs)
|
||||
_attach_model_profile(llm, model_string)
|
||||
# Build-time prompt caching: sets ``cache_control_injection_points`` for
|
||||
# all providers and (for OpenAI/DeepSeek/xAI) ``prompt_cache_retention``.
|
||||
# Per-thread ``prompt_cache_key`` is layered on later in
|
||||
# ``create_surfsense_deep_agent`` once ``thread_id`` is known.
|
||||
apply_litellm_prompt_caching(llm, agent_config=agent_config)
|
||||
return llm
|
||||
|
|
@ -0,0 +1,277 @@
|
|||
"""Resolve @-mention chips to canonical virtual paths and substitute the
|
||||
user-visible ``@title`` tokens with backtick-wrapped paths in the prompt
|
||||
the agent sees.
|
||||
|
||||
The frontend's mention seam is a single discriminated-union list of
|
||||
``{kind: "doc" | "folder", id, title, document_type?}`` chips (see
|
||||
``surfsense_web/atoms/chat/mentioned-documents.atom.ts``). When a turn
|
||||
reaches the backend stream task we have three needs that this module
|
||||
centralises:
|
||||
|
||||
1. Map each chip to its canonical virtual path
|
||||
(``/documents/.../file.xml`` for docs, ``/documents/MyFolder/`` for
|
||||
folders) so the agent sees concrete filesystem locations instead of
|
||||
ambiguous ``@``-titles.
|
||||
2. Substitute ``@title`` tokens in the user-typed text with backtick-
|
||||
wrapped paths so the path becomes part of the ``HumanMessage`` body
|
||||
the LLM consumes — without rewriting the persisted user message
|
||||
text (which keeps ``@title`` so chip rendering on reload is
|
||||
unchanged).
|
||||
3. Surface the resolved id sets (docs + folders) to the priority
|
||||
middleware so it can render ``[USER-MENTIONED]`` priority entries
|
||||
without re-doing path resolution.
|
||||
|
||||
This is intentionally one module — see the architectural note in
|
||||
``mention-paths-and-folders`` plan: previously the doc-resolution lived
|
||||
inline in ``stream_new_chat`` and the folder mention had no resolution
|
||||
at all. Centralising both behind a single ``resolve_mentions`` call
|
||||
turns a leaky multi-field seam into a single deeper interface.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.multi_agent_chat.shared.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
build_path_index,
|
||||
doc_to_virtual_path,
|
||||
)
|
||||
from app.db import Document, Folder
|
||||
from app.schemas.new_chat import MentionedDocumentInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedMention:
|
||||
"""Canonical view of a single @-mention chip.
|
||||
|
||||
``virtual_path`` is the path the agent will see (no trailing slash
|
||||
for documents, trailing ``/`` for folders to match the convention
|
||||
used by ``KnowledgeTreeMiddleware``).
|
||||
"""
|
||||
|
||||
kind: str # "doc" | "folder"
|
||||
id: int
|
||||
title: str
|
||||
virtual_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolvedMentionSet:
|
||||
"""Aggregate result of resolving a turn's mention chips.
|
||||
|
||||
``token_to_path`` maps ``@title`` (the literal token the user typed
|
||||
and the editor emitted) to the canonical virtual path for that
|
||||
chip. It is produced longest-token-first so substitution mirrors
|
||||
``parseMentionSegments`` on the frontend (a longer title like
|
||||
``@Project Roadmap`` is never shadowed by a shorter prefix
|
||||
``@Project``).
|
||||
|
||||
``mentioned_document_ids`` is an ordered, deduped list consumed by
|
||||
the priority middleware downstream — see
|
||||
``KnowledgePriorityMiddleware._compute_priority_paths``.
|
||||
"""
|
||||
|
||||
mentions: list[ResolvedMention] = field(default_factory=list)
|
||||
token_to_path: list[tuple[str, str]] = field(default_factory=list)
|
||||
mentioned_document_ids: list[int] = field(default_factory=list)
|
||||
mentioned_folder_ids: list[int] = field(default_factory=list)
|
||||
|
||||
|
||||
def _folder_virtual_path(folder_id: int, folder_paths: dict[int, str]) -> str:
|
||||
"""Return ``/documents/Folder/Sub/`` for a folder id.
|
||||
|
||||
Falls back to the documents root when the folder is missing from
|
||||
the index (deleted or in a different search space). Trailing slash
|
||||
matches ``KnowledgeTreeMiddleware`` (``/documents/MyFolder/``) so
|
||||
the agent's ``ls`` can dispatch on it as a directory.
|
||||
"""
|
||||
base = folder_paths.get(folder_id, DOCUMENTS_ROOT)
|
||||
return f"{base}/" if not base.endswith("/") else base
|
||||
|
||||
|
||||
async def resolve_mentions(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
mentioned_documents: list[MentionedDocumentInfo] | None,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
mentioned_folder_ids: list[int] | None = None,
|
||||
) -> ResolvedMentionSet:
|
||||
"""Resolve every @-mention chip on a turn into virtual paths.
|
||||
|
||||
The function takes both the ``mentioned_documents`` discriminated
|
||||
list (chip metadata used for substitution + persistence) and the
|
||||
parallel id arrays (``mentioned_document_ids``,
|
||||
``mentioned_folder_ids``) for two reasons:
|
||||
|
||||
* Legacy clients that haven't migrated to the unified chip list
|
||||
still send the id arrays — we treat the union as authoritative.
|
||||
* The id arrays are the canonical input to
|
||||
``KnowledgePriorityMiddleware`` (via ``SurfSenseContextSchema``);
|
||||
returning the deduped, validated lists lets the route forward
|
||||
them unchanged.
|
||||
|
||||
Resolution is best-effort: a chip whose id no longer exists (e.g.
|
||||
document was deleted between mention and submit) is silently
|
||||
dropped. The agent still sees the user's original text, just
|
||||
without a backtick-path substitution for that chip.
|
||||
"""
|
||||
chip_doc_ids: list[int] = []
|
||||
chip_folder_ids: list[int] = []
|
||||
chip_titles_by_id: dict[tuple[str, int], str] = {}
|
||||
if mentioned_documents:
|
||||
for chip in mentioned_documents:
|
||||
kind = chip.kind
|
||||
if kind == "folder":
|
||||
chip_folder_ids.append(chip.id)
|
||||
elif kind == "doc":
|
||||
chip_doc_ids.append(chip.id)
|
||||
chip_titles_by_id[(kind, chip.id)] = chip.title
|
||||
|
||||
doc_id_pool: list[int] = list(
|
||||
dict.fromkeys(
|
||||
[
|
||||
*(mentioned_document_ids or []),
|
||||
*chip_doc_ids,
|
||||
]
|
||||
)
|
||||
)
|
||||
folder_id_pool: list[int] = list(
|
||||
dict.fromkeys([*(mentioned_folder_ids or []), *chip_folder_ids])
|
||||
)
|
||||
|
||||
if not doc_id_pool and not folder_id_pool:
|
||||
return ResolvedMentionSet()
|
||||
|
||||
index = await build_path_index(session, search_space_id)
|
||||
|
||||
doc_rows: dict[int, Document] = {}
|
||||
if doc_id_pool:
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.id.in_(doc_id_pool),
|
||||
)
|
||||
)
|
||||
for row in result.scalars().all():
|
||||
doc_rows[row.id] = row
|
||||
|
||||
folder_rows: dict[int, Folder] = {}
|
||||
if folder_id_pool:
|
||||
result = await session.execute(
|
||||
select(Folder).where(
|
||||
Folder.search_space_id == search_space_id,
|
||||
Folder.id.in_(folder_id_pool),
|
||||
)
|
||||
)
|
||||
for row in result.scalars().all():
|
||||
folder_rows[row.id] = row
|
||||
|
||||
resolved: list[ResolvedMention] = []
|
||||
accepted_doc_ids: list[int] = []
|
||||
accepted_folder_ids: list[int] = []
|
||||
|
||||
for doc_id in doc_id_pool:
|
||||
row = doc_rows.get(doc_id)
|
||||
if row is None:
|
||||
logger.debug(
|
||||
"mention_resolver: dropping doc id=%s (not found in space=%s)",
|
||||
doc_id,
|
||||
search_space_id,
|
||||
)
|
||||
continue
|
||||
title = chip_titles_by_id.get(("doc", doc_id), str(row.title or ""))
|
||||
path = doc_to_virtual_path(
|
||||
doc_id=row.id,
|
||||
title=str(row.title or "untitled"),
|
||||
folder_id=row.folder_id,
|
||||
index=index,
|
||||
)
|
||||
resolved.append(
|
||||
ResolvedMention(kind="doc", id=row.id, title=title, virtual_path=path)
|
||||
)
|
||||
accepted_doc_ids.append(row.id)
|
||||
|
||||
for folder_id in folder_id_pool:
|
||||
row = folder_rows.get(folder_id)
|
||||
if row is None:
|
||||
logger.debug(
|
||||
"mention_resolver: dropping folder id=%s (not found in space=%s)",
|
||||
folder_id,
|
||||
search_space_id,
|
||||
)
|
||||
continue
|
||||
title = chip_titles_by_id.get(("folder", folder_id), str(row.name or ""))
|
||||
path = _folder_virtual_path(row.id, index.folder_paths)
|
||||
resolved.append(
|
||||
ResolvedMention(kind="folder", id=row.id, title=title, virtual_path=path)
|
||||
)
|
||||
accepted_folder_ids.append(row.id)
|
||||
|
||||
token_to_path: list[tuple[str, str]] = []
|
||||
seen_tokens: set[str] = set()
|
||||
for mention in resolved:
|
||||
if not mention.title:
|
||||
continue
|
||||
token = f"@{mention.title}"
|
||||
if token in seen_tokens:
|
||||
continue
|
||||
seen_tokens.add(token)
|
||||
token_to_path.append((token, mention.virtual_path))
|
||||
token_to_path.sort(key=lambda pair: len(pair[0]), reverse=True)
|
||||
|
||||
return ResolvedMentionSet(
|
||||
mentions=resolved,
|
||||
token_to_path=token_to_path,
|
||||
mentioned_document_ids=accepted_doc_ids,
|
||||
mentioned_folder_ids=accepted_folder_ids,
|
||||
)
|
||||
|
||||
|
||||
def substitute_in_text(text: str, token_to_path: list[tuple[str, str]]) -> str:
|
||||
"""Replace each ``@title`` token with a backtick-wrapped virtual path.
|
||||
|
||||
Mirrors ``parseMentionSegments`` on the frontend: longest token
|
||||
first, single forward pass, no regex (titles can contain regex
|
||||
metacharacters). The substitution is idempotent for already-
|
||||
substituted text because the backtick-wrapped path no longer
|
||||
starts with ``@``.
|
||||
|
||||
Empty / no-op cases short-circuit so callers can pass this through
|
||||
unconditionally without paying for a scan.
|
||||
"""
|
||||
if not text or not token_to_path:
|
||||
return text
|
||||
|
||||
out: list[str] = []
|
||||
i = 0
|
||||
n = len(text)
|
||||
while i < n:
|
||||
matched: tuple[str, str] | None = None
|
||||
for token, path in token_to_path:
|
||||
if text.startswith(token, i):
|
||||
matched = (token, path)
|
||||
break
|
||||
if matched is None:
|
||||
out.append(text[i])
|
||||
i += 1
|
||||
continue
|
||||
token, path = matched
|
||||
out.append(f"`{path}`")
|
||||
i += len(token)
|
||||
return "".join(out)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ResolvedMention",
|
||||
"ResolvedMentionSet",
|
||||
"resolve_mentions",
|
||||
"substitute_in_text",
|
||||
]
|
||||
|
|
@ -0,0 +1,328 @@
|
|||
"""
|
||||
BusyMutexMiddleware — per-thread asyncio lock + cancel token.
|
||||
|
||||
LangChain has no built-in concept of "this thread is already running a
|
||||
turn — refuse the second concurrent request". Without it, a user
|
||||
double-clicking "send" or refreshing the page mid-stream can spawn two
|
||||
turns racing on the same checkpoint, producing duplicated tool calls
|
||||
and mangled state.
|
||||
|
||||
Ported from OpenCode's ``Stream.scoped(AbortController)`` pattern: a
|
||||
single-process, in-memory lock + cooperative cancellation token keyed by
|
||||
``thread_id``. For multi-worker deployments a distributed lock backend
|
||||
(Redis or PostgreSQL advisory locks) is a phase-2 follow-up.
|
||||
|
||||
What this provides:
|
||||
- A ``WeakValueDictionary[str, asyncio.Lock]`` keyed by ``thread_id``;
|
||||
acquiring the lock during ``before_agent`` blocks any concurrent
|
||||
prompt on the same thread until release.
|
||||
- A per-thread ``asyncio.Event`` (``cancel_event``) that long-running
|
||||
tools can poll to abort cooperatively. The event is reset between
|
||||
turns. Tools should check ``runtime.context.cancel_event.is_set()``
|
||||
in tight inner loops.
|
||||
- A typed :class:`~app.agents.multi_agent_chat.shared.errors.BusyError` raised when a
|
||||
second turn arrives while the lock is held.
|
||||
|
||||
Note: SurfSense's ``stream_new_chat`` is the call site that should
|
||||
acquire/release. Wiring this as middleware means the contract is
|
||||
explicit and the lock manager is shared with subagents that compile
|
||||
their own ``create_agent`` runnables.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ResponseT,
|
||||
)
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agents.multi_agent_chat.shared.errors import BusyError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _ThreadLockManager:
|
||||
"""Process-local registry of per-thread asyncio locks + cancel events."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = (
|
||||
weakref.WeakValueDictionary()
|
||||
)
|
||||
self._cancel_events: dict[str, asyncio.Event] = {}
|
||||
self._cancel_requested_at_ms: dict[str, int] = {}
|
||||
self._cancel_attempt_count: dict[str, int] = {}
|
||||
# Monotonic per-thread epoch used to prevent stale middleware
|
||||
# teardown from releasing a newer turn's lock.
|
||||
self._turn_epoch: dict[str, int] = {}
|
||||
|
||||
def lock_for(self, thread_id: str) -> asyncio.Lock:
|
||||
lock = self._locks.get(thread_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
self._locks[thread_id] = lock
|
||||
return lock
|
||||
|
||||
def cancel_event(self, thread_id: str) -> asyncio.Event:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
if event is None:
|
||||
event = asyncio.Event()
|
||||
self._cancel_events[thread_id] = event
|
||||
return event
|
||||
|
||||
def request_cancel(self, thread_id: str) -> bool:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
if event is None:
|
||||
event = asyncio.Event()
|
||||
self._cancel_events[thread_id] = event
|
||||
event.set()
|
||||
now_ms = int(time.time() * 1000)
|
||||
self._cancel_requested_at_ms[thread_id] = now_ms
|
||||
self._cancel_attempt_count[thread_id] = (
|
||||
self._cancel_attempt_count.get(thread_id, 0) + 1
|
||||
)
|
||||
return True
|
||||
|
||||
def is_cancel_requested(self, thread_id: str) -> bool:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
return bool(event and event.is_set())
|
||||
|
||||
def cancel_state(self, thread_id: str) -> tuple[int, int] | None:
|
||||
if not self.is_cancel_requested(thread_id):
|
||||
return None
|
||||
attempts = self._cancel_attempt_count.get(thread_id, 1)
|
||||
requested_at_ms = self._cancel_requested_at_ms.get(thread_id, 0)
|
||||
return attempts, requested_at_ms
|
||||
|
||||
def reset(self, thread_id: str) -> None:
|
||||
event = self._cancel_events.get(thread_id)
|
||||
if event is not None:
|
||||
event.clear()
|
||||
self._cancel_requested_at_ms.pop(thread_id, None)
|
||||
self._cancel_attempt_count.pop(thread_id, None)
|
||||
|
||||
def bump_turn_epoch(self, thread_id: str) -> int:
|
||||
epoch = self._turn_epoch.get(thread_id, 0) + 1
|
||||
self._turn_epoch[thread_id] = epoch
|
||||
return epoch
|
||||
|
||||
def current_turn_epoch(self, thread_id: str) -> int:
|
||||
return self._turn_epoch.get(thread_id, 0)
|
||||
|
||||
def end_turn(self, thread_id: str) -> None:
|
||||
"""Best-effort terminal cleanup for a thread turn.
|
||||
|
||||
This is intentionally idempotent and safe to call from outer stream
|
||||
finally-blocks where middleware teardown might be skipped due to abort
|
||||
or disconnect edge-cases.
|
||||
"""
|
||||
# Invalidate any in-flight middleware holder first. This guarantees a
|
||||
# stale ``aafter_agent`` from an older attempt cannot unlock a newer
|
||||
# retry that already acquired the lock for the same thread.
|
||||
self.bump_turn_epoch(thread_id)
|
||||
lock = self._locks.get(thread_id)
|
||||
if lock is not None and lock.locked():
|
||||
lock.release()
|
||||
self.reset(thread_id)
|
||||
|
||||
def release(self, thread_id: str) -> bool:
|
||||
"""Force-release the per-thread lock; safety-net for turns that end before ``__end__``.
|
||||
|
||||
``BusyMutexMiddleware.aafter_agent`` only releases on graph completion, so
|
||||
an ``interrupt()`` pause or an early streaming bail-out would otherwise
|
||||
leak the lock and block the next request with :class:`BusyError`. Returns
|
||||
``True`` when a held lock was released, ``False`` otherwise.
|
||||
"""
|
||||
lock = self._locks.get(thread_id)
|
||||
if lock is None or not lock.locked():
|
||||
return False
|
||||
try:
|
||||
lock.release()
|
||||
except RuntimeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# Module-level singleton — process-local but reused across all agent
|
||||
# instances built in this process. Subagents created in nested
|
||||
# ``create_agent`` calls also get this so locks are coherent.
|
||||
manager = _ThreadLockManager()
|
||||
|
||||
|
||||
def get_cancel_event(thread_id: str) -> asyncio.Event:
|
||||
"""Public accessor used by long-running tools to poll cancellation."""
|
||||
return manager.cancel_event(thread_id)
|
||||
|
||||
|
||||
def request_cancel(thread_id: str) -> bool:
|
||||
"""Trip the cancel event for ``thread_id``. Always returns True."""
|
||||
return manager.request_cancel(thread_id)
|
||||
|
||||
|
||||
def is_cancel_requested(thread_id: str) -> bool:
|
||||
"""Return whether ``thread_id`` currently has a pending cancel signal."""
|
||||
return manager.is_cancel_requested(thread_id)
|
||||
|
||||
|
||||
def get_cancel_state(thread_id: str) -> tuple[int, int] | None:
|
||||
"""Return ``(attempt_count, requested_at_ms)`` for pending cancel state."""
|
||||
return manager.cancel_state(thread_id)
|
||||
|
||||
|
||||
def reset_cancel(thread_id: str) -> None:
|
||||
"""Reset the cancel event for ``thread_id`` (called between turns)."""
|
||||
manager.reset(thread_id)
|
||||
|
||||
|
||||
def end_turn(thread_id: str) -> None:
|
||||
"""Force end-of-turn cleanup for lock + cancel state."""
|
||||
manager.end_turn(thread_id)
|
||||
|
||||
|
||||
class BusyMutexMiddleware(AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]):
|
||||
"""Block concurrent prompts on the same thread.
|
||||
|
||||
Acquires the thread's lock in ``abefore_agent`` and releases in
|
||||
``aafter_agent``. If the lock is held, raises :class:`BusyError`
|
||||
so the caller can emit a ``surfsense.busy`` SSE event with the
|
||||
in-flight request id.
|
||||
|
||||
Args:
|
||||
require_thread_id: When True, raise :class:`BusyError` if no
|
||||
``thread_id`` can be resolved from the active
|
||||
``RunnableConfig``. Default is False — we treat a missing
|
||||
thread_id as "this turn has nothing to lock against" and
|
||||
no-op the mutex. Set True only when you trust the call
|
||||
site to always provide ``configurable.thread_id`` (e.g.
|
||||
in production where ``stream_new_chat`` always does).
|
||||
"""
|
||||
|
||||
def __init__(self, *, require_thread_id: bool = False) -> None:
|
||||
super().__init__()
|
||||
self._require_thread_id = require_thread_id
|
||||
self.tools = []
|
||||
# Per-call lock ownership tracked as (lock, epoch). ``aafter_agent``
|
||||
# only releases when its epoch still matches the manager's current
|
||||
# epoch for the thread, preventing stale unlock races.
|
||||
self._held_locks: dict[str, tuple[asyncio.Lock, int]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _thread_id(runtime: Runtime[ContextT]) -> str | None:
|
||||
"""Extract ``thread_id`` from the active LangGraph ``RunnableConfig``.
|
||||
|
||||
``langgraph.runtime.Runtime`` deliberately does NOT expose ``config``.
|
||||
The runnable config (where ``configurable.thread_id`` lives) must be
|
||||
fetched via :func:`langgraph.config.get_config` from inside a node /
|
||||
middleware. We fall back to ``getattr(runtime, "config", None)`` for
|
||||
unit tests / legacy runtimes that synthesize a config-bearing stub.
|
||||
"""
|
||||
|
||||
def _from_dict(cfg: Any) -> str | None:
|
||||
if not isinstance(cfg, dict):
|
||||
return None
|
||||
tid = (cfg.get("configurable") or {}).get("thread_id")
|
||||
return str(tid) if tid is not None else None
|
||||
|
||||
# Preferred path: real LangGraph runtime context.
|
||||
try:
|
||||
tid = _from_dict(get_config())
|
||||
except Exception:
|
||||
tid = None
|
||||
if tid is not None:
|
||||
return tid
|
||||
|
||||
# Fallback for tests and any runtime that surfaces a config dict
|
||||
# directly on the runtime instance.
|
||||
return _from_dict(getattr(runtime, "config", None))
|
||||
|
||||
async def abefore_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState[Any],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
del state
|
||||
thread_id = self._thread_id(runtime)
|
||||
if thread_id is None:
|
||||
if self._require_thread_id:
|
||||
raise BusyError("no thread_id configured")
|
||||
logger.debug(
|
||||
"BusyMutexMiddleware: no thread_id resolved from RunnableConfig; "
|
||||
"skipping per-thread lock for this turn."
|
||||
)
|
||||
return None
|
||||
|
||||
lock = manager.lock_for(thread_id)
|
||||
if lock.locked():
|
||||
raise BusyError(request_id=thread_id)
|
||||
await lock.acquire()
|
||||
epoch = manager.bump_turn_epoch(thread_id)
|
||||
self._held_locks[thread_id] = (lock, epoch)
|
||||
# Reset the cancel event so this turn starts fresh
|
||||
reset_cancel(thread_id)
|
||||
return None
|
||||
|
||||
async def aafter_agent( # type: ignore[override]
|
||||
self,
|
||||
state: AgentState[Any],
|
||||
runtime: Runtime[ContextT],
|
||||
) -> dict[str, Any] | None:
|
||||
del state
|
||||
thread_id = self._thread_id(runtime)
|
||||
if thread_id is None:
|
||||
return None
|
||||
held = self._held_locks.pop(thread_id, None)
|
||||
if held is None:
|
||||
return None
|
||||
lock, held_epoch = held
|
||||
if held_epoch != manager.current_turn_epoch(thread_id):
|
||||
# Stale teardown from an older attempt (e.g. runtime-recovery path
|
||||
# already advanced epoch). Do not touch current lock/cancel state.
|
||||
return None
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
# Always clear cancel event between turns so a stale signal
|
||||
# doesn't leak into the next request.
|
||||
reset_cancel(thread_id)
|
||||
return None
|
||||
|
||||
# Provide sync no-ops because the middleware base class allows them
|
||||
def before_agent( # type: ignore[override]
|
||||
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
# Sync path: no asyncio.Lock to acquire. Best we can do is reject
|
||||
# if anyone else is in flight.
|
||||
thread_id = self._thread_id(runtime)
|
||||
if thread_id is None:
|
||||
if self._require_thread_id:
|
||||
raise BusyError("no thread_id configured")
|
||||
return None
|
||||
lock = manager.lock_for(thread_id)
|
||||
if lock.locked():
|
||||
raise BusyError(request_id=thread_id)
|
||||
return None
|
||||
|
||||
def after_agent( # type: ignore[override]
|
||||
self, state: AgentState[Any], runtime: Runtime[ContextT]
|
||||
) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BusyMutexMiddleware",
|
||||
"end_turn",
|
||||
"get_cancel_event",
|
||||
"get_cancel_state",
|
||||
"is_cancel_requested",
|
||||
"manager",
|
||||
"request_cancel",
|
||||
"reset_cancel",
|
||||
]
|
||||
|
|
@ -45,7 +45,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from app.agents.multi_agent_chat.shared.middleware.filesystem.backends.document_xml import (
|
||||
build_document_xml,
|
||||
)
|
||||
from app.agents.shared.path_resolver import (
|
||||
from app.agents.multi_agent_chat.shared.path_resolver import (
|
||||
DOCUMENTS_ROOT,
|
||||
build_path_index,
|
||||
doc_to_virtual_path,
|
||||
|
|
|
|||
|
|
@ -9,13 +9,16 @@ from deepagents.backends.protocol import BackendProtocol
|
|||
from deepagents.backends.state import StateBackend
|
||||
from langgraph.prebuilt.tool_node import ToolRuntime
|
||||
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import (
|
||||
FilesystemMode,
|
||||
FilesystemSelection,
|
||||
)
|
||||
from app.agents.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import (
|
||||
KBPostgresBackend,
|
||||
)
|
||||
from app.agents.multi_agent_chat.shared.middleware.filesystem.backends.multi_root_local_folder import (
|
||||
MultiRootLocalFolderBackend,
|
||||
)
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode, FilesystemSelection
|
||||
|
||||
|
||||
@lru_cache(maxsize=64)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
from .middleware import SurfSenseFilesystemMiddleware
|
||||
|
||||
|
|
|
|||
|
|
@ -7,11 +7,11 @@ from typing import Any
|
|||
from deepagents import FilesystemMiddleware
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.sandbox import is_sandbox_enabled
|
||||
from app.agents.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.shared.sandbox import is_sandbox_enabled
|
||||
|
||||
from ..system_prompt import build_system_prompt
|
||||
from ..tools import (
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.shared.path_resolver import DOCUMENTS_ROOT
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.path_resolver import DOCUMENTS_ROOT
|
||||
|
||||
|
||||
def is_cloud(mode: FilesystemMode) -> bool:
|
||||
|
|
|
|||
|
|
@ -11,10 +11,10 @@ from typing import TYPE_CHECKING
|
|||
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from app.agents.multi_agent_chat.shared.path_resolver import DOCUMENTS_ROOT
|
||||
from app.agents.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.agents.shared.path_resolver import DOCUMENTS_ROOT
|
||||
|
||||
from ..shared.paths import TEMP_PREFIX, basename
|
||||
from .mode import is_cloud
|
||||
|
|
|
|||
|
|
@ -7,13 +7,13 @@ from typing import TYPE_CHECKING
|
|||
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.middleware.filesystem.backends.multi_root_local_folder import (
|
||||
MultiRootLocalFolderBackend,
|
||||
)
|
||||
from app.agents.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
from ..shared.paths import (
|
||||
extract_mount_from_path,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
from .cloud import BODY as CLOUD_BODY
|
||||
from .common import HEADER, SANDBOX_ADDENDUM
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
_DESCRIPTION = """Changes the current working directory (cwd).
|
||||
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ from langchain_core.messages import ToolMessage
|
|||
from langchain_core.tools import BaseTool, StructuredTool
|
||||
from langgraph.types import Command
|
||||
|
||||
from app.agents.multi_agent_chat.shared.path_resolver import DOCUMENTS_ROOT
|
||||
from app.agents.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.agents.shared.path_resolver import DOCUMENTS_ROOT
|
||||
|
||||
from ...middleware.async_dispatch import run_async_blocking
|
||||
from ...middleware.path_resolution import resolve_relative
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
_CLOUD_DESCRIPTION = """Performs exact string replacements in files.
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
_DESCRIPTION = """Executes Python code in an isolated sandbox environment.
|
||||
|
||||
|
|
|
|||
|
|
@ -14,14 +14,14 @@ from typing import TYPE_CHECKING
|
|||
from daytona.common.errors import DaytonaError
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from app.agents.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.agents.shared.sandbox import (
|
||||
from app.agents.multi_agent_chat.shared.sandbox import (
|
||||
_evict_sandbox_cache,
|
||||
delete_sandbox,
|
||||
get_or_create_sandbox,
|
||||
)
|
||||
from app.agents.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...middleware import SurfSenseFilesystemMiddleware
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
_DESCRIPTION = """Find files matching a glob pattern.
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
_CLOUD_DESCRIPTION = """Search for a literal text pattern across files.
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
_CLOUD_DESCRIPTION = """Lists files/folders recursively in a single bounded call.
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
_CLOUD_DESCRIPTION = """Lists files and directories at the given path.
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
_CLOUD_DESCRIPTION = """Creates a directory under `/documents/`.
|
||||
|
||||
|
|
|
|||
|
|
@ -11,10 +11,10 @@ from langchain_core.messages import ToolMessage
|
|||
from langchain_core.tools import BaseTool, StructuredTool
|
||||
from langgraph.types import Command
|
||||
|
||||
from app.agents.multi_agent_chat.shared.path_resolver import DOCUMENTS_ROOT
|
||||
from app.agents.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.agents.shared.path_resolver import DOCUMENTS_ROOT
|
||||
|
||||
from ...middleware.async_dispatch import run_async_blocking
|
||||
from ...middleware.mode import is_cloud
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
_CLOUD_DESCRIPTION = """Moves or renames a file or folder.
|
||||
|
||||
|
|
|
|||
|
|
@ -11,11 +11,11 @@ from langgraph.types import Command
|
|||
from app.agents.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import (
|
||||
KBPostgresBackend,
|
||||
)
|
||||
from app.agents.multi_agent_chat.shared.path_resolver import DOCUMENTS_ROOT
|
||||
from app.agents.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.agents.multi_agent_chat.shared.state.reducers import _CLEAR
|
||||
from app.agents.shared.path_resolver import DOCUMENTS_ROOT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...middleware import SurfSenseFilesystemMiddleware
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
_DESCRIPTION = """Prints the current working directory."""
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
_DESCRIPTION = """Reads a file from the filesystem.
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
_CLOUD_DESCRIPTION = """Deletes a single file under `/documents/`.
|
||||
|
||||
|
|
|
|||
|
|
@ -15,11 +15,11 @@ from langgraph.types import Command
|
|||
from app.agents.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import (
|
||||
KBPostgresBackend,
|
||||
)
|
||||
from app.agents.multi_agent_chat.shared.path_resolver import DOCUMENTS_ROOT
|
||||
from app.agents.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.agents.multi_agent_chat.shared.state.reducers import _CLEAR
|
||||
from app.agents.shared.path_resolver import DOCUMENTS_ROOT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...middleware import SurfSenseFilesystemMiddleware
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
_CLOUD_DESCRIPTION = """Deletes an empty directory under `/documents/`.
|
||||
|
||||
|
|
|
|||
|
|
@ -16,11 +16,11 @@ from langgraph.types import Command
|
|||
from app.agents.multi_agent_chat.shared.middleware.filesystem.backends.kb_postgres import (
|
||||
KBPostgresBackend,
|
||||
)
|
||||
from app.agents.multi_agent_chat.shared.path_resolver import DOCUMENTS_ROOT
|
||||
from app.agents.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.agents.multi_agent_chat.shared.state.reducers import _CLEAR
|
||||
from app.agents.shared.path_resolver import DOCUMENTS_ROOT
|
||||
|
||||
from ...middleware.path_resolution import current_cwd
|
||||
from ...shared.paths import is_ancestor_of
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
_CLOUD_DESCRIPTION = """Writes a new text file to the workspace.
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
|
||||
|
||||
def enabled(flags: AgentFeatureFlags, attr: str) -> bool:
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -45,16 +45,16 @@ from app.agents.multi_agent_chat.shared.date_filters import (
|
|||
parse_date_or_datetime,
|
||||
resolve_date_range,
|
||||
)
|
||||
from app.agents.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.agents.shared.feature_flags import get_flags
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.shared.path_resolver import (
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import get_flags
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.path_resolver import (
|
||||
PathIndex,
|
||||
build_path_index,
|
||||
doc_to_virtual_path,
|
||||
)
|
||||
from app.agents.multi_agent_chat.shared.state.filesystem_state import (
|
||||
SurfSenseFilesystemState,
|
||||
)
|
||||
from app.db import (
|
||||
NATIVE_TO_LEGACY_DOCTYPE,
|
||||
Chunk,
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ from typing import Any
|
|||
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
from app.agents.multi_agent_chat.shared.errors import StreamingError
|
||||
from app.agents.multi_agent_chat.shared.permissions import Rule
|
||||
from app.agents.shared.errors import StreamingError
|
||||
|
||||
|
||||
def build_deny_message(tool_call: dict[str, Any], rule: Rule) -> ToolMessage:
|
||||
|
|
|
|||
|
|
@ -26,8 +26,8 @@ from langchain_core.messages import AIMessage, ToolMessage
|
|||
from langchain_core.tools import BaseTool
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agents.multi_agent_chat.shared.errors import CorrectedError, RejectedError
|
||||
from app.agents.multi_agent_chat.shared.permissions import Ruleset
|
||||
from app.agents.shared.errors import CorrectedError, RejectedError
|
||||
from app.services.user_tool_allowlist import TrustedToolSaver
|
||||
|
||||
from ..ask.edit import merge_edited_args
|
||||
|
|
|
|||
|
|
@ -27,8 +27,8 @@ from collections.abc import Sequence
|
|||
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.permissions import Rule, Ruleset
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.services.user_tool_allowlist import TrustedToolSaver
|
||||
|
||||
from .core import PermissionMiddleware
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from langchain.agents.middleware import (
|
|||
ToolCallLimitMiddleware,
|
||||
)
|
||||
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.shared.middleware import RetryAfterMiddleware
|
||||
|
||||
from .fallback import build_fallback_mw
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
|
||||
from ..flags import enabled
|
||||
from .scoped_model_fallback import (
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
from langchain.agents.middleware import ModelCallLimitMiddleware
|
||||
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
|
||||
from ..flags import enabled
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.shared.middleware import RetryAfterMiddleware
|
||||
|
||||
from ..flags import enabled
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
from langchain.agents.middleware import ToolCallLimitMiddleware
|
||||
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
|
||||
from ..flags import enabled
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,351 @@
|
|||
"""Canonical virtual-path resolver for SurfSense knowledge-base documents.
|
||||
|
||||
This module is the single source of truth for mapping ``Document`` rows to
|
||||
virtual paths under ``/documents/`` and back. It is used by:
|
||||
|
||||
* :class:`KnowledgeTreeMiddleware` (rendering the workspace tree)
|
||||
* :class:`KnowledgePriorityMiddleware` (computing priority paths)
|
||||
* :class:`KBPostgresBackend` (``als_info`` / ``aread`` / move operations)
|
||||
* :class:`KnowledgeBasePersistenceMiddleware` (resolving moves and creates)
|
||||
|
||||
Centralising the logic ensures that title-collision suffixes, folder paths,
|
||||
and ``unique_identifier_hash`` lookups never drift between renders and
|
||||
commits.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType, Folder
|
||||
from app.utils.document_converters import generate_unique_identifier_hash
|
||||
|
||||
DOCUMENTS_ROOT = "/documents"
|
||||
"""Root virtual folder for all KB documents."""
|
||||
|
||||
_INVALID_FILENAME_CHARS = re.compile(r"[\\/:*?\"<>|]+")
|
||||
_WHITESPACE_RUN = re.compile(r"\s+")
|
||||
|
||||
|
||||
def safe_filename(value: str, *, fallback: str = "untitled.xml") -> str:
|
||||
"""Convert arbitrary text into a filesystem-safe ``.xml`` filename."""
|
||||
name = _INVALID_FILENAME_CHARS.sub("_", value).strip()
|
||||
name = _WHITESPACE_RUN.sub(" ", name)
|
||||
if not name:
|
||||
name = fallback
|
||||
if len(name) > 180:
|
||||
name = name[:180].rstrip()
|
||||
if not name.lower().endswith(".xml"):
|
||||
name = f"{name}.xml"
|
||||
return name
|
||||
|
||||
|
||||
def safe_folder_segment(value: str, *, fallback: str = "folder") -> str:
|
||||
"""Sanitize a single folder name into a path-safe segment."""
|
||||
name = _INVALID_FILENAME_CHARS.sub("_", value).strip()
|
||||
name = _WHITESPACE_RUN.sub(" ", name)
|
||||
if not name:
|
||||
return fallback
|
||||
if len(name) > 180:
|
||||
name = name[:180].rstrip()
|
||||
return name
|
||||
|
||||
|
||||
def _suffix_with_doc_id(filename: str, doc_id: int | None) -> str:
|
||||
if doc_id is None:
|
||||
return filename
|
||||
if not filename.lower().endswith(".xml"):
|
||||
return f"{filename} ({doc_id}).xml"
|
||||
stem = filename[:-4]
|
||||
return f"{stem} ({doc_id}).xml"
|
||||
|
||||
|
||||
_SUFFIX_PATTERN = re.compile(r"\s\((\d+)\)\.xml$", re.IGNORECASE)
|
||||
|
||||
|
||||
def parse_doc_id_suffix(filename: str) -> tuple[str, int | None]:
|
||||
"""Strip a trailing ``" (<doc_id>).xml"`` suffix; return ``(stem, doc_id)``.
|
||||
|
||||
If no suffix is present, returns ``(stem_without_xml_extension, None)``.
|
||||
"""
|
||||
match = _SUFFIX_PATTERN.search(filename)
|
||||
if match:
|
||||
doc_id = int(match.group(1))
|
||||
stem = filename[: match.start()]
|
||||
return stem, doc_id
|
||||
if filename.lower().endswith(".xml"):
|
||||
return filename[:-4], None
|
||||
return filename, None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PathIndex:
|
||||
"""In-memory occupancy snapshot used by :func:`doc_to_virtual_path`.
|
||||
|
||||
Built once per call site so collision handling is deterministic and so
|
||||
we don't perform N folder lookups per render.
|
||||
"""
|
||||
|
||||
folder_paths: dict[int, str] = field(default_factory=dict)
|
||||
"""``Folder.id`` -> absolute virtual folder path under ``/documents``."""
|
||||
|
||||
occupants: dict[str, int] = field(default_factory=dict)
|
||||
"""virtual path -> ``Document.id`` already occupying that path (this render)."""
|
||||
|
||||
|
||||
async def _build_folder_paths(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
) -> dict[int, str]:
|
||||
"""Compute ``Folder.id`` -> absolute virtual path under ``/documents``."""
|
||||
result = await session.execute(
|
||||
select(Folder.id, Folder.name, Folder.parent_id).where(
|
||||
Folder.search_space_id == search_space_id
|
||||
)
|
||||
)
|
||||
rows = result.all()
|
||||
by_id = {row.id: {"name": row.name, "parent_id": row.parent_id} for row in rows}
|
||||
cache: dict[int, str] = {}
|
||||
|
||||
def resolve(folder_id: int) -> str:
|
||||
if folder_id in cache:
|
||||
return cache[folder_id]
|
||||
parts: list[str] = []
|
||||
cursor: int | None = folder_id
|
||||
visited: set[int] = set()
|
||||
while cursor is not None and cursor in by_id and cursor not in visited:
|
||||
visited.add(cursor)
|
||||
entry = by_id[cursor]
|
||||
parts.append(safe_folder_segment(str(entry["name"])))
|
||||
cursor = entry["parent_id"]
|
||||
parts.reverse()
|
||||
path = f"{DOCUMENTS_ROOT}/" + "/".join(parts) if parts else DOCUMENTS_ROOT
|
||||
cache[folder_id] = path
|
||||
return path
|
||||
|
||||
for folder_id in by_id:
|
||||
resolve(folder_id)
|
||||
return cache
|
||||
|
||||
|
||||
async def build_path_index(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
*,
|
||||
populate_occupants: bool = True,
|
||||
) -> PathIndex:
|
||||
"""Build a :class:`PathIndex` for a search space.
|
||||
|
||||
``populate_occupants`` controls whether the occupancy map is pre-seeded
|
||||
from existing ``Document`` rows. Most callers want this so that
|
||||
:func:`doc_to_virtual_path` can detect collisions across the whole space;
|
||||
the persistence middleware sets this to ``False`` when it is iterating to
|
||||
decide where to place fresh documents.
|
||||
"""
|
||||
folder_paths = await _build_folder_paths(session, search_space_id)
|
||||
occupants: dict[str, int] = {}
|
||||
if populate_occupants:
|
||||
rows = await session.execute(
|
||||
select(Document.id, Document.title, Document.folder_id).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
for row in rows.all():
|
||||
base = folder_paths.get(row.folder_id, DOCUMENTS_ROOT)
|
||||
filename = safe_filename(str(row.title or "untitled"))
|
||||
path = f"{base}/{filename}"
|
||||
if path in occupants and occupants[path] != row.id:
|
||||
path = f"{base}/{_suffix_with_doc_id(filename, row.id)}"
|
||||
occupants[path] = row.id
|
||||
return PathIndex(folder_paths=folder_paths, occupants=occupants)
|
||||
|
||||
|
||||
def doc_to_virtual_path(
|
||||
*,
|
||||
doc_id: int | None,
|
||||
title: str,
|
||||
folder_id: int | None,
|
||||
index: PathIndex,
|
||||
) -> str:
|
||||
"""Return the canonical virtual path for a document.
|
||||
|
||||
Mutates ``index.occupants`` so subsequent calls see this assignment and
|
||||
deterministically pick a different suffix for the next colliding doc.
|
||||
"""
|
||||
base = index.folder_paths.get(folder_id, DOCUMENTS_ROOT)
|
||||
filename = safe_filename(str(title or "untitled"))
|
||||
path = f"{base}/{filename}"
|
||||
occupant = index.occupants.get(path)
|
||||
if occupant is not None and occupant != doc_id:
|
||||
path = f"{base}/{_suffix_with_doc_id(filename, doc_id)}"
|
||||
if doc_id is not None:
|
||||
index.occupants[path] = doc_id
|
||||
return path
|
||||
|
||||
|
||||
async def virtual_path_to_doc(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
virtual_path: str,
|
||||
) -> Document | None:
|
||||
"""Resolve a virtual path back to a ``Document`` row.
|
||||
|
||||
Resolution order:
|
||||
1. ``Document.unique_identifier_hash`` lookup (fast path for paths created
|
||||
by SurfSense itself — every NOTE write goes through this hash).
|
||||
2. If the basename carries a ``" (<doc_id>).xml"`` disambiguation suffix,
|
||||
try a direct id lookup constrained to the search space.
|
||||
3. Title-from-basename + folder-resolution lookup as a last resort.
|
||||
"""
|
||||
if not virtual_path or not virtual_path.startswith(DOCUMENTS_ROOT):
|
||||
return None
|
||||
|
||||
unique_hash = generate_unique_identifier_hash(
|
||||
DocumentType.NOTE,
|
||||
virtual_path,
|
||||
search_space_id,
|
||||
)
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.unique_identifier_hash == unique_hash,
|
||||
)
|
||||
)
|
||||
document = result.scalar_one_or_none()
|
||||
if document is not None:
|
||||
return document
|
||||
|
||||
rel = virtual_path[len(DOCUMENTS_ROOT) :].lstrip("/")
|
||||
if not rel:
|
||||
return None
|
||||
parts = [p for p in rel.split("/") if p]
|
||||
if not parts:
|
||||
return None
|
||||
basename = parts[-1]
|
||||
folder_parts = parts[:-1]
|
||||
|
||||
stem, suffix_doc_id = parse_doc_id_suffix(basename)
|
||||
if suffix_doc_id is not None:
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.id == suffix_doc_id,
|
||||
)
|
||||
)
|
||||
document = result.scalar_one_or_none()
|
||||
if document is not None:
|
||||
return document
|
||||
|
||||
folder_id = await _resolve_folder_id(
|
||||
session, search_space_id=search_space_id, folder_parts=folder_parts
|
||||
)
|
||||
title_candidates: list[str] = []
|
||||
raw_title = stem
|
||||
title_candidates.append(raw_title)
|
||||
if raw_title.endswith(".xml"):
|
||||
title_candidates.append(raw_title[:-4])
|
||||
|
||||
for candidate in dict.fromkeys(title_candidates):
|
||||
if not candidate:
|
||||
continue
|
||||
query = select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.title == candidate,
|
||||
)
|
||||
if folder_id is None:
|
||||
query = query.where(Document.folder_id.is_(None))
|
||||
else:
|
||||
query = query.where(Document.folder_id == folder_id)
|
||||
result = await session.execute(query)
|
||||
document = result.scalars().first()
|
||||
if document is not None:
|
||||
return document
|
||||
|
||||
# Fallback: title-as-string lookup misses when the real DB title contains
|
||||
# characters that ``safe_filename`` lossily replaces (``:``, ``/``, ``*``,
|
||||
# etc.) — common for connector-imported docs (Google Calendar/Drive etc.).
|
||||
# The workspace tree shows the lossy filename, so the agent passes that
|
||||
# filename back here. Scan all documents in the resolved folder and match
|
||||
# by ``safe_filename(title)`` to recover the original document.
|
||||
folder_scan = select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
)
|
||||
if folder_id is None:
|
||||
folder_scan = folder_scan.where(Document.folder_id.is_(None))
|
||||
else:
|
||||
folder_scan = folder_scan.where(Document.folder_id == folder_id)
|
||||
result = await session.execute(folder_scan)
|
||||
for candidate_doc in result.scalars().all():
|
||||
encoded = safe_filename(str(candidate_doc.title or "untitled"))
|
||||
if encoded == basename:
|
||||
return candidate_doc
|
||||
return None
|
||||
|
||||
|
||||
async def _resolve_folder_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
folder_parts: list[str],
|
||||
) -> int | None:
|
||||
"""Look up the leaf folder id for a chain of folder names; return ``None`` if missing."""
|
||||
if not folder_parts:
|
||||
return None
|
||||
parent_id: int | None = None
|
||||
for raw in folder_parts:
|
||||
name = safe_folder_segment(raw)
|
||||
query = select(Folder.id).where(
|
||||
Folder.search_space_id == search_space_id,
|
||||
Folder.name == name,
|
||||
)
|
||||
if parent_id is None:
|
||||
query = query.where(Folder.parent_id.is_(None))
|
||||
else:
|
||||
query = query.where(Folder.parent_id == parent_id)
|
||||
result = await session.execute(query)
|
||||
row = result.first()
|
||||
if row is None:
|
||||
return None
|
||||
parent_id = row[0]
|
||||
return parent_id
|
||||
|
||||
|
||||
def parse_documents_path(virtual_path: str) -> tuple[list[str], str]:
|
||||
"""Parse a ``/documents/...`` path into ``(folder_parts, document_title)``.
|
||||
|
||||
The title has any ``.xml`` extension and trailing ``" (<doc_id>)"``
|
||||
disambiguation suffix stripped.
|
||||
"""
|
||||
if not virtual_path or not virtual_path.startswith(DOCUMENTS_ROOT):
|
||||
return [], ""
|
||||
rel = virtual_path[len(DOCUMENTS_ROOT) :].strip("/")
|
||||
if not rel:
|
||||
return [], ""
|
||||
parts = [p for p in rel.split("/") if p]
|
||||
if not parts:
|
||||
return [], ""
|
||||
folder_parts = parts[:-1]
|
||||
basename = parts[-1]
|
||||
stem, _ = parse_doc_id_suffix(basename)
|
||||
title = stem
|
||||
if title.endswith(".xml"):
|
||||
title = title[:-4]
|
||||
return folder_parts, title
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DOCUMENTS_ROOT",
|
||||
"PathIndex",
|
||||
"build_path_index",
|
||||
"doc_to_virtual_path",
|
||||
"parse_doc_id_suffix",
|
||||
"parse_documents_path",
|
||||
"safe_filename",
|
||||
"safe_folder_segment",
|
||||
"virtual_path_to_doc",
|
||||
]
|
||||
|
|
@ -0,0 +1,239 @@
|
|||
r"""LiteLLM-native prompt caching configuration for SurfSense agents.
|
||||
|
||||
Replaces the legacy ``AnthropicPromptCachingMiddleware`` (which never
|
||||
activated for our LiteLLM-based stack — its ``isinstance(model, ChatAnthropic)``
|
||||
gate always failed) with LiteLLM's universal caching mechanism.
|
||||
|
||||
Coverage:
|
||||
|
||||
- Marker-based providers (need ``cache_control`` injection, which LiteLLM
|
||||
performs automatically when ``cache_control_injection_points`` is set):
|
||||
``anthropic/``, ``bedrock/``, ``vertex_ai/``, ``gemini/``, ``azure_ai/``,
|
||||
``openrouter/`` (Claude/Gemini/MiniMax/GLM/z-ai routes), ``databricks/``
|
||||
(Claude), ``dashscope/`` (Qwen), ``minimax/``, ``zai/`` (GLM).
|
||||
- Auto-cached (LiteLLM strips the marker silently): ``openai/``,
|
||||
``deepseek/``, ``xai/`` — these caches automatically for prompts ≥1024
|
||||
tokens and surface ``prompt_cache_key`` / ``prompt_cache_retention``.
|
||||
|
||||
We inject **two** breakpoints per request:
|
||||
|
||||
- ``index: 0`` — pins the SurfSense system prompt at the head of the
|
||||
request (provider variant, citation rules, tool catalog, KB tree,
|
||||
skills metadata). The langchain agent factory always prepends
|
||||
``request.system_message`` at index 0 (see ``factory.py``
|
||||
``_execute_model_async``), so this targets exactly the main system
|
||||
prompt regardless of how many other ``SystemMessage``\ s the
|
||||
``before_agent`` injectors (priority, tree, memory, file-intent,
|
||||
anonymous-doc) have inserted into ``state["messages"]``. Using
|
||||
``role: system`` here would apply ``cache_control`` to **every**
|
||||
system-role message and trip Anthropic's hard cap of 4 cache
|
||||
breakpoints per request once the conversation accumulates enough
|
||||
injected system messages — which surfaces as the upstream 400
|
||||
``A maximum of 4 blocks with cache_control may be provided. Found N``
|
||||
via OpenRouter→Anthropic.
|
||||
- ``index: -1`` — pins the latest message so multi-turn savings compound:
|
||||
Anthropic-family providers use longest-matching-prefix lookup, so turn
|
||||
N+1 still reads turn N's cache up to the shared prefix.
|
||||
|
||||
For OpenAI-family configs we additionally pass:
|
||||
|
||||
- ``prompt_cache_key=f"surfsense-thread-{thread_id}"`` — routing hint that
|
||||
raises hit rate by sending requests with a shared prefix to the same
|
||||
backend. Supported by ``openai/``, ``deepseek/``, ``xai/``, and
|
||||
``azure/`` (added to LiteLLM's Azure transformer in
|
||||
https://github.com/BerriAI/litellm/pull/20989, Feb 2026; verified
|
||||
against ``AzureOpenAIConfig.get_supported_openai_params`` in our
|
||||
installed litellm 1.83.14 for ``azure/gpt-4o``, ``azure/gpt-4o-mini``,
|
||||
``azure/gpt-5.4``, ``azure/gpt-5.4-mini``).
|
||||
- ``prompt_cache_retention="24h"`` — extends cache TTL beyond the default
|
||||
5-10 min in-memory cache. Set ONLY for OpenAI/DeepSeek/xAI: Azure's
|
||||
server-side support landed in Microsoft's docs on 2026-05-13 but
|
||||
LiteLLM 1.83.14's Azure transformer still omits it from its supported
|
||||
params list, so it gets silently dropped by ``litellm.drop_params``.
|
||||
Azure's default in-memory retention (5-10 min, max 1 h) already
|
||||
bridges intra-conversation turns; revisit when LiteLLM bumps Azure.
|
||||
|
||||
Safety net: ``litellm.drop_params=True`` is set globally in
|
||||
``app.services.llm_service`` at module-load time. Any kwarg the destination
|
||||
provider doesn't recognise is auto-stripped at the provider transformer
|
||||
layer, so an OpenAI→Bedrock auto-mode fallback can't 400 on
|
||||
``prompt_cache_key`` etc.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.agents.multi_agent_chat.shared.llm_config import AgentConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Two-breakpoint policy: head-of-request + latest message. See module
|
||||
# docstring for rationale. Anthropic caps requests at 4 ``cache_control``
|
||||
# blocks; we use 2 here, leaving headroom for Phase-2 tool caching.
|
||||
#
|
||||
# IMPORTANT: ``index: 0`` (not ``role: system``). The deepagent stack's
|
||||
# ``before_agent`` middlewares (priority, tree, memory, anonymous-doc)
|
||||
# insert ``SystemMessage`` instances into ``state["messages"]`` that
|
||||
# accumulate across turns. With ``role: system`` the LiteLLM hook would
|
||||
# tag *every* one of them with ``cache_control`` and overflow Anthropic's
|
||||
# 4-block limit. ``index: 0`` always targets the langchain-prepended
|
||||
# ``request.system_message``, giving us exactly one stable cache breakpoint.
|
||||
_DEFAULT_INJECTION_POINTS: tuple[dict[str, Any], ...] = (
|
||||
{"location": "message", "index": 0},
|
||||
{"location": "message", "index": -1},
|
||||
)
|
||||
|
||||
# Providers (uppercase ``AgentConfig.provider`` values) that accept the
|
||||
# OpenAI ``prompt_cache_key`` routing hint. Microsoft's Azure OpenAI docs
|
||||
# (2026-05-13) confirm automatic prompt caching applies to every GPT-4o
|
||||
# or newer Azure deployment at ≥1024 tokens with no configuration needed,
|
||||
# and that ``prompt_cache_key`` is combined with the prefix hash to
|
||||
# improve routing affinity and therefore cache hit rate. LiteLLM's Azure
|
||||
# transformer ships ``prompt_cache_key`` in its supported params as of
|
||||
# https://github.com/BerriAI/litellm/pull/20989.
|
||||
#
|
||||
# Strict whitelist — many other providers in ``PROVIDER_MAP`` route
|
||||
# through litellm's ``openai`` prefix without implementing the OpenAI
|
||||
# prompt-cache surface (e.g. MOONSHOT, ZHIPU, MINIMAX), so we can't infer
|
||||
# family from the litellm prefix alone.
|
||||
_PROMPT_CACHE_KEY_PROVIDERS: frozenset[str] = frozenset(
|
||||
{"OPENAI", "DEEPSEEK", "XAI", "AZURE", "AZURE_OPENAI"}
|
||||
)
|
||||
|
||||
# Subset of ``_PROMPT_CACHE_KEY_PROVIDERS`` that also accept
|
||||
# ``prompt_cache_retention="24h"``. Azure is excluded: see module
|
||||
# docstring — LiteLLM 1.83.14's Azure transformer omits the param so
|
||||
# ``drop_params`` silently strips it. Re-add Azure once a future LiteLLM
|
||||
# release wires it into ``AzureOpenAIConfig.get_supported_openai_params``.
|
||||
_PROMPT_CACHE_RETENTION_PROVIDERS: frozenset[str] = frozenset(
|
||||
{"OPENAI", "DEEPSEEK", "XAI"}
|
||||
)
|
||||
|
||||
|
||||
def _is_router_llm(llm: BaseChatModel) -> bool:
|
||||
"""Detect ``ChatLiteLLMRouter`` (auto-mode) without an eager import.
|
||||
|
||||
Importing ``app.services.llm_router_service`` at module-load time would
|
||||
create a cycle via ``llm_config -> prompt_caching -> llm_router_service``.
|
||||
Class-name comparison is sufficient since the class is defined in a
|
||||
single place.
|
||||
"""
|
||||
return type(llm).__name__ == "ChatLiteLLMRouter"
|
||||
|
||||
|
||||
def _provider_supports_prompt_cache_key(agent_config: AgentConfig | None) -> bool:
|
||||
"""Whether the config targets a provider that accepts ``prompt_cache_key``.
|
||||
|
||||
Strict — only returns True for explicitly chosen OPENAI, DEEPSEEK,
|
||||
XAI, AZURE, or AZURE_OPENAI providers. Auto-mode and custom
|
||||
providers return False because we can't statically know the
|
||||
destination and the router fans out across mixed providers.
|
||||
"""
|
||||
if agent_config is None or not agent_config.provider:
|
||||
return False
|
||||
if agent_config.is_auto_mode:
|
||||
return False
|
||||
if agent_config.custom_provider:
|
||||
return False
|
||||
return agent_config.provider.upper() in _PROMPT_CACHE_KEY_PROVIDERS
|
||||
|
||||
|
||||
def _provider_supports_prompt_cache_retention(
|
||||
agent_config: AgentConfig | None,
|
||||
) -> bool:
|
||||
"""Whether the config targets a provider that accepts ``prompt_cache_retention``.
|
||||
|
||||
Tighter than :func:`_provider_supports_prompt_cache_key` — Azure
|
||||
deployments are excluded until LiteLLM ships the param in its Azure
|
||||
transformer (see module docstring).
|
||||
"""
|
||||
if agent_config is None or not agent_config.provider:
|
||||
return False
|
||||
if agent_config.is_auto_mode:
|
||||
return False
|
||||
if agent_config.custom_provider:
|
||||
return False
|
||||
return agent_config.provider.upper() in _PROMPT_CACHE_RETENTION_PROVIDERS
|
||||
|
||||
|
||||
def _get_or_init_model_kwargs(llm: BaseChatModel) -> dict[str, Any] | None:
|
||||
"""Return ``llm.model_kwargs`` as a writable dict, or ``None`` to bail.
|
||||
|
||||
Initialises the field to ``{}`` when present-but-None on a Pydantic v2
|
||||
model. Returns ``None`` if the LLM type doesn't expose a writable
|
||||
``model_kwargs`` attribute (caller should treat as no-op).
|
||||
"""
|
||||
model_kwargs = getattr(llm, "model_kwargs", None)
|
||||
if isinstance(model_kwargs, dict):
|
||||
return model_kwargs
|
||||
try:
|
||||
llm.model_kwargs = {} # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return None
|
||||
refreshed = getattr(llm, "model_kwargs", None)
|
||||
return refreshed if isinstance(refreshed, dict) else None
|
||||
|
||||
|
||||
def apply_litellm_prompt_caching(
|
||||
llm: BaseChatModel,
|
||||
*,
|
||||
agent_config: AgentConfig | None = None,
|
||||
thread_id: int | None = None,
|
||||
) -> None:
|
||||
"""Configure LiteLLM prompt caching on a ChatLiteLLM/ChatLiteLLMRouter.
|
||||
|
||||
Idempotent — values already present in ``llm.model_kwargs`` (e.g. from
|
||||
``agent_config.litellm_params`` overrides) are preserved. Mutates
|
||||
``llm.model_kwargs`` in place; the kwargs flow to ``litellm.completion``
|
||||
via ``ChatLiteLLM._default_params`` and via ``self.model_kwargs`` merge
|
||||
in our custom ``ChatLiteLLMRouter``.
|
||||
|
||||
Args:
|
||||
llm: ChatLiteLLM, SanitizedChatLiteLLM, or ChatLiteLLMRouter instance.
|
||||
agent_config: Optional ``AgentConfig`` driving provider-specific
|
||||
behaviour. When omitted (or auto-mode), only the universal
|
||||
``cache_control_injection_points`` are set.
|
||||
thread_id: Optional thread id used to construct a per-thread
|
||||
``prompt_cache_key`` for OpenAI-family providers. Caching still
|
||||
works without it (server-side automatic), but the key improves
|
||||
backend routing affinity and therefore hit rate.
|
||||
"""
|
||||
model_kwargs = _get_or_init_model_kwargs(llm)
|
||||
if model_kwargs is None:
|
||||
logger.debug(
|
||||
"apply_litellm_prompt_caching: %s exposes no writable model_kwargs; skipping",
|
||||
type(llm).__name__,
|
||||
)
|
||||
return
|
||||
|
||||
if "cache_control_injection_points" not in model_kwargs:
|
||||
model_kwargs["cache_control_injection_points"] = [
|
||||
dict(point) for point in _DEFAULT_INJECTION_POINTS
|
||||
]
|
||||
|
||||
# OpenAI-style extras only when we statically know the destination
|
||||
# accepts them. Auto-mode router fans out across mixed providers so
|
||||
# we can't safely set destination-specific kwargs there (drop_params
|
||||
# would strip them but it's wasteful to set them in the first
|
||||
# place).
|
||||
if _is_router_llm(llm):
|
||||
return
|
||||
|
||||
if (
|
||||
thread_id is not None
|
||||
and "prompt_cache_key" not in model_kwargs
|
||||
and _provider_supports_prompt_cache_key(agent_config)
|
||||
):
|
||||
model_kwargs["prompt_cache_key"] = f"surfsense-thread-{thread_id}"
|
||||
|
||||
if (
|
||||
"prompt_cache_retention" not in model_kwargs
|
||||
and _provider_supports_prompt_cache_retention(agent_config)
|
||||
):
|
||||
model_kwargs["prompt_cache_retention"] = "24h"
|
||||
|
|
@ -23,7 +23,7 @@ the receipt into the parent's ``receipts`` state via the append reducer.
|
|||
|
||||
The KB write path is the one exception: file-tool calls cannot emit a
|
||||
durable receipt because the actual DB writes happen end-of-turn inside
|
||||
:class:`app.agents.shared.middleware.kb_persistence.KnowledgeBasePersistenceMiddleware`.
|
||||
:class:`app.agents.multi_agent_chat.shared.middleware.kb_persistence.KnowledgeBasePersistenceMiddleware`.
|
||||
KB tools therefore emit a *provisional* receipt with ``status="pending"``;
|
||||
the persistence middleware flips it to ``"success"`` or ``"failed"``
|
||||
before returning control to the parent.
|
||||
|
|
|
|||
401
surfsense_backend/app/agents/multi_agent_chat/shared/sandbox.py
Normal file
401
surfsense_backend/app/agents/multi_agent_chat/shared/sandbox.py
Normal file
|
|
@ -0,0 +1,401 @@
|
|||
"""
|
||||
Daytona sandbox provider for SurfSense deep agent.
|
||||
|
||||
Manages the lifecycle of sandboxed code execution environments.
|
||||
Each conversation thread gets its own isolated sandbox instance
|
||||
via the Daytona cloud API, identified by labels.
|
||||
|
||||
Files created during a session are persisted to local storage before
|
||||
the sandbox is deleted so they remain downloadable after cleanup.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
from daytona import (
|
||||
CreateSandboxFromSnapshotParams,
|
||||
Daytona,
|
||||
DaytonaConfig,
|
||||
SandboxState,
|
||||
)
|
||||
from daytona.common.errors import DaytonaError
|
||||
from deepagents.backends.protocol import ExecuteResponse
|
||||
from langchain_daytona import DaytonaSandbox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _TimeoutAwareSandbox(DaytonaSandbox):
|
||||
"""DaytonaSandbox subclass that accepts the per-command *timeout*
|
||||
kwarg required by the deepagents middleware.
|
||||
|
||||
The upstream ``langchain-daytona`` ``execute()`` ignores timeout,
|
||||
so deepagents raises *"This sandbox backend does not support
|
||||
per-command timeout overrides"* on every first call. This thin
|
||||
wrapper forwards the parameter to the Daytona SDK.
|
||||
"""
|
||||
|
||||
def execute(self, command: str, *, timeout: int | None = None) -> ExecuteResponse:
|
||||
t = timeout if timeout is not None else self._default_timeout
|
||||
result = self._sandbox.process.exec(command, timeout=t)
|
||||
return ExecuteResponse(
|
||||
output=result.result,
|
||||
exit_code=result.exit_code,
|
||||
truncated=False,
|
||||
)
|
||||
|
||||
async def aexecute(
|
||||
self, command: str, *, timeout: int | None = None
|
||||
) -> ExecuteResponse: # type: ignore[override]
|
||||
return await asyncio.to_thread(self.execute, command, timeout=timeout)
|
||||
|
||||
def download_file(self, path: str) -> bytes:
|
||||
"""Download a file from the sandbox filesystem."""
|
||||
return self._sandbox.fs.download_file(path)
|
||||
|
||||
|
||||
_daytona_client: Daytona | None = None
|
||||
_client_lock = threading.Lock()
|
||||
_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {}
|
||||
_sandbox_locks: dict[str, asyncio.Lock] = {}
|
||||
_sandbox_locks_mu = asyncio.Lock()
|
||||
_seeded_files: dict[str, dict[str, str]] = {}
|
||||
_SANDBOX_CACHE_MAX_SIZE = 20
|
||||
THREAD_LABEL_KEY = "surfsense_thread"
|
||||
SANDBOX_DOCUMENTS_ROOT = "/home/daytona/documents"
|
||||
|
||||
|
||||
def is_sandbox_enabled() -> bool:
|
||||
return os.environ.get("DAYTONA_SANDBOX_ENABLED", "FALSE").upper() == "TRUE"
|
||||
|
||||
|
||||
def _get_client() -> Daytona:
|
||||
global _daytona_client
|
||||
with _client_lock:
|
||||
if _daytona_client is None:
|
||||
config = DaytonaConfig(
|
||||
api_key=os.environ.get("DAYTONA_API_KEY", ""),
|
||||
api_url=os.environ.get("DAYTONA_API_URL", "https://app.daytona.io/api"),
|
||||
target=os.environ.get("DAYTONA_TARGET", "us"),
|
||||
)
|
||||
_daytona_client = Daytona(config)
|
||||
return _daytona_client
|
||||
|
||||
|
||||
def _sandbox_create_params(
|
||||
labels: dict[str, str],
|
||||
) -> CreateSandboxFromSnapshotParams:
|
||||
snapshot_id = os.environ.get("DAYTONA_SNAPSHOT_ID") or None
|
||||
return CreateSandboxFromSnapshotParams(
|
||||
language="python",
|
||||
labels=labels,
|
||||
snapshot=snapshot_id,
|
||||
network_block_all=True,
|
||||
auto_stop_interval=10,
|
||||
auto_delete_interval=60,
|
||||
)
|
||||
|
||||
|
||||
def _find_or_create(thread_id: str) -> tuple[_TimeoutAwareSandbox, bool]:
|
||||
"""Find an existing sandbox for *thread_id*, or create a new one.
|
||||
|
||||
Returns a tuple of (sandbox, is_new) where *is_new* is True when a
|
||||
fresh sandbox was created (first time or replacement after failure).
|
||||
"""
|
||||
client = _get_client()
|
||||
labels = {THREAD_LABEL_KEY: thread_id}
|
||||
is_new = False
|
||||
|
||||
try:
|
||||
sandbox = client.find_one(labels=labels)
|
||||
logger.info("Found existing sandbox %s (state=%s)", sandbox.id, sandbox.state)
|
||||
|
||||
if sandbox.state in (
|
||||
SandboxState.STOPPED,
|
||||
SandboxState.STOPPING,
|
||||
SandboxState.ARCHIVED,
|
||||
):
|
||||
logger.info("Starting stopped sandbox %s …", sandbox.id)
|
||||
sandbox.start(timeout=60)
|
||||
logger.info("Sandbox %s is now started", sandbox.id)
|
||||
elif sandbox.state in (
|
||||
SandboxState.ERROR,
|
||||
SandboxState.BUILD_FAILED,
|
||||
SandboxState.DESTROYED,
|
||||
):
|
||||
logger.warning(
|
||||
"Sandbox %s in unrecoverable state %s — creating a new one",
|
||||
sandbox.id,
|
||||
sandbox.state,
|
||||
)
|
||||
try:
|
||||
client.delete(sandbox)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Could not delete broken sandbox %s", sandbox.id, exc_info=True
|
||||
)
|
||||
sandbox = client.create(_sandbox_create_params(labels))
|
||||
is_new = True
|
||||
logger.info("Created replacement sandbox: %s", sandbox.id)
|
||||
elif sandbox.state != SandboxState.STARTED:
|
||||
sandbox.wait_for_sandbox_start(timeout=60)
|
||||
|
||||
except DaytonaError:
|
||||
logger.info("No existing sandbox for thread %s — creating one", thread_id)
|
||||
sandbox = client.create(_sandbox_create_params(labels))
|
||||
is_new = True
|
||||
logger.info("Created new sandbox: %s", sandbox.id)
|
||||
|
||||
return _TimeoutAwareSandbox(sandbox=sandbox), is_new
|
||||
|
||||
|
||||
async def _get_thread_lock(key: str) -> asyncio.Lock:
|
||||
"""Return a per-thread asyncio lock, creating one if needed."""
|
||||
async with _sandbox_locks_mu:
|
||||
lock = _sandbox_locks.get(key)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_sandbox_locks[key] = lock
|
||||
return lock
|
||||
|
||||
|
||||
async def get_or_create_sandbox(
|
||||
thread_id: int | str,
|
||||
) -> tuple[_TimeoutAwareSandbox, bool]:
|
||||
"""Get or create a sandbox for a conversation thread.
|
||||
|
||||
Uses an in-process cache keyed by thread_id so subsequent messages
|
||||
in the same conversation reuse the sandbox object without an API call.
|
||||
A per-thread async lock prevents duplicate sandbox creation from
|
||||
concurrent requests.
|
||||
|
||||
Returns:
|
||||
Tuple of (sandbox, is_new). *is_new* is True when a fresh sandbox
|
||||
was created, signalling that file tracking should be reset.
|
||||
"""
|
||||
key = str(thread_id)
|
||||
lock = await _get_thread_lock(key)
|
||||
|
||||
async with lock:
|
||||
cached = _sandbox_cache.get(key)
|
||||
if cached is not None:
|
||||
logger.info("Reusing cached sandbox for thread %s", key)
|
||||
return cached, False
|
||||
sandbox, is_new = await asyncio.to_thread(_find_or_create, key)
|
||||
_sandbox_cache[key] = sandbox
|
||||
|
||||
if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE:
|
||||
oldest_key = next(iter(_sandbox_cache))
|
||||
if oldest_key != key:
|
||||
evicted = _sandbox_cache.pop(oldest_key, None)
|
||||
_seeded_files.pop(oldest_key, None)
|
||||
logger.debug("Evicted sandbox cache entry: %s", oldest_key)
|
||||
if evicted is not None:
|
||||
_schedule_sandbox_delete(evicted)
|
||||
|
||||
return sandbox, is_new
|
||||
|
||||
|
||||
def _schedule_sandbox_delete(sandbox: _TimeoutAwareSandbox) -> None:
|
||||
"""Best-effort background deletion of an evicted sandbox."""
|
||||
|
||||
def _delete() -> None:
|
||||
try:
|
||||
client = _get_client()
|
||||
client.delete(sandbox._sandbox)
|
||||
logger.info("Deleted evicted sandbox: %s", sandbox._sandbox.id)
|
||||
except Exception:
|
||||
logger.debug("Could not delete evicted sandbox", exc_info=True)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.run_in_executor(None, _delete)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
async def sync_files_to_sandbox(
|
||||
thread_id: int | str,
|
||||
files: dict[str, dict],
|
||||
sandbox: _TimeoutAwareSandbox,
|
||||
is_new: bool,
|
||||
) -> None:
|
||||
"""Upload new or changed virtual-filesystem files to the sandbox.
|
||||
|
||||
Compares *files* (from ``state["files"]``) against the ``_seeded_files``
|
||||
tracking dict and uploads only what has changed. When *is_new* is True
|
||||
the tracking is reset so every file is re-uploaded.
|
||||
"""
|
||||
key = str(thread_id)
|
||||
if is_new:
|
||||
_seeded_files.pop(key, None)
|
||||
|
||||
tracked = _seeded_files.get(key, {})
|
||||
to_upload: list[tuple[str, bytes]] = []
|
||||
|
||||
for vpath, fdata in files.items():
|
||||
modified_at = fdata.get("modified_at", "")
|
||||
if tracked.get(vpath) == modified_at:
|
||||
continue
|
||||
content = "\n".join(fdata.get("content", []))
|
||||
sandbox_path = f"{SANDBOX_DOCUMENTS_ROOT}{vpath}"
|
||||
to_upload.append((sandbox_path, content.encode("utf-8")))
|
||||
|
||||
if not to_upload:
|
||||
return
|
||||
|
||||
def _upload() -> None:
|
||||
sandbox.upload_files(to_upload)
|
||||
|
||||
await asyncio.to_thread(_upload)
|
||||
|
||||
new_tracked = dict(tracked)
|
||||
for vpath, fdata in files.items():
|
||||
new_tracked[vpath] = fdata.get("modified_at", "")
|
||||
_seeded_files[key] = new_tracked
|
||||
logger.info("Synced %d file(s) to sandbox for thread %s", len(to_upload), key)
|
||||
|
||||
|
||||
def _evict_sandbox_cache(thread_id: int | str) -> None:
|
||||
key = str(thread_id)
|
||||
_sandbox_cache.pop(key, None)
|
||||
_seeded_files.pop(key, None)
|
||||
|
||||
|
||||
async def delete_sandbox(thread_id: int | str) -> None:
|
||||
"""Delete the sandbox for a conversation thread."""
|
||||
_evict_sandbox_cache(thread_id)
|
||||
|
||||
def _delete() -> None:
|
||||
client = _get_client()
|
||||
labels = {THREAD_LABEL_KEY: str(thread_id)}
|
||||
try:
|
||||
sandbox = client.find_one(labels=labels)
|
||||
except DaytonaError:
|
||||
logger.debug(
|
||||
"No sandbox to delete for thread %s (already removed)", thread_id
|
||||
)
|
||||
return
|
||||
try:
|
||||
client.delete(sandbox)
|
||||
logger.info("Sandbox deleted: %s", sandbox.id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to delete sandbox for thread %s",
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_delete)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local file persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_sandbox_files_dir() -> Path:
|
||||
return Path(os.environ.get("SANDBOX_FILES_DIR", "sandbox_files"))
|
||||
|
||||
|
||||
def _local_path_for(thread_id: int | str, sandbox_path: str) -> Path:
|
||||
"""Map a sandbox-internal absolute path to a local filesystem path."""
|
||||
relative = sandbox_path.lstrip("/")
|
||||
base = (_get_sandbox_files_dir() / str(thread_id)).resolve()
|
||||
target = (base / relative).resolve()
|
||||
if not target.is_relative_to(base):
|
||||
raise ValueError(f"Path traversal blocked: {sandbox_path}")
|
||||
return target
|
||||
|
||||
|
||||
def get_local_sandbox_file(thread_id: int | str, sandbox_path: str) -> bytes | None:
|
||||
"""Read a previously-persisted sandbox file from local storage.
|
||||
|
||||
Returns the file bytes, or *None* if the file does not exist locally.
|
||||
"""
|
||||
local = _local_path_for(thread_id, sandbox_path)
|
||||
if local.is_file():
|
||||
return local.read_bytes()
|
||||
return None
|
||||
|
||||
|
||||
def delete_local_sandbox_files(thread_id: int | str) -> None:
|
||||
"""Remove all locally-persisted sandbox files for a thread."""
|
||||
thread_dir = _get_sandbox_files_dir() / str(thread_id)
|
||||
if thread_dir.is_dir():
|
||||
shutil.rmtree(thread_dir, ignore_errors=True)
|
||||
logger.info("Deleted local sandbox files for thread %s", thread_id)
|
||||
|
||||
|
||||
async def persist_and_delete_sandbox(
|
||||
thread_id: int | str,
|
||||
sandbox_file_paths: list[str],
|
||||
) -> None:
|
||||
"""Download sandbox files to local storage, then delete the sandbox.
|
||||
|
||||
Each file in *sandbox_file_paths* is downloaded from the Daytona
|
||||
sandbox and saved under ``{SANDBOX_FILES_DIR}/{thread_id}/…``.
|
||||
Per-file errors are logged but do **not** prevent the sandbox from
|
||||
being deleted — freeing Daytona storage is the priority.
|
||||
"""
|
||||
_evict_sandbox_cache(thread_id)
|
||||
|
||||
def _persist_and_delete() -> None:
|
||||
client = _get_client()
|
||||
labels = {THREAD_LABEL_KEY: str(thread_id)}
|
||||
|
||||
try:
|
||||
sandbox = client.find_one(labels=labels)
|
||||
except Exception:
|
||||
logger.info(
|
||||
"No sandbox found for thread %s — nothing to persist", thread_id
|
||||
)
|
||||
return
|
||||
|
||||
# Ensure the sandbox is running so we can download files
|
||||
if sandbox.state != SandboxState.STARTED:
|
||||
try:
|
||||
sandbox.start(timeout=60)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Could not start sandbox %s for file download — deleting anyway",
|
||||
sandbox.id,
|
||||
exc_info=True,
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
client.delete(sandbox)
|
||||
return
|
||||
|
||||
for path in sandbox_file_paths:
|
||||
try:
|
||||
content: bytes = sandbox.fs.download_file(path)
|
||||
local = _local_path_for(thread_id, path)
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
local.write_bytes(content)
|
||||
logger.info("Persisted sandbox file %s → %s", path, local)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist sandbox file %s for thread %s",
|
||||
path,
|
||||
thread_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
client.delete(sandbox)
|
||||
logger.info("Sandbox deleted after file persistence: %s", sandbox.id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to delete sandbox %s after persistence",
|
||||
sandbox.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_persist_and_delete)
|
||||
|
|
@ -13,9 +13,9 @@ from deepagents import SubAgent
|
|||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.permissions import Rule, Ruleset
|
||||
from app.agents.multi_agent_chat.subagents.shared.spec import SurfSenseSubagentSpec
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
from .middleware_stack import build_kb_middleware
|
||||
from .prompts import load_description, load_readonly_system_prompt, load_system_prompt
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ from typing import Any
|
|||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.shared.middleware.anthropic_cache import (
|
||||
build_anthropic_cache_mw,
|
||||
)
|
||||
|
|
@ -29,8 +31,6 @@ from app.agents.multi_agent_chat.shared.middleware.permissions import (
|
|||
build_permission_mw,
|
||||
)
|
||||
from app.agents.multi_agent_chat.shared.permissions import Ruleset
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
|
||||
def _kb_user_allowlist(
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.agents.multi_agent_chat.shared.filesystem_selection import FilesystemMode
|
||||
from app.agents.multi_agent_chat.subagents.shared.md_file_reader import read_md_file
|
||||
from app.agents.shared.filesystem_selection import FilesystemMode
|
||||
|
||||
|
||||
def load_system_prompt(filesystem_mode: FilesystemMode) -> str:
|
||||
|
|
|
|||
|
|
@ -14,8 +14,7 @@ from __future__ import annotations
|
|||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.shared.feature_flags import AgentFeatureFlags
|
||||
|
||||
from app.agents.multi_agent_chat.shared.feature_flags import AgentFeatureFlags
|
||||
from app.agents.multi_agent_chat.shared.middleware.permissions import (
|
||||
build_permission_mw,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue