SurfSense/surfsense_backend/app/agents/new_chat/sandbox.py
2026-04-14 01:43:30 -07:00

401 lines
13 KiB
Python

"""
Daytona sandbox provider for SurfSense deep agent.
Manages the lifecycle of sandboxed code execution environments.
Each conversation thread gets its own isolated sandbox instance
via the Daytona cloud API, identified by labels.
Files created during a session are persisted to local storage before
the sandbox is deleted so they remain downloadable after cleanup.
"""
from __future__ import annotations
import asyncio
import contextlib
import logging
import os
import shutil
import threading
from pathlib import Path
from daytona import (
CreateSandboxFromSnapshotParams,
Daytona,
DaytonaConfig,
SandboxState,
)
from daytona.common.errors import DaytonaError
from deepagents.backends.protocol import ExecuteResponse
from langchain_daytona import DaytonaSandbox
logger = logging.getLogger(__name__)
class _TimeoutAwareSandbox(DaytonaSandbox):
"""DaytonaSandbox subclass that accepts the per-command *timeout*
kwarg required by the deepagents middleware.
The upstream ``langchain-daytona`` ``execute()`` ignores timeout,
so deepagents raises *"This sandbox backend does not support
per-command timeout overrides"* on every first call. This thin
wrapper forwards the parameter to the Daytona SDK.
"""
def execute(self, command: str, *, timeout: int | None = None) -> ExecuteResponse:
t = timeout if timeout is not None else self._default_timeout
result = self._sandbox.process.exec(command, timeout=t)
return ExecuteResponse(
output=result.result,
exit_code=result.exit_code,
truncated=False,
)
async def aexecute(
self, command: str, *, timeout: int | None = None
) -> ExecuteResponse: # type: ignore[override]
return await asyncio.to_thread(self.execute, command, timeout=timeout)
def download_file(self, path: str) -> bytes:
"""Download a file from the sandbox filesystem."""
return self._sandbox.fs.download_file(path)
_daytona_client: Daytona | None = None
_client_lock = threading.Lock()
_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {}
_sandbox_locks: dict[str, asyncio.Lock] = {}
_sandbox_locks_mu = asyncio.Lock()
_seeded_files: dict[str, dict[str, str]] = {}
_SANDBOX_CACHE_MAX_SIZE = 20
THREAD_LABEL_KEY = "surfsense_thread"
SANDBOX_DOCUMENTS_ROOT = "/home/daytona/documents"
def is_sandbox_enabled() -> bool:
return os.environ.get("DAYTONA_SANDBOX_ENABLED", "FALSE").upper() == "TRUE"
def _get_client() -> Daytona:
global _daytona_client
with _client_lock:
if _daytona_client is None:
config = DaytonaConfig(
api_key=os.environ.get("DAYTONA_API_KEY", ""),
api_url=os.environ.get("DAYTONA_API_URL", "https://app.daytona.io/api"),
target=os.environ.get("DAYTONA_TARGET", "us"),
)
_daytona_client = Daytona(config)
return _daytona_client
def _sandbox_create_params(
labels: dict[str, str],
) -> CreateSandboxFromSnapshotParams:
snapshot_id = os.environ.get("DAYTONA_SNAPSHOT_ID") or None
return CreateSandboxFromSnapshotParams(
language="python",
labels=labels,
snapshot=snapshot_id,
network_block_all=True,
auto_stop_interval=10,
auto_delete_interval=60,
)
def _find_or_create(thread_id: str) -> tuple[_TimeoutAwareSandbox, bool]:
"""Find an existing sandbox for *thread_id*, or create a new one.
Returns a tuple of (sandbox, is_new) where *is_new* is True when a
fresh sandbox was created (first time or replacement after failure).
"""
client = _get_client()
labels = {THREAD_LABEL_KEY: thread_id}
is_new = False
try:
sandbox = client.find_one(labels=labels)
logger.info("Found existing sandbox %s (state=%s)", sandbox.id, sandbox.state)
if sandbox.state in (
SandboxState.STOPPED,
SandboxState.STOPPING,
SandboxState.ARCHIVED,
):
logger.info("Starting stopped sandbox %s", sandbox.id)
sandbox.start(timeout=60)
logger.info("Sandbox %s is now started", sandbox.id)
elif sandbox.state in (
SandboxState.ERROR,
SandboxState.BUILD_FAILED,
SandboxState.DESTROYED,
):
logger.warning(
"Sandbox %s in unrecoverable state %s — creating a new one",
sandbox.id,
sandbox.state,
)
try:
client.delete(sandbox)
except Exception:
logger.debug(
"Could not delete broken sandbox %s", sandbox.id, exc_info=True
)
sandbox = client.create(_sandbox_create_params(labels))
is_new = True
logger.info("Created replacement sandbox: %s", sandbox.id)
elif sandbox.state != SandboxState.STARTED:
sandbox.wait_for_sandbox_start(timeout=60)
except DaytonaError:
logger.info("No existing sandbox for thread %s — creating one", thread_id)
sandbox = client.create(_sandbox_create_params(labels))
is_new = True
logger.info("Created new sandbox: %s", sandbox.id)
return _TimeoutAwareSandbox(sandbox=sandbox), is_new
async def _get_thread_lock(key: str) -> asyncio.Lock:
"""Return a per-thread asyncio lock, creating one if needed."""
async with _sandbox_locks_mu:
lock = _sandbox_locks.get(key)
if lock is None:
lock = asyncio.Lock()
_sandbox_locks[key] = lock
return lock
async def get_or_create_sandbox(
thread_id: int | str,
) -> tuple[_TimeoutAwareSandbox, bool]:
"""Get or create a sandbox for a conversation thread.
Uses an in-process cache keyed by thread_id so subsequent messages
in the same conversation reuse the sandbox object without an API call.
A per-thread async lock prevents duplicate sandbox creation from
concurrent requests.
Returns:
Tuple of (sandbox, is_new). *is_new* is True when a fresh sandbox
was created, signalling that file tracking should be reset.
"""
key = str(thread_id)
lock = await _get_thread_lock(key)
async with lock:
cached = _sandbox_cache.get(key)
if cached is not None:
logger.info("Reusing cached sandbox for thread %s", key)
return cached, False
sandbox, is_new = await asyncio.to_thread(_find_or_create, key)
_sandbox_cache[key] = sandbox
if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE:
oldest_key = next(iter(_sandbox_cache))
if oldest_key != key:
evicted = _sandbox_cache.pop(oldest_key, None)
_seeded_files.pop(oldest_key, None)
logger.debug("Evicted sandbox cache entry: %s", oldest_key)
if evicted is not None:
_schedule_sandbox_delete(evicted)
return sandbox, is_new
def _schedule_sandbox_delete(sandbox: _TimeoutAwareSandbox) -> None:
"""Best-effort background deletion of an evicted sandbox."""
def _delete() -> None:
try:
client = _get_client()
client.delete(sandbox._sandbox)
logger.info("Deleted evicted sandbox: %s", sandbox._sandbox.id)
except Exception:
logger.debug("Could not delete evicted sandbox", exc_info=True)
try:
loop = asyncio.get_running_loop()
loop.run_in_executor(None, _delete)
except RuntimeError:
pass
async def sync_files_to_sandbox(
thread_id: int | str,
files: dict[str, dict],
sandbox: _TimeoutAwareSandbox,
is_new: bool,
) -> None:
"""Upload new or changed virtual-filesystem files to the sandbox.
Compares *files* (from ``state["files"]``) against the ``_seeded_files``
tracking dict and uploads only what has changed. When *is_new* is True
the tracking is reset so every file is re-uploaded.
"""
key = str(thread_id)
if is_new:
_seeded_files.pop(key, None)
tracked = _seeded_files.get(key, {})
to_upload: list[tuple[str, bytes]] = []
for vpath, fdata in files.items():
modified_at = fdata.get("modified_at", "")
if tracked.get(vpath) == modified_at:
continue
content = "\n".join(fdata.get("content", []))
sandbox_path = f"{SANDBOX_DOCUMENTS_ROOT}{vpath}"
to_upload.append((sandbox_path, content.encode("utf-8")))
if not to_upload:
return
def _upload() -> None:
sandbox.upload_files(to_upload)
await asyncio.to_thread(_upload)
new_tracked = dict(tracked)
for vpath, fdata in files.items():
new_tracked[vpath] = fdata.get("modified_at", "")
_seeded_files[key] = new_tracked
logger.info("Synced %d file(s) to sandbox for thread %s", len(to_upload), key)
def _evict_sandbox_cache(thread_id: int | str) -> None:
key = str(thread_id)
_sandbox_cache.pop(key, None)
_seeded_files.pop(key, None)
async def delete_sandbox(thread_id: int | str) -> None:
"""Delete the sandbox for a conversation thread."""
_evict_sandbox_cache(thread_id)
def _delete() -> None:
client = _get_client()
labels = {THREAD_LABEL_KEY: str(thread_id)}
try:
sandbox = client.find_one(labels=labels)
except DaytonaError:
logger.debug(
"No sandbox to delete for thread %s (already removed)", thread_id
)
return
try:
client.delete(sandbox)
logger.info("Sandbox deleted: %s", sandbox.id)
except Exception:
logger.warning(
"Failed to delete sandbox for thread %s",
thread_id,
exc_info=True,
)
await asyncio.to_thread(_delete)
# ---------------------------------------------------------------------------
# Local file persistence
# ---------------------------------------------------------------------------
def _get_sandbox_files_dir() -> Path:
return Path(os.environ.get("SANDBOX_FILES_DIR", "sandbox_files"))
def _local_path_for(thread_id: int | str, sandbox_path: str) -> Path:
"""Map a sandbox-internal absolute path to a local filesystem path."""
relative = sandbox_path.lstrip("/")
base = (_get_sandbox_files_dir() / str(thread_id)).resolve()
target = (base / relative).resolve()
if not target.is_relative_to(base):
raise ValueError(f"Path traversal blocked: {sandbox_path}")
return target
def get_local_sandbox_file(thread_id: int | str, sandbox_path: str) -> bytes | None:
"""Read a previously-persisted sandbox file from local storage.
Returns the file bytes, or *None* if the file does not exist locally.
"""
local = _local_path_for(thread_id, sandbox_path)
if local.is_file():
return local.read_bytes()
return None
def delete_local_sandbox_files(thread_id: int | str) -> None:
"""Remove all locally-persisted sandbox files for a thread."""
thread_dir = _get_sandbox_files_dir() / str(thread_id)
if thread_dir.is_dir():
shutil.rmtree(thread_dir, ignore_errors=True)
logger.info("Deleted local sandbox files for thread %s", thread_id)
async def persist_and_delete_sandbox(
thread_id: int | str,
sandbox_file_paths: list[str],
) -> None:
"""Download sandbox files to local storage, then delete the sandbox.
Each file in *sandbox_file_paths* is downloaded from the Daytona
sandbox and saved under ``{SANDBOX_FILES_DIR}/{thread_id}/…``.
Per-file errors are logged but do **not** prevent the sandbox from
being deleted — freeing Daytona storage is the priority.
"""
_evict_sandbox_cache(thread_id)
def _persist_and_delete() -> None:
client = _get_client()
labels = {THREAD_LABEL_KEY: str(thread_id)}
try:
sandbox = client.find_one(labels=labels)
except Exception:
logger.info(
"No sandbox found for thread %s — nothing to persist", thread_id
)
return
# Ensure the sandbox is running so we can download files
if sandbox.state != SandboxState.STARTED:
try:
sandbox.start(timeout=60)
except Exception:
logger.warning(
"Could not start sandbox %s for file download — deleting anyway",
sandbox.id,
exc_info=True,
)
with contextlib.suppress(Exception):
client.delete(sandbox)
return
for path in sandbox_file_paths:
try:
content: bytes = sandbox.fs.download_file(path)
local = _local_path_for(thread_id, path)
local.parent.mkdir(parents=True, exist_ok=True)
local.write_bytes(content)
logger.info("Persisted sandbox file %s%s", path, local)
except Exception:
logger.warning(
"Failed to persist sandbox file %s for thread %s",
path,
thread_id,
exc_info=True,
)
try:
client.delete(sandbox)
logger.info("Sandbox deleted after file persistence: %s", sandbox.id)
except Exception:
logger.warning(
"Failed to delete sandbox %s after persistence",
sandbox.id,
exc_info=True,
)
await asyncio.to_thread(_persist_and_delete)