mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-07-04 22:02:16 +02:00
harden sandbox: security params, file sync, path traversal fix
This commit is contained in:
parent
7ea840dbb2
commit
9e8ea1fd1c
1 changed files with 89 additions and 22 deletions
|
|
@ -42,7 +42,7 @@ class _TimeoutAwareSandbox(DaytonaSandbox):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def execute(self, command: str, *, timeout: int | None = None) -> ExecuteResponse:
|
def execute(self, command: str, *, timeout: int | None = None) -> ExecuteResponse:
|
||||||
t = timeout if timeout is not None else self._timeout
|
t = timeout if timeout is not None else self._default_timeout
|
||||||
result = self._sandbox.process.exec(command, timeout=t)
|
result = self._sandbox.process.exec(command, timeout=t)
|
||||||
return ExecuteResponse(
|
return ExecuteResponse(
|
||||||
output=result.result,
|
output=result.result,
|
||||||
|
|
@ -58,8 +58,10 @@ class _TimeoutAwareSandbox(DaytonaSandbox):
|
||||||
|
|
||||||
_daytona_client: Daytona | None = None
|
_daytona_client: Daytona | None = None
|
||||||
_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {}
|
_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {}
|
||||||
|
_seeded_files: dict[str, dict[str, str]] = {}
|
||||||
_SANDBOX_CACHE_MAX_SIZE = 20
|
_SANDBOX_CACHE_MAX_SIZE = 20
|
||||||
THREAD_LABEL_KEY = "surfsense_thread"
|
THREAD_LABEL_KEY = "surfsense_thread"
|
||||||
|
SANDBOX_DOCUMENTS_ROOT = "/home/daytona/documents"
|
||||||
|
|
||||||
|
|
||||||
def is_sandbox_enabled() -> bool:
|
def is_sandbox_enabled() -> bool:
|
||||||
|
|
@ -78,14 +80,29 @@ def _get_client() -> Daytona:
|
||||||
return _daytona_client
|
return _daytona_client
|
||||||
|
|
||||||
|
|
||||||
def _find_or_create(thread_id: str) -> _TimeoutAwareSandbox:
|
def _sandbox_create_params(
|
||||||
|
labels: dict[str, str],
|
||||||
|
) -> CreateSandboxFromSnapshotParams:
|
||||||
|
snapshot_id = os.environ.get("DAYTONA_SNAPSHOT_ID") or None
|
||||||
|
return CreateSandboxFromSnapshotParams(
|
||||||
|
language="python",
|
||||||
|
labels=labels,
|
||||||
|
snapshot=snapshot_id,
|
||||||
|
network_block_all=True,
|
||||||
|
auto_stop_interval=10,
|
||||||
|
auto_delete_interval=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_or_create(thread_id: str) -> tuple[_TimeoutAwareSandbox, bool]:
|
||||||
"""Find an existing sandbox for *thread_id*, or create a new one.
|
"""Find an existing sandbox for *thread_id*, or create a new one.
|
||||||
|
|
||||||
If an existing sandbox is found but is stopped/archived, it will be
|
Returns a tuple of (sandbox, is_new) where *is_new* is True when a
|
||||||
restarted automatically before returning.
|
fresh sandbox was created (first time or replacement after failure).
|
||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
labels = {THREAD_LABEL_KEY: thread_id}
|
labels = {THREAD_LABEL_KEY: thread_id}
|
||||||
|
is_new = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sandbox = client.find_one(labels=labels)
|
sandbox = client.find_one(labels=labels)
|
||||||
|
|
@ -109,41 +126,39 @@ def _find_or_create(thread_id: str) -> _TimeoutAwareSandbox:
|
||||||
sandbox.id,
|
sandbox.id,
|
||||||
sandbox.state,
|
sandbox.state,
|
||||||
)
|
)
|
||||||
sandbox = client.create(
|
sandbox = client.create(_sandbox_create_params(labels))
|
||||||
CreateSandboxFromSnapshotParams(language="python", labels=labels)
|
is_new = True
|
||||||
)
|
|
||||||
logger.info("Created replacement sandbox: %s", sandbox.id)
|
logger.info("Created replacement sandbox: %s", sandbox.id)
|
||||||
elif sandbox.state != SandboxState.STARTED:
|
elif sandbox.state != SandboxState.STARTED:
|
||||||
sandbox.wait_for_sandbox_start(timeout=60)
|
sandbox.wait_for_sandbox_start(timeout=60)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.info("No existing sandbox for thread %s — creating one", thread_id)
|
logger.info("No existing sandbox for thread %s — creating one", thread_id)
|
||||||
sandbox = client.create(
|
sandbox = client.create(_sandbox_create_params(labels))
|
||||||
CreateSandboxFromSnapshotParams(language="python", labels=labels)
|
is_new = True
|
||||||
)
|
|
||||||
logger.info("Created new sandbox: %s", sandbox.id)
|
logger.info("Created new sandbox: %s", sandbox.id)
|
||||||
|
|
||||||
return _TimeoutAwareSandbox(sandbox=sandbox)
|
return _TimeoutAwareSandbox(sandbox=sandbox), is_new
|
||||||
|
|
||||||
|
|
||||||
async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox:
|
async def get_or_create_sandbox(
|
||||||
|
thread_id: int | str,
|
||||||
|
) -> tuple[_TimeoutAwareSandbox, bool]:
|
||||||
"""Get or create a sandbox for a conversation thread.
|
"""Get or create a sandbox for a conversation thread.
|
||||||
|
|
||||||
Uses an in-process cache keyed by thread_id so subsequent messages
|
Uses an in-process cache keyed by thread_id so subsequent messages
|
||||||
in the same conversation reuse the sandbox object without an API call.
|
in the same conversation reuse the sandbox object without an API call.
|
||||||
|
|
||||||
Args:
|
|
||||||
thread_id: The conversation thread identifier.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
DaytonaSandbox connected to the sandbox.
|
Tuple of (sandbox, is_new). *is_new* is True when a fresh sandbox
|
||||||
|
was created, signalling that file tracking should be reset.
|
||||||
"""
|
"""
|
||||||
key = str(thread_id)
|
key = str(thread_id)
|
||||||
cached = _sandbox_cache.get(key)
|
cached = _sandbox_cache.get(key)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
logger.info("Reusing cached sandbox for thread %s", key)
|
logger.info("Reusing cached sandbox for thread %s", key)
|
||||||
return cached
|
return cached, False
|
||||||
sandbox = await asyncio.to_thread(_find_or_create, key)
|
sandbox, is_new = await asyncio.to_thread(_find_or_create, key)
|
||||||
_sandbox_cache[key] = sandbox
|
_sandbox_cache[key] = sandbox
|
||||||
|
|
||||||
if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE:
|
if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE:
|
||||||
|
|
@ -151,12 +166,60 @@ async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox:
|
||||||
_sandbox_cache.pop(oldest_key, None)
|
_sandbox_cache.pop(oldest_key, None)
|
||||||
logger.debug("Evicted oldest sandbox cache entry: %s", oldest_key)
|
logger.debug("Evicted oldest sandbox cache entry: %s", oldest_key)
|
||||||
|
|
||||||
return sandbox
|
return sandbox, is_new
|
||||||
|
|
||||||
|
|
||||||
|
async def sync_files_to_sandbox(
|
||||||
|
thread_id: int | str,
|
||||||
|
files: dict[str, dict],
|
||||||
|
sandbox: _TimeoutAwareSandbox,
|
||||||
|
is_new: bool,
|
||||||
|
) -> None:
|
||||||
|
"""Upload new or changed virtual-filesystem files to the sandbox.
|
||||||
|
|
||||||
|
Compares *files* (from ``state["files"]``) against the ``_seeded_files``
|
||||||
|
tracking dict and uploads only what has changed. When *is_new* is True
|
||||||
|
the tracking is reset so every file is re-uploaded.
|
||||||
|
"""
|
||||||
|
key = str(thread_id)
|
||||||
|
if is_new:
|
||||||
|
_seeded_files.pop(key, None)
|
||||||
|
|
||||||
|
tracked = _seeded_files.get(key, {})
|
||||||
|
to_upload: list[tuple[str, bytes]] = []
|
||||||
|
|
||||||
|
for vpath, fdata in files.items():
|
||||||
|
modified_at = fdata.get("modified_at", "")
|
||||||
|
if tracked.get(vpath) == modified_at:
|
||||||
|
continue
|
||||||
|
content = "\n".join(fdata.get("content", []))
|
||||||
|
sandbox_path = f"{SANDBOX_DOCUMENTS_ROOT}{vpath}"
|
||||||
|
to_upload.append((sandbox_path, content.encode("utf-8")))
|
||||||
|
|
||||||
|
if not to_upload:
|
||||||
|
return
|
||||||
|
|
||||||
|
def _upload() -> None:
|
||||||
|
sandbox.upload_files(to_upload)
|
||||||
|
|
||||||
|
await asyncio.to_thread(_upload)
|
||||||
|
|
||||||
|
new_tracked = dict(tracked)
|
||||||
|
for vpath, fdata in files.items():
|
||||||
|
new_tracked[vpath] = fdata.get("modified_at", "")
|
||||||
|
_seeded_files[key] = new_tracked
|
||||||
|
logger.info("Synced %d file(s) to sandbox for thread %s", len(to_upload), key)
|
||||||
|
|
||||||
|
|
||||||
|
def _evict_sandbox_cache(thread_id: int | str) -> None:
|
||||||
|
key = str(thread_id)
|
||||||
|
_sandbox_cache.pop(key, None)
|
||||||
|
_seeded_files.pop(key, None)
|
||||||
|
|
||||||
|
|
||||||
async def delete_sandbox(thread_id: int | str) -> None:
|
async def delete_sandbox(thread_id: int | str) -> None:
|
||||||
"""Delete the sandbox for a conversation thread."""
|
"""Delete the sandbox for a conversation thread."""
|
||||||
_sandbox_cache.pop(str(thread_id), None)
|
_evict_sandbox_cache(thread_id)
|
||||||
|
|
||||||
def _delete() -> None:
|
def _delete() -> None:
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
@ -193,7 +256,11 @@ def _get_sandbox_files_dir() -> Path:
|
||||||
def _local_path_for(thread_id: int | str, sandbox_path: str) -> Path:
|
def _local_path_for(thread_id: int | str, sandbox_path: str) -> Path:
|
||||||
"""Map a sandbox-internal absolute path to a local filesystem path."""
|
"""Map a sandbox-internal absolute path to a local filesystem path."""
|
||||||
relative = sandbox_path.lstrip("/")
|
relative = sandbox_path.lstrip("/")
|
||||||
return _get_sandbox_files_dir() / str(thread_id) / relative
|
base = (_get_sandbox_files_dir() / str(thread_id)).resolve()
|
||||||
|
target = (base / relative).resolve()
|
||||||
|
if not target.is_relative_to(base):
|
||||||
|
raise ValueError(f"Path traversal blocked: {sandbox_path}")
|
||||||
|
return target
|
||||||
|
|
||||||
|
|
||||||
def get_local_sandbox_file(thread_id: int | str, sandbox_path: str) -> bytes | None:
|
def get_local_sandbox_file(thread_id: int | str, sandbox_path: str) -> bytes | None:
|
||||||
|
|
@ -226,7 +293,7 @@ async def persist_and_delete_sandbox(
|
||||||
Per-file errors are logged but do **not** prevent the sandbox from
|
Per-file errors are logged but do **not** prevent the sandbox from
|
||||||
being deleted — freeing Daytona storage is the priority.
|
being deleted — freeing Daytona storage is the priority.
|
||||||
"""
|
"""
|
||||||
_sandbox_cache.pop(str(thread_id), None)
|
_evict_sandbox_cache(thread_id)
|
||||||
|
|
||||||
def _persist_and_delete() -> None:
|
def _persist_and_delete() -> None:
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue