diff --git a/surfsense_backend/app/agents/new_chat/sandbox.py b/surfsense_backend/app/agents/new_chat/sandbox.py index 8b634993b..79947de2b 100644 --- a/surfsense_backend/app/agents/new_chat/sandbox.py +++ b/surfsense_backend/app/agents/new_chat/sandbox.py @@ -42,7 +42,7 @@ class _TimeoutAwareSandbox(DaytonaSandbox): """ def execute(self, command: str, *, timeout: int | None = None) -> ExecuteResponse: - t = timeout if timeout is not None else self._timeout + t = timeout if timeout is not None else self._default_timeout result = self._sandbox.process.exec(command, timeout=t) return ExecuteResponse( output=result.result, @@ -58,8 +58,10 @@ class _TimeoutAwareSandbox(DaytonaSandbox): _daytona_client: Daytona | None = None _sandbox_cache: dict[str, _TimeoutAwareSandbox] = {} +_seeded_files: dict[str, dict[str, str]] = {} _SANDBOX_CACHE_MAX_SIZE = 20 THREAD_LABEL_KEY = "surfsense_thread" +SANDBOX_DOCUMENTS_ROOT = "/home/daytona/documents" def is_sandbox_enabled() -> bool: @@ -78,14 +80,29 @@ def _get_client() -> Daytona: return _daytona_client -def _find_or_create(thread_id: str) -> _TimeoutAwareSandbox: +def _sandbox_create_params( + labels: dict[str, str], +) -> CreateSandboxFromSnapshotParams: + snapshot_id = os.environ.get("DAYTONA_SNAPSHOT_ID") or None + return CreateSandboxFromSnapshotParams( + language="python", + labels=labels, + snapshot=snapshot_id, + network_block_all=True, + auto_stop_interval=10, + auto_delete_interval=60, + ) + + +def _find_or_create(thread_id: str) -> tuple[_TimeoutAwareSandbox, bool]: """Find an existing sandbox for *thread_id*, or create a new one. - If an existing sandbox is found but is stopped/archived, it will be - restarted automatically before returning. + Returns a tuple of (sandbox, is_new) where *is_new* is True when a + fresh sandbox was created (first time or replacement after failure). """ client = _get_client() labels = {THREAD_LABEL_KEY: thread_id} + is_new = False try: sandbox = client.find_one(labels=labels) @@ -109,41 +126,39 @@ def _find_or_create(thread_id: str) -> _TimeoutAwareSandbox: sandbox.id, sandbox.state, ) - sandbox = client.create( - CreateSandboxFromSnapshotParams(language="python", labels=labels) - ) + sandbox = client.create(_sandbox_create_params(labels)) + is_new = True logger.info("Created replacement sandbox: %s", sandbox.id) elif sandbox.state != SandboxState.STARTED: sandbox.wait_for_sandbox_start(timeout=60) except Exception: logger.info("No existing sandbox for thread %s — creating one", thread_id) - sandbox = client.create( - CreateSandboxFromSnapshotParams(language="python", labels=labels) - ) + sandbox = client.create(_sandbox_create_params(labels)) + is_new = True logger.info("Created new sandbox: %s", sandbox.id) - return _TimeoutAwareSandbox(sandbox=sandbox) + return _TimeoutAwareSandbox(sandbox=sandbox), is_new -async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox: +async def get_or_create_sandbox( + thread_id: int | str, +) -> tuple[_TimeoutAwareSandbox, bool]: """Get or create a sandbox for a conversation thread. Uses an in-process cache keyed by thread_id so subsequent messages in the same conversation reuse the sandbox object without an API call. - Args: - thread_id: The conversation thread identifier. - Returns: - DaytonaSandbox connected to the sandbox. + Tuple of (sandbox, is_new). *is_new* is True when a fresh sandbox + was created, signalling that file tracking should be reset. """ key = str(thread_id) cached = _sandbox_cache.get(key) if cached is not None: logger.info("Reusing cached sandbox for thread %s", key) - return cached - sandbox = await asyncio.to_thread(_find_or_create, key) + return cached, False + sandbox, is_new = await asyncio.to_thread(_find_or_create, key) _sandbox_cache[key] = sandbox if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE: @@ -151,12 +166,60 @@ async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox: _sandbox_cache.pop(oldest_key, None) logger.debug("Evicted oldest sandbox cache entry: %s", oldest_key) - return sandbox + return sandbox, is_new + + +async def sync_files_to_sandbox( + thread_id: int | str, + files: dict[str, dict], + sandbox: _TimeoutAwareSandbox, + is_new: bool, +) -> None: + """Upload new or changed virtual-filesystem files to the sandbox. + + Compares *files* (from ``state["files"]``) against the ``_seeded_files`` + tracking dict and uploads only what has changed. When *is_new* is True + the tracking is reset so every file is re-uploaded. + """ + key = str(thread_id) + if is_new: + _seeded_files.pop(key, None) + + tracked = _seeded_files.get(key, {}) + to_upload: list[tuple[str, bytes]] = [] + + for vpath, fdata in files.items(): + modified_at = fdata.get("modified_at", "") + if tracked.get(vpath) == modified_at: + continue + content = "\n".join(fdata.get("content", [])) + sandbox_path = f"{SANDBOX_DOCUMENTS_ROOT}{vpath}" + to_upload.append((sandbox_path, content.encode("utf-8"))) + + if not to_upload: + return + + def _upload() -> None: + sandbox.upload_files(to_upload) + + await asyncio.to_thread(_upload) + + new_tracked = dict(tracked) + for vpath, fdata in files.items(): + new_tracked[vpath] = fdata.get("modified_at", "") + _seeded_files[key] = new_tracked + logger.info("Synced %d file(s) to sandbox for thread %s", len(to_upload), key) + + +def _evict_sandbox_cache(thread_id: int | str) -> None: + key = str(thread_id) + _sandbox_cache.pop(key, None) + _seeded_files.pop(key, None) async def delete_sandbox(thread_id: int | str) -> None: """Delete the sandbox for a conversation thread.""" - _sandbox_cache.pop(str(thread_id), None) + _evict_sandbox_cache(thread_id) def _delete() -> None: client = _get_client() @@ -193,7 +256,11 @@ def _get_sandbox_files_dir() -> Path: def _local_path_for(thread_id: int | str, sandbox_path: str) -> Path: """Map a sandbox-internal absolute path to a local filesystem path.""" relative = sandbox_path.lstrip("/") - return _get_sandbox_files_dir() / str(thread_id) / relative + base = (_get_sandbox_files_dir() / str(thread_id)).resolve() + target = (base / relative).resolve() + if not target.is_relative_to(base): + raise ValueError(f"Path traversal blocked: {sandbox_path}") + return target def get_local_sandbox_file(thread_id: int | str, sandbox_path: str) -> bytes | None: @@ -226,7 +293,7 @@ async def persist_and_delete_sandbox( Per-file errors are logged but do **not** prevent the sandbox from being deleted — freeing Daytona storage is the priority. """ - _sandbox_cache.pop(str(thread_id), None) + _evict_sandbox_cache(thread_id) def _persist_and_delete() -> None: client = _get_client()