SurfSense/surfsense_backend/app/agents/new_chat/sandbox.py

402 lines
13 KiB
Python
Raw Normal View History

2026-02-24 16:36:11 -08:00
"""
Daytona sandbox provider for SurfSense deep agent.
2026-02-24 16:36:11 -08:00
Manages the lifecycle of sandboxed code execution environments.
Each conversation thread gets its own isolated sandbox instance
via the Daytona cloud API, identified by labels.
Files created during a session are persisted to local storage before
the sandbox is deleted so they remain downloadable after cleanup.
2026-02-24 16:36:11 -08:00
"""
from __future__ import annotations
import asyncio
import contextlib
2026-02-24 16:36:11 -08:00
import logging
import os
import shutil
import threading
from pathlib import Path
2026-02-24 16:36:11 -08:00
2026-02-25 01:50:28 -08:00
from daytona import (
CreateSandboxFromSnapshotParams,
Daytona,
DaytonaConfig,
SandboxState,
)
from daytona.common.errors import DaytonaError
from deepagents.backends.protocol import ExecuteResponse
from langchain_daytona import DaytonaSandbox
logger = logging.getLogger(__name__)
class _TimeoutAwareSandbox(DaytonaSandbox):
"""DaytonaSandbox subclass that accepts the per-command *timeout*
kwarg required by the deepagents middleware.
The upstream ``langchain-daytona`` ``execute()`` ignores timeout,
so deepagents raises *"This sandbox backend does not support
per-command timeout overrides"* on every first call. This thin
wrapper forwards the parameter to the Daytona SDK.
"""
def execute(self, command: str, *, timeout: int | None = None) -> ExecuteResponse:
t = timeout if timeout is not None else self._default_timeout
result = self._sandbox.process.exec(command, timeout=t)
return ExecuteResponse(
output=result.result,
exit_code=result.exit_code,
truncated=False,
)
2026-02-25 01:50:28 -08:00
async def aexecute(
self, command: str, *, timeout: int | None = None
) -> ExecuteResponse: # type: ignore[override]
return await asyncio.to_thread(self.execute, command, timeout=timeout)
def download_file(self, path: str) -> bytes:
"""Download a file from the sandbox filesystem."""
return self._sandbox.fs.download_file(path)
2026-02-25 01:50:28 -08:00
_daytona_client: Daytona | None = None
_client_lock = threading.Lock()
_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {}
_sandbox_locks: dict[str, asyncio.Lock] = {}
_sandbox_locks_mu = asyncio.Lock()
_seeded_files: dict[str, dict[str, str]] = {}
_SANDBOX_CACHE_MAX_SIZE = 20
THREAD_LABEL_KEY = "surfsense_thread"
SANDBOX_DOCUMENTS_ROOT = "/home/daytona/documents"
def is_sandbox_enabled() -> bool:
return os.environ.get("DAYTONA_SANDBOX_ENABLED", "FALSE").upper() == "TRUE"
2026-02-24 16:36:11 -08:00
def _get_client() -> Daytona:
global _daytona_client
with _client_lock:
if _daytona_client is None:
config = DaytonaConfig(
api_key=os.environ.get("DAYTONA_API_KEY", ""),
api_url=os.environ.get("DAYTONA_API_URL", "https://app.daytona.io/api"),
target=os.environ.get("DAYTONA_TARGET", "us"),
)
_daytona_client = Daytona(config)
return _daytona_client
2026-02-24 16:36:11 -08:00
def _sandbox_create_params(
labels: dict[str, str],
) -> CreateSandboxFromSnapshotParams:
snapshot_id = os.environ.get("DAYTONA_SNAPSHOT_ID") or None
return CreateSandboxFromSnapshotParams(
language="python",
labels=labels,
snapshot=snapshot_id,
network_block_all=True,
auto_stop_interval=10,
auto_delete_interval=60,
)
def _find_or_create(thread_id: str) -> tuple[_TimeoutAwareSandbox, bool]:
"""Find an existing sandbox for *thread_id*, or create a new one.
Returns a tuple of (sandbox, is_new) where *is_new* is True when a
fresh sandbox was created (first time or replacement after failure).
"""
client = _get_client()
labels = {THREAD_LABEL_KEY: thread_id}
is_new = False
2026-02-24 16:36:11 -08:00
try:
sandbox = client.find_one(labels=labels)
2026-02-25 01:50:28 -08:00
logger.info("Found existing sandbox %s (state=%s)", sandbox.id, sandbox.state)
if sandbox.state in (
SandboxState.STOPPED,
SandboxState.STOPPING,
SandboxState.ARCHIVED,
):
logger.info("Starting stopped sandbox %s", sandbox.id)
sandbox.start(timeout=60)
logger.info("Sandbox %s is now started", sandbox.id)
2026-02-25 01:50:28 -08:00
elif sandbox.state in (
SandboxState.ERROR,
SandboxState.BUILD_FAILED,
SandboxState.DESTROYED,
):
logger.warning(
"Sandbox %s in unrecoverable state %s — creating a new one",
sandbox.id,
sandbox.state,
)
try:
client.delete(sandbox)
except Exception:
2026-04-14 01:43:30 -07:00
logger.debug(
"Could not delete broken sandbox %s", sandbox.id, exc_info=True
)
sandbox = client.create(_sandbox_create_params(labels))
is_new = True
logger.info("Created replacement sandbox: %s", sandbox.id)
elif sandbox.state != SandboxState.STARTED:
sandbox.wait_for_sandbox_start(timeout=60)
except DaytonaError:
logger.info("No existing sandbox for thread %s — creating one", thread_id)
sandbox = client.create(_sandbox_create_params(labels))
is_new = True
logger.info("Created new sandbox: %s", sandbox.id)
return _TimeoutAwareSandbox(sandbox=sandbox), is_new
2026-02-24 16:36:11 -08:00
async def _get_thread_lock(key: str) -> asyncio.Lock:
"""Return a per-thread asyncio lock, creating one if needed."""
async with _sandbox_locks_mu:
lock = _sandbox_locks.get(key)
if lock is None:
lock = asyncio.Lock()
_sandbox_locks[key] = lock
return lock
async def get_or_create_sandbox(
thread_id: int | str,
) -> tuple[_TimeoutAwareSandbox, bool]:
2026-02-24 16:36:11 -08:00
"""Get or create a sandbox for a conversation thread.
Uses an in-process cache keyed by thread_id so subsequent messages
in the same conversation reuse the sandbox object without an API call.
A per-thread async lock prevents duplicate sandbox creation from
concurrent requests.
2026-02-24 16:36:11 -08:00
Returns:
Tuple of (sandbox, is_new). *is_new* is True when a fresh sandbox
was created, signalling that file tracking should be reset.
2026-02-24 16:36:11 -08:00
"""
key = str(thread_id)
lock = await _get_thread_lock(key)
async with lock:
cached = _sandbox_cache.get(key)
if cached is not None:
logger.info("Reusing cached sandbox for thread %s", key)
return cached, False
sandbox, is_new = await asyncio.to_thread(_find_or_create, key)
_sandbox_cache[key] = sandbox
if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE:
oldest_key = next(iter(_sandbox_cache))
if oldest_key != key:
evicted = _sandbox_cache.pop(oldest_key, None)
_seeded_files.pop(oldest_key, None)
logger.debug("Evicted sandbox cache entry: %s", oldest_key)
if evicted is not None:
_schedule_sandbox_delete(evicted)
return sandbox, is_new
def _schedule_sandbox_delete(sandbox: _TimeoutAwareSandbox) -> None:
"""Best-effort background deletion of an evicted sandbox."""
2026-04-14 01:43:30 -07:00
def _delete() -> None:
try:
client = _get_client()
client.delete(sandbox._sandbox)
logger.info("Deleted evicted sandbox: %s", sandbox._sandbox.id)
except Exception:
logger.debug("Could not delete evicted sandbox", exc_info=True)
try:
loop = asyncio.get_running_loop()
loop.run_in_executor(None, _delete)
except RuntimeError:
pass
async def sync_files_to_sandbox(
thread_id: int | str,
files: dict[str, dict],
sandbox: _TimeoutAwareSandbox,
is_new: bool,
) -> None:
"""Upload new or changed virtual-filesystem files to the sandbox.
Compares *files* (from ``state["files"]``) against the ``_seeded_files``
tracking dict and uploads only what has changed. When *is_new* is True
the tracking is reset so every file is re-uploaded.
"""
key = str(thread_id)
if is_new:
_seeded_files.pop(key, None)
tracked = _seeded_files.get(key, {})
to_upload: list[tuple[str, bytes]] = []
for vpath, fdata in files.items():
modified_at = fdata.get("modified_at", "")
if tracked.get(vpath) == modified_at:
continue
content = "\n".join(fdata.get("content", []))
sandbox_path = f"{SANDBOX_DOCUMENTS_ROOT}{vpath}"
to_upload.append((sandbox_path, content.encode("utf-8")))
if not to_upload:
return
def _upload() -> None:
sandbox.upload_files(to_upload)
await asyncio.to_thread(_upload)
new_tracked = dict(tracked)
for vpath, fdata in files.items():
new_tracked[vpath] = fdata.get("modified_at", "")
_seeded_files[key] = new_tracked
logger.info("Synced %d file(s) to sandbox for thread %s", len(to_upload), key)
def _evict_sandbox_cache(thread_id: int | str) -> None:
key = str(thread_id)
_sandbox_cache.pop(key, None)
_seeded_files.pop(key, None)
2026-02-24 16:36:11 -08:00
async def delete_sandbox(thread_id: int | str) -> None:
"""Delete the sandbox for a conversation thread."""
_evict_sandbox_cache(thread_id)
def _delete() -> None:
client = _get_client()
labels = {THREAD_LABEL_KEY: str(thread_id)}
try:
sandbox = client.find_one(labels=labels)
except DaytonaError:
logger.debug(
"No sandbox to delete for thread %s (already removed)", thread_id
)
return
try:
client.delete(sandbox)
logger.info("Sandbox deleted: %s", sandbox.id)
except Exception:
logger.warning(
"Failed to delete sandbox for thread %s",
thread_id,
exc_info=True,
)
await asyncio.to_thread(_delete)
# ---------------------------------------------------------------------------
# Local file persistence
# ---------------------------------------------------------------------------
def _get_sandbox_files_dir() -> Path:
return Path(os.environ.get("SANDBOX_FILES_DIR", "sandbox_files"))
def _local_path_for(thread_id: int | str, sandbox_path: str) -> Path:
"""Map a sandbox-internal absolute path to a local filesystem path."""
relative = sandbox_path.lstrip("/")
base = (_get_sandbox_files_dir() / str(thread_id)).resolve()
target = (base / relative).resolve()
if not target.is_relative_to(base):
raise ValueError(f"Path traversal blocked: {sandbox_path}")
return target
def get_local_sandbox_file(thread_id: int | str, sandbox_path: str) -> bytes | None:
"""Read a previously-persisted sandbox file from local storage.
Returns the file bytes, or *None* if the file does not exist locally.
"""
local = _local_path_for(thread_id, sandbox_path)
if local.is_file():
return local.read_bytes()
return None
def delete_local_sandbox_files(thread_id: int | str) -> None:
"""Remove all locally-persisted sandbox files for a thread."""
thread_dir = _get_sandbox_files_dir() / str(thread_id)
if thread_dir.is_dir():
shutil.rmtree(thread_dir, ignore_errors=True)
logger.info("Deleted local sandbox files for thread %s", thread_id)
async def persist_and_delete_sandbox(
thread_id: int | str,
sandbox_file_paths: list[str],
) -> None:
"""Download sandbox files to local storage, then delete the sandbox.
Each file in *sandbox_file_paths* is downloaded from the Daytona
sandbox and saved under ``{SANDBOX_FILES_DIR}/{thread_id}/``.
Per-file errors are logged but do **not** prevent the sandbox from
being deleted freeing Daytona storage is the priority.
"""
_evict_sandbox_cache(thread_id)
def _persist_and_delete() -> None:
client = _get_client()
labels = {THREAD_LABEL_KEY: str(thread_id)}
try:
sandbox = client.find_one(labels=labels)
except Exception:
logger.info(
"No sandbox found for thread %s — nothing to persist", thread_id
)
return
# Ensure the sandbox is running so we can download files
if sandbox.state != SandboxState.STARTED:
try:
sandbox.start(timeout=60)
except Exception:
logger.warning(
"Could not start sandbox %s for file download — deleting anyway",
sandbox.id,
exc_info=True,
)
with contextlib.suppress(Exception):
client.delete(sandbox)
return
for path in sandbox_file_paths:
try:
content: bytes = sandbox.fs.download_file(path)
local = _local_path_for(thread_id, path)
local.parent.mkdir(parents=True, exist_ok=True)
local.write_bytes(content)
logger.info("Persisted sandbox file %s%s", path, local)
except Exception:
logger.warning(
"Failed to persist sandbox file %s for thread %s",
path,
thread_id,
exc_info=True,
)
try:
client.delete(sandbox)
logger.info("Sandbox deleted after file persistence: %s", sandbox.id)
except Exception:
logger.warning(
"Failed to delete sandbox %s after persistence",
sandbox.id,
exc_info=True,
)
await asyncio.to_thread(_persist_and_delete)