mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
feat: enhance sandbox functionality with threading support and file download capabilities
This commit is contained in:
parent
38b9e8dcc5
commit
b5301fa438
5 changed files with 103 additions and 29 deletions
|
|
@ -9,6 +9,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import secrets
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
|
||||
|
|
@ -27,6 +28,7 @@ from sqlalchemy import delete, select
|
|||
|
||||
from app.agents.new_chat.sandbox import (
|
||||
_evict_sandbox_cache,
|
||||
delete_sandbox,
|
||||
get_or_create_sandbox,
|
||||
is_sandbox_enabled,
|
||||
)
|
||||
|
|
@ -552,7 +554,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
@staticmethod
|
||||
def _wrap_as_python(code: str) -> str:
|
||||
"""Wrap Python code in a shell invocation for the sandbox."""
|
||||
return f"python3 << 'PYEOF'\n{code}\nPYEOF"
|
||||
sentinel = f"_PYEOF_{secrets.token_hex(8)}"
|
||||
return f"python3 << '{sentinel}'\n{code}\n{sentinel}"
|
||||
|
||||
async def _execute_in_sandbox(
|
||||
self,
|
||||
|
|
@ -572,7 +575,10 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
self._thread_id,
|
||||
first_err,
|
||||
)
|
||||
_evict_sandbox_cache(self._thread_id)
|
||||
try:
|
||||
await delete_sandbox(self._thread_id)
|
||||
except Exception:
|
||||
_evict_sandbox_cache(self._thread_id)
|
||||
try:
|
||||
return await self._try_sandbox_execute(command, runtime, timeout)
|
||||
except Exception:
|
||||
|
|
@ -588,6 +594,13 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
|||
timeout: int | None,
|
||||
) -> str:
|
||||
sandbox, is_new = await get_or_create_sandbox(self._thread_id)
|
||||
# NOTE: sync_files_to_sandbox is intentionally disabled.
|
||||
# The virtual FS contains XML-wrapped KB documents whose paths
|
||||
# would double-nest under SANDBOX_DOCUMENTS_ROOT (e.g.
|
||||
# /home/daytona/documents/documents/Report.xml) and uploading
|
||||
# all KB docs on the first execute_code call adds significant
|
||||
# latency. Re-enable once path mapping is fixed and upload is
|
||||
# limited to user-created scratch files.
|
||||
# files = runtime.state.get("files") or {}
|
||||
# await sync_files_to_sandbox(self._thread_id, files, sandbox, is_new)
|
||||
result = await sandbox.aexecute(command, timeout=timeout)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import contextlib
|
|||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
from daytona import (
|
||||
|
|
@ -55,9 +56,16 @@ class _TimeoutAwareSandbox(DaytonaSandbox):
|
|||
) -> ExecuteResponse: # type: ignore[override]
|
||||
return await asyncio.to_thread(self.execute, command, timeout=timeout)
|
||||
|
||||
def download_file(self, path: str) -> bytes:
|
||||
"""Download a file from the sandbox filesystem."""
|
||||
return self._sandbox.fs.download_file(path)
|
||||
|
||||
|
||||
_daytona_client: Daytona | None = None
|
||||
_client_lock = threading.Lock()
|
||||
_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {}
|
||||
_sandbox_locks: dict[str, asyncio.Lock] = {}
|
||||
_sandbox_locks_mu = asyncio.Lock()
|
||||
_seeded_files: dict[str, dict[str, str]] = {}
|
||||
_SANDBOX_CACHE_MAX_SIZE = 20
|
||||
THREAD_LABEL_KEY = "surfsense_thread"
|
||||
|
|
@ -70,14 +78,15 @@ def is_sandbox_enabled() -> bool:
|
|||
|
||||
def _get_client() -> Daytona:
|
||||
global _daytona_client
|
||||
if _daytona_client is None:
|
||||
config = DaytonaConfig(
|
||||
api_key=os.environ.get("DAYTONA_API_KEY", ""),
|
||||
api_url=os.environ.get("DAYTONA_API_URL", "https://app.daytona.io/api"),
|
||||
target=os.environ.get("DAYTONA_TARGET", "us"),
|
||||
)
|
||||
_daytona_client = Daytona(config)
|
||||
return _daytona_client
|
||||
with _client_lock:
|
||||
if _daytona_client is None:
|
||||
config = DaytonaConfig(
|
||||
api_key=os.environ.get("DAYTONA_API_KEY", ""),
|
||||
api_url=os.environ.get("DAYTONA_API_URL", "https://app.daytona.io/api"),
|
||||
target=os.environ.get("DAYTONA_TARGET", "us"),
|
||||
)
|
||||
_daytona_client = Daytona(config)
|
||||
return _daytona_client
|
||||
|
||||
|
||||
def _sandbox_create_params(
|
||||
|
|
@ -136,7 +145,7 @@ def _find_or_create(thread_id: str) -> tuple[_TimeoutAwareSandbox, bool]:
|
|||
elif sandbox.state != SandboxState.STARTED:
|
||||
sandbox.wait_for_sandbox_start(timeout=60)
|
||||
|
||||
except Exception:
|
||||
except DaytonaError:
|
||||
logger.info("No existing sandbox for thread %s — creating one", thread_id)
|
||||
sandbox = client.create(_sandbox_create_params(labels))
|
||||
is_new = True
|
||||
|
|
@ -145,6 +154,16 @@ def _find_or_create(thread_id: str) -> tuple[_TimeoutAwareSandbox, bool]:
|
|||
return _TimeoutAwareSandbox(sandbox=sandbox), is_new
|
||||
|
||||
|
||||
async def _get_thread_lock(key: str) -> asyncio.Lock:
|
||||
"""Return a per-thread asyncio lock, creating one if needed."""
|
||||
async with _sandbox_locks_mu:
|
||||
lock = _sandbox_locks.get(key)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_sandbox_locks[key] = lock
|
||||
return lock
|
||||
|
||||
|
||||
async def get_or_create_sandbox(
|
||||
thread_id: int | str,
|
||||
) -> tuple[_TimeoutAwareSandbox, bool]:
|
||||
|
|
@ -152,25 +171,51 @@ async def get_or_create_sandbox(
|
|||
|
||||
Uses an in-process cache keyed by thread_id so subsequent messages
|
||||
in the same conversation reuse the sandbox object without an API call.
|
||||
A per-thread async lock prevents duplicate sandbox creation from
|
||||
concurrent requests.
|
||||
|
||||
Returns:
|
||||
Tuple of (sandbox, is_new). *is_new* is True when a fresh sandbox
|
||||
was created, signalling that file tracking should be reset.
|
||||
"""
|
||||
key = str(thread_id)
|
||||
cached = _sandbox_cache.get(key)
|
||||
if cached is not None:
|
||||
logger.info("Reusing cached sandbox for thread %s", key)
|
||||
return cached, False
|
||||
sandbox, is_new = await asyncio.to_thread(_find_or_create, key)
|
||||
_sandbox_cache[key] = sandbox
|
||||
lock = await _get_thread_lock(key)
|
||||
|
||||
if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE:
|
||||
oldest_key = next(iter(_sandbox_cache))
|
||||
_sandbox_cache.pop(oldest_key, None)
|
||||
logger.debug("Evicted oldest sandbox cache entry: %s", oldest_key)
|
||||
async with lock:
|
||||
cached = _sandbox_cache.get(key)
|
||||
if cached is not None:
|
||||
logger.info("Reusing cached sandbox for thread %s", key)
|
||||
return cached, False
|
||||
sandbox, is_new = await asyncio.to_thread(_find_or_create, key)
|
||||
_sandbox_cache[key] = sandbox
|
||||
|
||||
return sandbox, is_new
|
||||
if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE:
|
||||
oldest_key = next(iter(_sandbox_cache))
|
||||
if oldest_key != key:
|
||||
evicted = _sandbox_cache.pop(oldest_key, None)
|
||||
_seeded_files.pop(oldest_key, None)
|
||||
logger.debug("Evicted sandbox cache entry: %s", oldest_key)
|
||||
if evicted is not None:
|
||||
_schedule_sandbox_delete(evicted)
|
||||
|
||||
return sandbox, is_new
|
||||
|
||||
|
||||
def _schedule_sandbox_delete(sandbox: _TimeoutAwareSandbox) -> None:
|
||||
"""Best-effort background deletion of an evicted sandbox."""
|
||||
def _delete() -> None:
|
||||
try:
|
||||
client = _get_client()
|
||||
client.delete(sandbox._sandbox)
|
||||
logger.info("Deleted evicted sandbox: %s", sandbox._sandbox.id)
|
||||
except Exception:
|
||||
logger.debug("Could not delete evicted sandbox", exc_info=True)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.run_in_executor(None, _delete)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
async def sync_files_to_sandbox(
|
||||
|
|
|
|||
|
|
@ -86,9 +86,8 @@ async def download_sandbox_file(
|
|||
|
||||
# Fall back to live sandbox download
|
||||
try:
|
||||
sandbox = await get_or_create_sandbox(thread_id)
|
||||
raw_sandbox = sandbox._sandbox
|
||||
content: bytes = await asyncio.to_thread(raw_sandbox.fs.download_file, path)
|
||||
sandbox, _ = await get_or_create_sandbox(thread_id)
|
||||
content: bytes = await asyncio.to_thread(sandbox.download_file, path)
|
||||
except Exception as exc:
|
||||
logger.warning("Sandbox file download failed for %s: %s", path, exc)
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -142,7 +142,7 @@ class StreamResult:
|
|||
accumulated_text: str = ""
|
||||
is_interrupted: bool = False
|
||||
interrupt_value: dict[str, Any] | None = None
|
||||
sandbox_files: list[str] = field(default_factory=list) # unused, kept for compat
|
||||
sandbox_files: list[str] = field(default_factory=list)
|
||||
agent_called_update_memory: bool = False
|
||||
|
||||
|
||||
|
|
@ -440,7 +440,7 @@ async def _stream_agent_events(
|
|||
status="in_progress",
|
||||
items=last_active_step_items,
|
||||
)
|
||||
elif tool_name == "execute":
|
||||
elif tool_name in ("execute", "execute_code"):
|
||||
cmd = (
|
||||
tool_input.get("command", "")
|
||||
if isinstance(tool_input, dict)
|
||||
|
|
@ -738,7 +738,7 @@ async def _stream_agent_events(
|
|||
status="completed",
|
||||
items=completed_items,
|
||||
)
|
||||
elif tool_name == "execute":
|
||||
elif tool_name in ("execute", "execute_code"):
|
||||
raw_text = (
|
||||
tool_output.get("result", "")
|
||||
if isinstance(tool_output, dict)
|
||||
|
|
@ -985,7 +985,7 @@ async def _stream_agent_events(
|
|||
if isinstance(tool_output, dict)
|
||||
else {"result": tool_output},
|
||||
)
|
||||
elif tool_name == "execute":
|
||||
elif tool_name in ("execute", "execute_code"):
|
||||
raw_text = (
|
||||
tool_output.get("result", "")
|
||||
if isinstance(tool_output, dict)
|
||||
|
|
@ -1617,6 +1617,21 @@ async def stream_new_chat(
|
|||
with contextlib.suppress(Exception):
|
||||
await session.close()
|
||||
|
||||
# Persist any sandbox-produced files to local storage so they
|
||||
# remain downloadable after the Daytona sandbox auto-deletes.
|
||||
if stream_result and stream_result.sandbox_files:
|
||||
with contextlib.suppress(Exception):
|
||||
from app.agents.new_chat.sandbox import (
|
||||
is_sandbox_enabled,
|
||||
persist_and_delete_sandbox,
|
||||
)
|
||||
|
||||
if is_sandbox_enabled():
|
||||
with anyio.CancelScope(shield=True):
|
||||
await persist_and_delete_sandbox(
|
||||
chat_id, stream_result.sandbox_files
|
||||
)
|
||||
|
||||
# Break circular refs held by the agent graph, tools, and LLM
|
||||
# wrappers so the GC can reclaim them in a single pass.
|
||||
agent = llm = connector_service = None
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ import { useComments } from "@/hooks/use-comments";
|
|||
import { useMediaQuery } from "@/hooks/use-media-query";
|
||||
import { useElectronAPI } from "@/hooks/use-platform";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { openSafeNavigationHref, resolveSafeNavigationHref } from "@/components/tool-ui/shared/media";
|
||||
|
||||
// Captured once at module load — survives client-side navigations that strip the query param.
|
||||
const IS_QUICK_ASSIST_WINDOW =
|
||||
|
|
@ -384,6 +385,7 @@ const AssistantMessageInner: FC = () => {
|
|||
generate_image: GenerateImageToolUI,
|
||||
update_memory: UpdateMemoryToolUI,
|
||||
execute: SandboxExecuteToolUI,
|
||||
execute_code: SandboxExecuteToolUI,
|
||||
create_notion_page: CreateNotionPageToolUI,
|
||||
update_notion_page: UpdateNotionPageToolUI,
|
||||
delete_notion_page: DeleteNotionPageToolUI,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue