Merge remote-tracking branch 'upstream/dev' into feat/token-calculation

This commit is contained in:
Anish Sarkar 2026-04-14 15:49:39 +05:30
commit 8fd7664f8f
70 changed files with 3348 additions and 971 deletions

View file

@ -0,0 +1,44 @@
"""124_add_ai_file_sort_enabled
Revision ID: 124
Revises: 123
Create Date: 2026-04-14
Adds ai_file_sort_enabled boolean column to searchspaces.
Defaults to False so AI file sorting is opt-in per search space.
"""
from __future__ import annotations
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "124"
down_revision: str | None = "123"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
conn = op.get_bind()
existing_columns = [
col["name"] for col in sa.inspect(conn).get_columns("searchspaces")
]
if "ai_file_sort_enabled" not in existing_columns:
op.add_column(
"searchspaces",
sa.Column(
"ai_file_sort_enabled",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)
def downgrade() -> None:
op.drop_column("searchspaces", "ai_file_sort_enabled")

View file

@ -93,7 +93,8 @@ class DedupHITLToolCallsMiddleware(AgentMiddleware): # type: ignore[type-arg]
@staticmethod
def _dedup(
state: AgentState, dedup_keys: dict[str, str] # type: ignore[type-arg]
state: AgentState,
dedup_keys: dict[str, str], # type: ignore[type-arg]
) -> dict[str, Any] | None:
messages = state.get("messages")
if not messages:

View file

@ -9,6 +9,7 @@ from __future__ import annotations
import asyncio
import logging
import re
import secrets
from datetime import UTC, datetime
from typing import Annotated, Any
@ -27,6 +28,7 @@ from sqlalchemy import delete, select
from app.agents.new_chat.sandbox import (
_evict_sandbox_cache,
delete_sandbox,
get_or_create_sandbox,
is_sandbox_enabled,
)
@ -552,7 +554,8 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
@staticmethod
def _wrap_as_python(code: str) -> str:
"""Wrap Python code in a shell invocation for the sandbox."""
return f"python3 << 'PYEOF'\n{code}\nPYEOF"
sentinel = f"_PYEOF_{secrets.token_hex(8)}"
return f"python3 << '{sentinel}'\n{code}\n{sentinel}"
async def _execute_in_sandbox(
self,
@ -572,7 +575,10 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
self._thread_id,
first_err,
)
_evict_sandbox_cache(self._thread_id)
try:
await delete_sandbox(self._thread_id)
except Exception:
_evict_sandbox_cache(self._thread_id)
try:
return await self._try_sandbox_execute(command, runtime, timeout)
except Exception:
@ -587,7 +593,14 @@ class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
runtime: ToolRuntime[None, FilesystemState],
timeout: int | None,
) -> str:
sandbox, is_new = await get_or_create_sandbox(self._thread_id)
sandbox, _is_new = await get_or_create_sandbox(self._thread_id)
# NOTE: sync_files_to_sandbox is intentionally disabled.
# The virtual FS contains XML-wrapped KB documents whose paths
# would double-nest under SANDBOX_DOCUMENTS_ROOT (e.g.
# /home/daytona/documents/documents/Report.xml) and uploading
# all KB docs on the first execute_code call adds significant
# latency. Re-enable once path mapping is fixed and upload is
# limited to user-created scratch files.
# files = runtime.state.get("files") or {}
# await sync_files_to_sandbox(self._thread_id, files, sandbox, is_new)
result = await sandbox.aexecute(command, timeout=timeout)

View file

@ -58,6 +58,14 @@ class KBSearchPlan(BaseModel):
default=None,
description="Optional ISO end date or datetime for KB search filtering.",
)
is_recency_query: bool = Field(
default=False,
description=(
"True when the user's intent is primarily about recency or temporal "
"ordering (e.g. 'latest', 'newest', 'most recent', 'last uploaded') "
"rather than topical relevance."
),
)
def _extract_text_from_message(message: BaseMessage) -> str:
@ -245,7 +253,7 @@ def _build_kb_planner_prompt(
return (
"You optimize internal knowledge-base search inputs for document retrieval.\n"
"Return JSON only with this exact shape:\n"
'{"optimized_query":"string","start_date":"ISO string or null","end_date":"ISO string or null"}\n\n'
'{"optimized_query":"string","start_date":"ISO string or null","end_date":"ISO string or null","is_recency_query":bool}\n\n'
"Rules:\n"
"- Preserve the user's intent.\n"
"- Rewrite the query to improve retrieval using concrete entities, acronyms, projects, tools, people, and document-specific terms when helpful.\n"
@ -253,6 +261,11 @@ def _build_kb_planner_prompt(
"- Only use date filters when the latest user request or recent dialogue clearly implies a time range.\n"
"- If you use date filters, prefer returning both bounds.\n"
"- If no date filter is useful, return null for both dates.\n"
'- Set "is_recency_query" to true ONLY when the user\'s primary intent is about '
"recency or temporal ordering rather than topical relevance. Examples: "
'"latest file", "newest upload", "most recent document", "what did I save last", '
'"show me files from today", "last thing I added". '
"When true, results will be sorted by date instead of relevance.\n"
"- Do not include markdown, prose, or explanations.\n\n"
f"Today's UTC date: {today}\n\n"
f"Recent conversation:\n{recent_conversation or '(none)'}\n\n"
@ -506,6 +519,135 @@ def _resolve_search_types(
return list(expanded) if expanded else None
_RECENCY_MAX_CHUNKS_PER_DOC = 5
async def browse_recent_documents(
*,
search_space_id: int,
document_type: list[str] | None = None,
top_k: int = 10,
start_date: datetime | None = None,
end_date: datetime | None = None,
) -> list[dict[str, Any]]:
"""Return documents ordered by recency (newest first), no relevance ranking.
Used when the user's intent is temporal ("latest file", "most recent upload")
and hybrid search would produce poor results because the query has no
meaningful topical signal.
"""
from sqlalchemy import func, select
from app.db import DocumentType
async with shielded_async_session() as session:
base_conditions = [
Document.search_space_id == search_space_id,
func.coalesce(Document.status["state"].astext, "ready") != "deleting",
]
if document_type is not None:
import contextlib
doc_type_enums = []
for dt in document_type:
if isinstance(dt, str):
with contextlib.suppress(KeyError):
doc_type_enums.append(DocumentType[dt])
else:
doc_type_enums.append(dt)
if doc_type_enums:
if len(doc_type_enums) == 1:
base_conditions.append(Document.document_type == doc_type_enums[0])
else:
base_conditions.append(Document.document_type.in_(doc_type_enums))
if start_date is not None:
base_conditions.append(Document.updated_at >= start_date)
if end_date is not None:
base_conditions.append(Document.updated_at <= end_date)
doc_query = (
select(Document)
.where(*base_conditions)
.order_by(Document.updated_at.desc())
.limit(top_k)
)
result = await session.execute(doc_query)
documents = result.scalars().unique().all()
if not documents:
return []
doc_ids = [d.id for d in documents]
numbered = (
select(
Chunk.id.label("chunk_id"),
Chunk.document_id,
Chunk.content,
func.row_number()
.over(partition_by=Chunk.document_id, order_by=Chunk.id)
.label("rn"),
)
.where(Chunk.document_id.in_(doc_ids))
.subquery("numbered")
)
chunk_query = (
select(numbered.c.chunk_id, numbered.c.document_id, numbered.c.content)
.where(numbered.c.rn <= _RECENCY_MAX_CHUNKS_PER_DOC)
.order_by(numbered.c.document_id, numbered.c.chunk_id)
)
chunk_result = await session.execute(chunk_query)
fetched_chunks = chunk_result.all()
doc_chunks: dict[int, list[dict[str, Any]]] = {d.id: [] for d in documents}
for row in fetched_chunks:
if row.document_id in doc_chunks:
doc_chunks[row.document_id].append(
{"chunk_id": row.chunk_id, "content": row.content}
)
results: list[dict[str, Any]] = []
for doc in documents:
chunks_list = doc_chunks.get(doc.id, [])
metadata = doc.document_metadata or {}
results.append(
{
"document_id": doc.id,
"content": "\n\n".join(
c["content"] for c in chunks_list if c.get("content")
),
"score": 0.0,
"chunks": chunks_list,
"matched_chunk_ids": [],
"document": {
"id": doc.id,
"title": doc.title,
"document_type": (
doc.document_type.value
if getattr(doc, "document_type", None)
else None
),
"metadata": metadata,
},
"source": (
doc.document_type.value
if getattr(doc, "document_type", None)
else None
),
}
)
logger.info(
"browse_recent_documents: %d docs returned for space=%d",
len(results),
search_space_id,
)
return results
async def search_knowledge_base(
*,
query: str,
@ -704,10 +846,13 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
*,
messages: Sequence[BaseMessage],
user_text: str,
) -> tuple[str, datetime | None, datetime | None]:
"""Rewrite the KB query and infer optional date filters with the LLM."""
) -> tuple[str, datetime | None, datetime | None, bool]:
"""Rewrite the KB query and infer optional date filters with the LLM.
Returns (optimized_query, start_date, end_date, is_recency_query).
"""
if self.llm is None:
return user_text, None, None
return user_text, None, None, False
recent_conversation = _render_recent_conversation(
messages,
@ -734,15 +879,18 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
plan.start_date,
plan.end_date,
)
is_recency = plan.is_recency_query
_perf_log.info(
"[kb_fs_middleware] planner in %.3fs query=%r optimized=%r start=%s end=%s",
"[kb_fs_middleware] planner in %.3fs query=%r optimized=%r "
"start=%s end=%s recency=%s",
loop.time() - t0,
user_text[:80],
optimized_query[:120],
start_date.isoformat() if start_date else None,
end_date.isoformat() if end_date else None,
is_recency,
)
return optimized_query, start_date, end_date
return optimized_query, start_date, end_date, is_recency
except (json.JSONDecodeError, ValidationError, ValueError) as exc:
logger.warning(
"KB planner returned invalid output, using raw query: %s", exc
@ -750,7 +898,7 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
except Exception as exc: # pragma: no cover - defensive fallback
logger.warning("KB planner failed, using raw query: %s", exc)
return user_text, None, None
return user_text, None, None, False
def before_agent( # type: ignore[override]
self,
@ -789,7 +937,12 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
t0 = _perf_log and asyncio.get_event_loop().time()
existing_files = state.get("files")
planned_query, start_date, end_date = await self._plan_search_inputs(
(
planned_query,
start_date,
end_date,
is_recency,
) = await self._plan_search_inputs(
messages=messages,
user_text=user_text,
)
@ -805,16 +958,28 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
# messages within the same agent instance.
self.mentioned_document_ids = []
# --- 2. Run KB hybrid search ---
search_results = await search_knowledge_base(
query=planned_query,
search_space_id=self.search_space_id,
available_connectors=self.available_connectors,
available_document_types=self.available_document_types,
top_k=self.top_k,
start_date=start_date,
end_date=end_date,
)
# --- 2. Run KB search (recency browse or hybrid) ---
if is_recency:
doc_types = _resolve_search_types(
self.available_connectors, self.available_document_types
)
search_results = await browse_recent_documents(
search_space_id=self.search_space_id,
document_type=doc_types,
top_k=self.top_k,
start_date=start_date,
end_date=end_date,
)
else:
search_results = await search_knowledge_base(
query=planned_query,
search_space_id=self.search_space_id,
available_connectors=self.available_connectors,
available_document_types=self.available_document_types,
top_k=self.top_k,
start_date=start_date,
end_date=end_date,
)
# --- 3. Merge: mentioned first, then search (dedup by doc id) ---
seen_doc_ids: set[int] = set()

View file

@ -16,6 +16,7 @@ import contextlib
import logging
import os
import shutil
import threading
from pathlib import Path
from daytona import (
@ -55,9 +56,16 @@ class _TimeoutAwareSandbox(DaytonaSandbox):
) -> ExecuteResponse: # type: ignore[override]
return await asyncio.to_thread(self.execute, command, timeout=timeout)
def download_file(self, path: str) -> bytes:
"""Download a file from the sandbox filesystem."""
return self._sandbox.fs.download_file(path)
_daytona_client: Daytona | None = None
_client_lock = threading.Lock()
_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {}
_sandbox_locks: dict[str, asyncio.Lock] = {}
_sandbox_locks_mu = asyncio.Lock()
_seeded_files: dict[str, dict[str, str]] = {}
_SANDBOX_CACHE_MAX_SIZE = 20
THREAD_LABEL_KEY = "surfsense_thread"
@ -70,14 +78,15 @@ def is_sandbox_enabled() -> bool:
def _get_client() -> Daytona:
global _daytona_client
if _daytona_client is None:
config = DaytonaConfig(
api_key=os.environ.get("DAYTONA_API_KEY", ""),
api_url=os.environ.get("DAYTONA_API_URL", "https://app.daytona.io/api"),
target=os.environ.get("DAYTONA_TARGET", "us"),
)
_daytona_client = Daytona(config)
return _daytona_client
with _client_lock:
if _daytona_client is None:
config = DaytonaConfig(
api_key=os.environ.get("DAYTONA_API_KEY", ""),
api_url=os.environ.get("DAYTONA_API_URL", "https://app.daytona.io/api"),
target=os.environ.get("DAYTONA_TARGET", "us"),
)
_daytona_client = Daytona(config)
return _daytona_client
def _sandbox_create_params(
@ -129,14 +138,16 @@ def _find_or_create(thread_id: str) -> tuple[_TimeoutAwareSandbox, bool]:
try:
client.delete(sandbox)
except Exception:
logger.debug("Could not delete broken sandbox %s", sandbox.id, exc_info=True)
logger.debug(
"Could not delete broken sandbox %s", sandbox.id, exc_info=True
)
sandbox = client.create(_sandbox_create_params(labels))
is_new = True
logger.info("Created replacement sandbox: %s", sandbox.id)
elif sandbox.state != SandboxState.STARTED:
sandbox.wait_for_sandbox_start(timeout=60)
except Exception:
except DaytonaError:
logger.info("No existing sandbox for thread %s — creating one", thread_id)
sandbox = client.create(_sandbox_create_params(labels))
is_new = True
@ -145,6 +156,16 @@ def _find_or_create(thread_id: str) -> tuple[_TimeoutAwareSandbox, bool]:
return _TimeoutAwareSandbox(sandbox=sandbox), is_new
async def _get_thread_lock(key: str) -> asyncio.Lock:
"""Return a per-thread asyncio lock, creating one if needed."""
async with _sandbox_locks_mu:
lock = _sandbox_locks.get(key)
if lock is None:
lock = asyncio.Lock()
_sandbox_locks[key] = lock
return lock
async def get_or_create_sandbox(
thread_id: int | str,
) -> tuple[_TimeoutAwareSandbox, bool]:
@ -152,25 +173,52 @@ async def get_or_create_sandbox(
Uses an in-process cache keyed by thread_id so subsequent messages
in the same conversation reuse the sandbox object without an API call.
A per-thread async lock prevents duplicate sandbox creation from
concurrent requests.
Returns:
Tuple of (sandbox, is_new). *is_new* is True when a fresh sandbox
was created, signalling that file tracking should be reset.
"""
key = str(thread_id)
cached = _sandbox_cache.get(key)
if cached is not None:
logger.info("Reusing cached sandbox for thread %s", key)
return cached, False
sandbox, is_new = await asyncio.to_thread(_find_or_create, key)
_sandbox_cache[key] = sandbox
lock = await _get_thread_lock(key)
if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE:
oldest_key = next(iter(_sandbox_cache))
_sandbox_cache.pop(oldest_key, None)
logger.debug("Evicted oldest sandbox cache entry: %s", oldest_key)
async with lock:
cached = _sandbox_cache.get(key)
if cached is not None:
logger.info("Reusing cached sandbox for thread %s", key)
return cached, False
sandbox, is_new = await asyncio.to_thread(_find_or_create, key)
_sandbox_cache[key] = sandbox
return sandbox, is_new
if len(_sandbox_cache) > _SANDBOX_CACHE_MAX_SIZE:
oldest_key = next(iter(_sandbox_cache))
if oldest_key != key:
evicted = _sandbox_cache.pop(oldest_key, None)
_seeded_files.pop(oldest_key, None)
logger.debug("Evicted sandbox cache entry: %s", oldest_key)
if evicted is not None:
_schedule_sandbox_delete(evicted)
return sandbox, is_new
def _schedule_sandbox_delete(sandbox: _TimeoutAwareSandbox) -> None:
"""Best-effort background deletion of an evicted sandbox."""
def _delete() -> None:
try:
client = _get_client()
client.delete(sandbox._sandbox)
logger.info("Deleted evicted sandbox: %s", sandbox._sandbox.id)
except Exception:
logger.debug("Could not delete evicted sandbox", exc_info=True)
try:
loop = asyncio.get_running_loop()
loop.run_in_executor(None, _delete)
except RuntimeError:
pass
async def sync_files_to_sandbox(

View file

@ -2,10 +2,10 @@ import logging
from typing import Any
from langchain_core.tools import tool
from app.agents.new_chat.tools.hitl import request_approval
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.services.confluence import ConfluenceToolMetadataService

View file

@ -2,10 +2,10 @@ import logging
from typing import Any
from langchain_core.tools import tool
from app.agents.new_chat.tools.hitl import request_approval
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.services.confluence import ConfluenceToolMetadataService

View file

@ -2,10 +2,10 @@ import logging
from typing import Any
from langchain_core.tools import tool
from app.agents.new_chat.tools.hitl import request_approval
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.services.confluence import ConfluenceToolMetadataService

View file

@ -5,10 +5,10 @@ from pathlib import Path
from typing import Any, Literal
from langchain_core.tools import tool
from app.agents.new_chat.tools.hitl import request_approval
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.dropbox.client import DropboxClient
from app.db import SearchSourceConnector, SearchSourceConnectorType

View file

@ -2,11 +2,11 @@ import logging
from typing import Any
from langchain_core.tools import tool
from app.agents.new_chat.tools.hitl import request_approval
from sqlalchemy import String, and_, cast, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.dropbox.client import DropboxClient
from app.db import (
Document,

View file

@ -6,9 +6,9 @@ from email.mime.text import MIMEText
from typing import Any
from langchain_core.tools import tool
from app.agents.new_chat.tools.hitl import request_approval
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)

View file

@ -6,9 +6,9 @@ from email.mime.text import MIMEText
from typing import Any
from langchain_core.tools import tool
from app.agents.new_chat.tools.hitl import request_approval
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)

View file

@ -4,9 +4,9 @@ from datetime import datetime
from typing import Any
from langchain_core.tools import tool
from app.agents.new_chat.tools.hitl import request_approval
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)

View file

@ -6,9 +6,9 @@ from email.mime.text import MIMEText
from typing import Any
from langchain_core.tools import tool
from app.agents.new_chat.tools.hitl import request_approval
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)

View file

@ -150,7 +150,9 @@ def create_update_calendar_event_tool(
final_new_end_datetime = result.params.get(
"new_end_datetime", new_end_datetime
)
final_new_description = result.params.get("new_description", new_description)
final_new_description = result.params.get(
"new_description", new_description
)
final_new_location = result.params.get("new_location", new_location)
final_new_attendees = result.params.get("new_attendees", new_attendees)

View file

@ -58,7 +58,9 @@ def _parse_decision(approval: Any) -> tuple[str, dict[str, Any]]:
raise ValueError("No approval decision received")
decision = decisions[0]
decision_type: str = decision.get("type") or decision.get("decision_type") or "approve"
decision_type: str = (
decision.get("type") or decision.get("decision_type") or "approve"
)
edited_params: dict[str, Any] = {}
edited_action = decision.get("edited_action")

View file

@ -3,10 +3,10 @@ import logging
from typing import Any
from langchain_core.tools import tool
from app.agents.new_chat.tools.hitl import request_approval
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector
from app.services.jira import JiraToolMetadataService

View file

@ -3,10 +3,10 @@ import logging
from typing import Any
from langchain_core.tools import tool
from app.agents.new_chat.tools.hitl import request_approval
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector
from app.services.jira import JiraToolMetadataService

View file

@ -3,10 +3,10 @@ import logging
from typing import Any
from langchain_core.tools import tool
from app.agents.new_chat.tools.hitl import request_approval
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.jira_history import JiraHistoryConnector
from app.services.jira import JiraToolMetadataService

View file

@ -2,9 +2,9 @@ import logging
from typing import Any
from langchain_core.tools import tool
from app.agents.new_chat.tools.hitl import request_approval
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.linear_connector import LinearAPIError, LinearConnector
from app.services.linear import LinearToolMetadataService

View file

@ -2,9 +2,9 @@ import logging
from typing import Any
from langchain_core.tools import tool
from app.agents.new_chat.tools.hitl import request_approval
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.linear_connector import LinearAPIError, LinearConnector
from app.services.linear import LinearToolMetadataService

View file

@ -2,9 +2,9 @@ import logging
from typing import Any
from langchain_core.tools import tool
from app.agents.new_chat.tools.hitl import request_approval
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.linear_connector import LinearAPIError, LinearConnector
from app.services.linear import LinearKBSyncService, LinearToolMetadataService
@ -157,9 +157,13 @@ def create_update_linear_issue_tool(
final_issue_id = result.params.get("issue_id", issue_id)
final_document_id = result.params.get("document_id", document_id)
final_new_title = result.params.get("new_title", new_title)
final_new_description = result.params.get("new_description", new_description)
final_new_description = result.params.get(
"new_description", new_description
)
final_new_state_id = result.params.get("new_state_id", new_state_id)
final_new_assignee_id = result.params.get("new_assignee_id", new_assignee_id)
final_new_assignee_id = result.params.get(
"new_assignee_id", new_assignee_id
)
final_new_priority = result.params.get("new_priority", new_priority)
final_new_label_ids: list[str] | None = result.params.get(
"new_label_ids", new_label_ids

View file

@ -2,9 +2,9 @@ import logging
from typing import Any
from langchain_core.tools import tool
from app.agents.new_chat.tools.hitl import request_approval
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
from app.services.notion import NotionToolMetadataService

View file

@ -2,9 +2,9 @@ import logging
from typing import Any
from langchain_core.tools import tool
from app.agents.new_chat.tools.hitl import request_approval
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
from app.services.notion.tool_metadata_service import NotionToolMetadataService

View file

@ -2,9 +2,9 @@ import logging
from typing import Any
from langchain_core.tools import tool
from app.agents.new_chat.tools.hitl import request_approval
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
from app.services.notion import NotionToolMetadataService

View file

@ -5,10 +5,10 @@ from pathlib import Path
from typing import Any
from langchain_core.tools import tool
from app.agents.new_chat.tools.hitl import request_approval
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.onedrive.client import OneDriveClient
from app.db import SearchSourceConnector, SearchSourceConnectorType

View file

@ -2,11 +2,11 @@ import logging
from typing import Any
from langchain_core.tools import tool
from app.agents.new_chat.tools.hitl import request_approval
from sqlalchemy import String, and_, cast, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.new_chat.tools.hitl import request_approval
from app.connectors.onedrive.client import OneDriveClient
from app.db import (
Document,

View file

@ -1383,6 +1383,10 @@ class SearchSpace(BaseModel, TimestampMixin):
Integer, nullable=True, default=0
) # For vision/screenshot analysis, defaults to Auto mode
ai_file_sort_enabled = Column(
Boolean, nullable=False, default=False, server_default="false"
)
user_id = Column(
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
)

View file

@ -422,6 +422,8 @@ class IndexingPipelineService:
)
log_index_success(ctx, chunk_count=len(chunks))
await self._enqueue_ai_sort_if_enabled(document)
except RETRYABLE_LLM_ERRORS as e:
log_retryable_llm_error(ctx, e)
await rollback_and_persist_failure(
@ -457,6 +459,29 @@ class IndexingPipelineService:
return document
async def _enqueue_ai_sort_if_enabled(self, document: Document) -> None:
"""Fire-and-forget: enqueue incremental AI sort if the search space has it enabled."""
try:
from app.db import SearchSpace
result = await self.session.execute(
select(SearchSpace.ai_file_sort_enabled).where(
SearchSpace.id == document.search_space_id
)
)
enabled = result.scalar()
if not enabled:
return
from app.tasks.celery_tasks.document_tasks import ai_sort_document_task
user_id = str(document.created_by_id) if document.created_by_id else ""
ai_sort_document_task.delay(document.search_space_id, user_id, document.id)
except Exception:
logging.getLogger(__name__).warning(
"Failed to enqueue AI sort for document %s", document.id, exc_info=True
)
async def index_batch_parallel(
self,
connector_docs: list[ConnectorDocument],

View file

@ -20,7 +20,9 @@ router = APIRouter()
@router.get("/search-spaces/{search_space_id}/export")
async def export_knowledge_base(
search_space_id: int,
folder_id: int | None = Query(None, description="Export only this folder's subtree"),
folder_id: int | None = Query(
None, description="Export only this folder's subtree"
),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):

View file

@ -86,9 +86,8 @@ async def download_sandbox_file(
# Fall back to live sandbox download
try:
sandbox = await get_or_create_sandbox(thread_id)
raw_sandbox = sandbox._sandbox
content: bytes = await asyncio.to_thread(raw_sandbox.fs.download_file, path)
sandbox, _ = await get_or_create_sandbox(thread_id)
content: bytes = await asyncio.to_thread(sandbox.download_file, path)
except Exception as exc:
logger.warning("Sandbox file download failed for %s: %s", path, exc)
raise HTTPException(

View file

@ -216,6 +216,7 @@ async def read_search_spaces(
user_id=space.user_id,
citations_enabled=space.citations_enabled,
qna_custom_instructions=space.qna_custom_instructions,
ai_file_sort_enabled=space.ai_file_sort_enabled,
member_count=member_count,
is_owner=is_owner,
)
@ -384,6 +385,42 @@ async def edit_team_memory(
return db_search_space
@router.post("/searchspaces/{search_space_id}/ai-sort")
async def trigger_ai_sort(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Trigger a full AI file sort for all documents in the search space."""
try:
await check_permission(
session,
user,
search_space_id,
Permission.SETTINGS_UPDATE.value,
"You don't have permission to trigger AI sort on this search space",
)
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
db_search_space = result.scalars().first()
if not db_search_space:
raise HTTPException(status_code=404, detail="Search space not found")
from app.tasks.celery_tasks.document_tasks import ai_sort_search_space_task
ai_sort_search_space_task.delay(search_space_id, str(user.id))
return {"message": "AI sort started"}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to trigger AI sort: {e!s}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"Failed to trigger AI sort: {e!s}"
) from e
@router.delete("/searchspaces/{search_space_id}", response_model=dict)
async def delete_search_space(
search_space_id: int,

View file

@ -22,6 +22,7 @@ class SearchSpaceUpdate(BaseModel):
citations_enabled: bool | None = None
qna_custom_instructions: str | None = None
shared_memory_md: str | None = None
ai_file_sort_enabled: bool | None = None
class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
@ -31,6 +32,7 @@ class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
citations_enabled: bool
qna_custom_instructions: str | None = None
shared_memory_md: str | None = None
ai_file_sort_enabled: bool = False
model_config = ConfigDict(from_attributes=True)

View file

@ -0,0 +1,329 @@
"""AI File Sort Service: builds connector-type/date/category/subcategory folder paths."""
from __future__ import annotations
import json
import logging
import re
from datetime import UTC, datetime
from langchain_core.messages import HumanMessage
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.db import (
Chunk,
Document,
DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
)
from app.services.folder_service import ensure_folder_hierarchy_with_depth_validation
logger = logging.getLogger(__name__)
_DOCTYPE_TO_CONNECTOR_LABEL: dict[str, str] = {
DocumentType.EXTENSION: "Browser Extension",
DocumentType.CRAWLED_URL: "Web Crawl",
DocumentType.FILE: "File Upload",
DocumentType.SLACK_CONNECTOR: "Slack",
DocumentType.TEAMS_CONNECTOR: "Teams",
DocumentType.ONEDRIVE_FILE: "OneDrive",
DocumentType.NOTION_CONNECTOR: "Notion",
DocumentType.YOUTUBE_VIDEO: "YouTube",
DocumentType.GITHUB_CONNECTOR: "GitHub",
DocumentType.LINEAR_CONNECTOR: "Linear",
DocumentType.DISCORD_CONNECTOR: "Discord",
DocumentType.JIRA_CONNECTOR: "Jira",
DocumentType.CONFLUENCE_CONNECTOR: "Confluence",
DocumentType.CLICKUP_CONNECTOR: "ClickUp",
DocumentType.GOOGLE_CALENDAR_CONNECTOR: "Google Calendar",
DocumentType.GOOGLE_GMAIL_CONNECTOR: "Gmail",
DocumentType.GOOGLE_DRIVE_FILE: "Google Drive",
DocumentType.AIRTABLE_CONNECTOR: "Airtable",
DocumentType.LUMA_CONNECTOR: "Luma",
DocumentType.ELASTICSEARCH_CONNECTOR: "Elasticsearch",
DocumentType.BOOKSTACK_CONNECTOR: "BookStack",
DocumentType.CIRCLEBACK: "Circleback",
DocumentType.OBSIDIAN_CONNECTOR: "Obsidian",
DocumentType.NOTE: "Notes",
DocumentType.DROPBOX_FILE: "Dropbox",
DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: "Google Drive (Composio)",
DocumentType.COMPOSIO_GMAIL_CONNECTOR: "Gmail (Composio)",
DocumentType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: "Google Calendar (Composio)",
DocumentType.LOCAL_FOLDER_FILE: "Local Folder",
}
_CONNECTOR_TYPE_LABEL: dict[str, str] = {
SearchSourceConnectorType.SERPER_API: "Serper Search",
SearchSourceConnectorType.TAVILY_API: "Tavily Search",
SearchSourceConnectorType.SEARXNG_API: "SearXNG Search",
SearchSourceConnectorType.LINKUP_API: "Linkup Search",
SearchSourceConnectorType.BAIDU_SEARCH_API: "Baidu Search",
SearchSourceConnectorType.SLACK_CONNECTOR: "Slack",
SearchSourceConnectorType.TEAMS_CONNECTOR: "Teams",
SearchSourceConnectorType.ONEDRIVE_CONNECTOR: "OneDrive",
SearchSourceConnectorType.NOTION_CONNECTOR: "Notion",
SearchSourceConnectorType.GITHUB_CONNECTOR: "GitHub",
SearchSourceConnectorType.LINEAR_CONNECTOR: "Linear",
SearchSourceConnectorType.DISCORD_CONNECTOR: "Discord",
SearchSourceConnectorType.JIRA_CONNECTOR: "Jira",
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "Confluence",
SearchSourceConnectorType.CLICKUP_CONNECTOR: "ClickUp",
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: "Google Calendar",
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: "Gmail",
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR: "Google Drive",
SearchSourceConnectorType.AIRTABLE_CONNECTOR: "Airtable",
SearchSourceConnectorType.LUMA_CONNECTOR: "Luma",
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: "Elasticsearch",
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: "Web Crawl",
SearchSourceConnectorType.BOOKSTACK_CONNECTOR: "BookStack",
SearchSourceConnectorType.CIRCLEBACK_CONNECTOR: "Circleback",
SearchSourceConnectorType.OBSIDIAN_CONNECTOR: "Obsidian",
SearchSourceConnectorType.MCP_CONNECTOR: "MCP",
SearchSourceConnectorType.DROPBOX_CONNECTOR: "Dropbox",
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: "Google Drive (Composio)",
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR: "Gmail (Composio)",
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: "Google Calendar (Composio)",
}
_MAX_CONTENT_CHARS = 4000
_MAX_CHUNKS_FOR_CONTEXT = 5
_CATEGORY_PROMPT = (
"Based on the document information below, classify it into a broad category "
"and a more specific subcategory.\n\n"
"Rules:\n"
"- category: 1-2 word broad theme (e.g. Science, Finance, Engineering, Communication, Media)\n"
"- subcategory: 1-2 word specific topic within the category "
"(e.g. Physics, Tax Reports, Backend, Team Updates)\n"
"- Use nouns only. Do not include generic terms like 'General' or 'Miscellaneous'.\n\n"
"Title: {title}\n\n"
"Content: {summary}\n\n"
'Respond with ONLY a JSON object: {{"category": "...", "subcategory": "..."}}'
)
_SAFE_NAME_RE = re.compile(r"[^a-zA-Z0-9 _\-()]")
_FALLBACK_CATEGORY = "Uncategorized"
_FALLBACK_SUBCATEGORY = "General"
def resolve_root_folder_label(
document: Document, connector: SearchSourceConnector | None
) -> str:
if connector is not None:
return _CONNECTOR_TYPE_LABEL.get(
connector.connector_type, str(connector.connector_type)
)
return _DOCTYPE_TO_CONNECTOR_LABEL.get(
document.document_type, str(document.document_type)
)
def resolve_date_folder(document: Document) -> str:
ts = document.updated_at or document.created_at
if ts is None:
ts = datetime.now(UTC)
return ts.strftime("%Y-%m-%d")
def sanitize_category_folder_name(
value: str | None, fallback: str = _FALLBACK_CATEGORY
) -> str:
if not value or not value.strip():
return fallback
cleaned = _SAFE_NAME_RE.sub("", value.strip())
cleaned = " ".join(cleaned.split())
if not cleaned:
return fallback
return cleaned[:50]
async def _resolve_document_text(
session: AsyncSession,
document: Document,
) -> str:
"""Build the best available text representation for taxonomy generation.
Prefers ``document.content``; falls back to joining the first few chunks
when content is empty or too short to be useful.
"""
text = (document.content or "").strip()
if len(text) >= 100:
return text[:_MAX_CONTENT_CHARS]
stmt = (
select(Chunk.content)
.where(Chunk.document_id == document.id)
.order_by(Chunk.id)
.limit(_MAX_CHUNKS_FOR_CONTEXT)
)
result = await session.execute(stmt)
chunk_texts = [row[0] for row in result.all() if row[0]]
if chunk_texts:
combined = "\n\n".join(chunk_texts)
return combined[:_MAX_CONTENT_CHARS]
return text[:_MAX_CONTENT_CHARS]
def _get_cached_taxonomy(document: Document) -> tuple[str, str] | None:
"""Return (category, subcategory) from document metadata cache, or None."""
meta = document.document_metadata
if not isinstance(meta, dict):
return None
cat = meta.get("ai_sort_category")
subcat = meta.get("ai_sort_subcategory")
if cat and subcat and isinstance(cat, str) and isinstance(subcat, str):
return cat, subcat
return None
def _set_cached_taxonomy(document: Document, category: str, subcategory: str) -> None:
"""Persist the AI taxonomy on document metadata for deterministic re-sorts."""
meta = dict(document.document_metadata or {})
meta["ai_sort_category"] = category
meta["ai_sort_subcategory"] = subcategory
document.document_metadata = meta
async def generate_ai_taxonomy(
title: str,
summary_or_content: str,
llm,
) -> tuple[str, str]:
"""Return (category, subcategory) using a single structured LLM call."""
text = (summary_or_content or "").strip()
if not text:
return _FALLBACK_CATEGORY, _FALLBACK_SUBCATEGORY
if len(text) > _MAX_CONTENT_CHARS:
text = text[:_MAX_CONTENT_CHARS]
prompt = _CATEGORY_PROMPT.format(title=title or "Untitled", summary=text)
try:
result = await llm.ainvoke([HumanMessage(content=prompt)])
raw = result.content.strip()
if raw.startswith("```"):
raw = raw.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
parsed = json.loads(raw)
category = sanitize_category_folder_name(
parsed.get("category"), _FALLBACK_CATEGORY
)
subcategory = sanitize_category_folder_name(
parsed.get("subcategory"), _FALLBACK_SUBCATEGORY
)
return category, subcategory
except Exception:
logger.warning("AI taxonomy generation failed, using fallback", exc_info=True)
return _FALLBACK_CATEGORY, _FALLBACK_SUBCATEGORY
def _build_path_segments(
root_label: str,
date_label: str,
category: str,
subcategory: str,
) -> list[dict]:
return [
{"name": root_label, "metadata": {"ai_sort": True, "ai_sort_level": 1}},
{"name": date_label, "metadata": {"ai_sort": True, "ai_sort_level": 2}},
{"name": category, "metadata": {"ai_sort": True, "ai_sort_level": 3}},
{"name": subcategory, "metadata": {"ai_sort": True, "ai_sort_level": 4}},
]
async def _resolve_taxonomy(
session: AsyncSession,
document: Document,
llm,
) -> tuple[str, str]:
"""Return (category, subcategory), reusing cached values when available."""
cached = _get_cached_taxonomy(document)
if cached is not None:
return cached
content_text = await _resolve_document_text(session, document)
category, subcategory = await generate_ai_taxonomy(
document.title, content_text, llm
)
_set_cached_taxonomy(document, category, subcategory)
return category, subcategory
async def ai_sort_document(
session: AsyncSession,
document: Document,
llm,
) -> Document:
"""Sort a single document into the 4-level AI folder hierarchy."""
connector: SearchSourceConnector | None = None
if document.connector_id is not None:
connector = await session.get(SearchSourceConnector, document.connector_id)
root_label = resolve_root_folder_label(document, connector)
date_label = resolve_date_folder(document)
category, subcategory = await _resolve_taxonomy(session, document, llm)
segments = _build_path_segments(root_label, date_label, category, subcategory)
leaf_folder = await ensure_folder_hierarchy_with_depth_validation(
session,
document.search_space_id,
segments,
)
document.folder_id = leaf_folder.id
await session.flush()
return document
async def ai_sort_all_documents(
session: AsyncSession,
search_space_id: int,
llm,
) -> tuple[int, int]:
"""Sort all documents in a search space. Returns (sorted_count, failed_count)."""
stmt = (
select(Document)
.where(Document.search_space_id == search_space_id)
.options(selectinload(Document.connector))
)
result = await session.execute(stmt)
documents = list(result.scalars().all())
sorted_count = 0
failed_count = 0
for doc in documents:
try:
connector = doc.connector
root_label = resolve_root_folder_label(doc, connector)
date_label = resolve_date_folder(doc)
category, subcategory = await _resolve_taxonomy(session, doc, llm)
segments = _build_path_segments(
root_label, date_label, category, subcategory
)
leaf_folder = await ensure_folder_hierarchy_with_depth_validation(
session,
search_space_id,
segments,
)
doc.folder_id = leaf_folder.id
sorted_count += 1
except Exception:
logger.error("Failed to AI-sort document %s", doc.id, exc_info=True)
failed_count += 1
await session.commit()
logger.info(
"AI sort complete for search_space=%d: sorted=%d, failed=%d",
search_space_id,
sorted_count,
failed_count,
)
return sorted_count, failed_count

View file

@ -142,6 +142,58 @@ async def generate_folder_position(
return generate_key_between(last_position, None)
async def ensure_folder_hierarchy_with_depth_validation(
session: AsyncSession,
search_space_id: int,
path_segments: list[dict],
) -> Folder:
"""Create or return a nested folder chain, validating depth at each step.
Each item in ``path_segments`` is a dict with:
- ``name`` (str): folder display name
- ``metadata`` (dict | None): optional ``folder_metadata`` JSONB payload
Returns the deepest (leaf) Folder in the chain.
"""
parent_id: int | None = None
current_folder: Folder | None = None
for segment in path_segments:
name = segment["name"]
metadata = segment.get("metadata")
stmt = select(Folder).where(
Folder.search_space_id == search_space_id,
Folder.name == name,
Folder.parent_id == parent_id
if parent_id is not None
else Folder.parent_id.is_(None),
)
result = await session.execute(stmt)
folder = result.scalar_one_or_none()
if folder is None:
await validate_folder_depth(session, parent_id, subtree_depth=0)
position = await generate_folder_position(
session, search_space_id, parent_id
)
folder = Folder(
name=name,
search_space_id=search_space_id,
parent_id=parent_id,
position=position,
folder_metadata=metadata,
)
session.add(folder)
await session.flush()
current_folder = folder
parent_id = folder.id
assert current_folder is not None, "path_segments must not be empty"
return current_folder
async def get_folder_subtree_ids(session: AsyncSession, folder_id: int) -> list[int]:
"""Return all folder IDs in the subtree rooted at folder_id (inclusive)."""
result = await session.execute(

View file

@ -4,6 +4,7 @@ import asyncio
import contextlib
import logging
import os
import time
from uuid import UUID
from app.celery_app import celery_app
@ -1551,3 +1552,121 @@ async def _index_uploaded_folder_files_async(
heartbeat_task.cancel()
if notification_id is not None:
_stop_heartbeat(notification_id)
# ===== AI File Sort tasks =====
AI_SORT_LOCK_TTL_SECONDS = 600 # 10 minutes
_ai_sort_redis = None
def _get_ai_sort_redis():
import redis
global _ai_sort_redis
if _ai_sort_redis is None:
_ai_sort_redis = redis.from_url(config.REDIS_APP_URL, decode_responses=True)
return _ai_sort_redis
def _ai_sort_lock_key(search_space_id: int) -> str:
return f"ai_sort:search_space:{search_space_id}:lock"
@celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1)
def ai_sort_search_space_task(self, search_space_id: int, user_id: str):
"""Full AI sort for all documents in a search space."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_ai_sort_search_space_async(search_space_id, user_id))
finally:
loop.close()
async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
r = _get_ai_sort_redis()
lock_key = _ai_sort_lock_key(search_space_id)
if not r.set(lock_key, "running", nx=True, ex=AI_SORT_LOCK_TTL_SECONDS):
logger.info(
"AI sort already running for search_space=%d, skipping",
search_space_id,
)
return
t_start = time.perf_counter()
try:
from app.services.ai_file_sort_service import ai_sort_all_documents
from app.services.llm_service import get_document_summary_llm
async with get_celery_session_maker()() as session:
llm = await get_document_summary_llm(
session, search_space_id, disable_streaming=True
)
if llm is None:
logger.warning(
"No LLM configured for search_space=%d, skipping AI sort",
search_space_id,
)
return
sorted_count, failed_count = await ai_sort_all_documents(
session, search_space_id, llm
)
elapsed = time.perf_counter() - t_start
logger.info(
"AI sort search_space=%d done in %.1fs: sorted=%d failed=%d",
search_space_id,
elapsed,
sorted_count,
failed_count,
)
finally:
r.delete(lock_key)
@celery_app.task(
name="ai_sort_document", bind=True, max_retries=2, default_retry_delay=10
)
def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int):
"""Incremental AI sort for a single document after indexing."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_ai_sort_document_async(search_space_id, user_id, document_id)
)
finally:
loop.close()
async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int):
from app.db import Document
from app.services.ai_file_sort_service import ai_sort_document
from app.services.llm_service import get_document_summary_llm
async with get_celery_session_maker()() as session:
document = await session.get(Document, document_id)
if document is None:
logger.warning("Document %d not found, skipping AI sort", document_id)
return
llm = await get_document_summary_llm(
session, search_space_id, disable_streaming=True
)
if llm is None:
logger.warning(
"No LLM for search_space=%d, skipping AI sort of doc=%d",
search_space_id,
document_id,
)
return
await ai_sort_document(session, document, llm)
await session.commit()
logger.info(
"AI sorted document=%d into search_space=%d",
document_id,
search_space_id,
)

View file

@ -61,6 +61,7 @@ from app.services.new_streaming_service import VercelStreamingService
from app.utils.content_utils import bootstrap_history_from_db
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
_background_tasks: set[asyncio.Task] = set()
_perf_log = get_perf_logger()
@ -142,7 +143,7 @@ class StreamResult:
accumulated_text: str = ""
is_interrupted: bool = False
interrupt_value: dict[str, Any] | None = None
sandbox_files: list[str] = field(default_factory=list) # unused, kept for compat
sandbox_files: list[str] = field(default_factory=list)
agent_called_update_memory: bool = False
@ -440,7 +441,7 @@ async def _stream_agent_events(
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "execute":
elif tool_name in ("execute", "execute_code"):
cmd = (
tool_input.get("command", "")
if isinstance(tool_input, dict)
@ -738,7 +739,7 @@ async def _stream_agent_events(
status="completed",
items=completed_items,
)
elif tool_name == "execute":
elif tool_name in ("execute", "execute_code"):
raw_text = (
tool_output.get("result", "")
if isinstance(tool_output, dict)
@ -985,7 +986,7 @@ async def _stream_agent_events(
if isinstance(tool_output, dict)
else {"result": tool_output},
)
elif tool_name == "execute":
elif tool_name in ("execute", "execute_code"):
raw_text = (
tool_output.get("result", "")
if isinstance(tool_output, dict)
@ -1598,7 +1599,7 @@ async def stream_new_chat(
# Shared threads write to team memory; private threads write to user memory.
if not stream_result.agent_called_update_memory:
if visibility == ChatVisibility.SEARCH_SPACE:
asyncio.create_task(
task = asyncio.create_task(
extract_and_save_team_memory(
user_message=user_query,
search_space_id=search_space_id,
@ -1606,14 +1607,18 @@ async def stream_new_chat(
author_display_name=current_user_display_name,
)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
elif user_id:
asyncio.create_task(
task = asyncio.create_task(
extract_and_save_memory(
user_message=user_query,
user_id=user_id,
llm=llm,
)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
# Finish the step and message
yield streaming_service.format_finish_step()
@ -1663,6 +1668,21 @@ async def stream_new_chat(
with contextlib.suppress(Exception):
await session.close()
# Persist any sandbox-produced files to local storage so they
# remain downloadable after the Daytona sandbox auto-deletes.
if stream_result and stream_result.sandbox_files:
with contextlib.suppress(Exception):
from app.agents.new_chat.sandbox import (
is_sandbox_enabled,
persist_and_delete_sandbox,
)
if is_sandbox_enabled():
with anyio.CancelScope(shield=True):
await persist_and_delete_sandbox(
chat_id, stream_result.sandbox_files
)
# Break circular refs held by the agent graph, tools, and LLM
# wrappers so the GC can reclaim them in a single pass.
agent = llm = connector_service = None

View file

@ -961,6 +961,7 @@ async def index_google_drive_files(
vision_llm = None
if connector_enable_vision_llm:
from app.services.llm_service import get_vision_llm
vision_llm = await get_vision_llm(session, search_space_id)
drive_client = GoogleDriveClient(
session, connector_id, credentials=pre_built_credentials
@ -1168,6 +1169,7 @@ async def index_google_drive_single_file(
vision_llm = None
if connector_enable_vision_llm:
from app.services.llm_service import get_vision_llm
vision_llm = await get_vision_llm(session, search_space_id)
drive_client = GoogleDriveClient(
session, connector_id, credentials=pre_built_credentials
@ -1306,6 +1308,7 @@ async def index_google_drive_selected_files(
vision_llm = None
if connector_enable_vision_llm:
from app.services.llm_service import get_vision_llm
vision_llm = await get_vision_llm(session, search_space_id)
drive_client = GoogleDriveClient(
session, connector_id, credentials=pre_built_credentials

View file

@ -1360,7 +1360,9 @@ async def index_uploaded_files(
try:
content, content_hash = await _compute_file_content_hash(
temp_path, filename, search_space_id,
temp_path,
filename,
search_space_id,
vision_llm=vision_llm_instance,
)
except Exception as e:

View file

@ -656,6 +656,7 @@ async def index_onedrive_files(
vision_llm = None
if connector_enable_vision_llm:
from app.services.llm_service import get_vision_llm
vision_llm = await get_vision_llm(session, search_space_id)
onedrive_client = OneDriveClient(session, connector_id)

View file

@ -21,7 +21,11 @@ from pathlib import Path
from dotenv import load_dotenv
_here = Path(__file__).parent
for candidate in [_here / "../surfsense_backend/.env", _here / ".env", _here / "../.env"]:
for candidate in [
_here / "../surfsense_backend/.env",
_here / ".env",
_here / "../.env",
]:
if candidate.exists():
load_dotenv(candidate)
break
@ -57,7 +61,10 @@ def main() -> None:
api_key = os.environ.get("DAYTONA_API_KEY")
if not api_key:
print("ERROR: DAYTONA_API_KEY is not set.", file=sys.stderr)
print("Add it to surfsense_backend/.env or export it in your shell.", file=sys.stderr)
print(
"Add it to surfsense_backend/.env or export it in your shell.",
file=sys.stderr,
)
sys.exit(1)
daytona = Daytona()
@ -67,7 +74,7 @@ def main() -> None:
print(f"Deleting existing snapshot '{SNAPSHOT_NAME}'")
daytona.snapshot.delete(existing)
print(f"Deleted '{SNAPSHOT_NAME}'. Waiting for removal to propagate …")
for attempt in range(30):
for _attempt in range(30):
time.sleep(2)
try:
daytona.snapshot.get(SNAPSHOT_NAME)
@ -75,7 +82,9 @@ def main() -> None:
print(f"Confirmed '{SNAPSHOT_NAME}' is gone.\n")
break
else:
print(f"WARNING: '{SNAPSHOT_NAME}' may still exist after 60s. Proceeding anyway.\n")
print(
f"WARNING: '{SNAPSHOT_NAME}' may still exist after 60s. Proceeding anyway.\n"
)
except Exception:
pass

View file

@ -431,7 +431,9 @@ async def test_llamacloud_heif_accepted_only_with_azure_di(tmp_path, mocker):
mocker.patch("app.config.config.AZURE_DI_ENDPOINT", None, create=True)
mocker.patch("app.config.config.AZURE_DI_KEY", None, create=True)
with pytest.raises(EtlUnsupportedFileError, match="document parser does not support this format"):
with pytest.raises(
EtlUnsupportedFileError, match="document parser does not support this format"
):
await EtlPipelineService().extract(
EtlRequest(file_path=str(heif_file), filename="photo.heif")
)

View file

@ -6,6 +6,7 @@ import pytest
from langchain_core.messages import AIMessage, HumanMessage
from app.agents.new_chat.middleware.knowledge_search import (
KBSearchPlan,
KnowledgeBaseSearchMiddleware,
_build_document_xml,
_normalize_optional_date_range,
@ -366,3 +367,146 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
assert captured["query"] == "deel founders guide summary"
assert captured["start_date"] is None
assert captured["end_date"] is None
async def test_middleware_routes_to_recency_browse_when_flagged(
self,
monkeypatch,
):
"""When the planner sets is_recency_query=true, browse_recent_documents
is called instead of search_knowledge_base."""
browse_captured: dict = {}
search_called = False
async def fake_browse_recent_documents(**kwargs):
browse_captured.update(kwargs)
return []
async def fake_search_knowledge_base(**kwargs):
nonlocal search_called
search_called = True
return []
async def fake_build_scoped_filesystem(**kwargs):
return {}, {}
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.browse_recent_documents",
fake_browse_recent_documents,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
fake_build_scoped_filesystem,
)
llm = FakeLLM(
json.dumps(
{
"optimized_query": "latest uploaded file",
"start_date": None,
"end_date": None,
"is_recency_query": True,
}
)
)
middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=42)
result = await middleware.abefore_agent(
{"messages": [HumanMessage(content="what's my latest file?")]},
runtime=None,
)
assert result is not None
assert browse_captured["search_space_id"] == 42
assert not search_called
async def test_middleware_uses_hybrid_search_when_not_recency(
self,
monkeypatch,
):
"""When is_recency_query is false (default), hybrid search is used."""
search_captured: dict = {}
browse_called = False
async def fake_browse_recent_documents(**kwargs):
nonlocal browse_called
browse_called = True
return []
async def fake_search_knowledge_base(**kwargs):
search_captured.update(kwargs)
return []
async def fake_build_scoped_filesystem(**kwargs):
return {}, {}
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.browse_recent_documents",
fake_browse_recent_documents,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
fake_build_scoped_filesystem,
)
llm = FakeLLM(
json.dumps(
{
"optimized_query": "quarterly revenue report analysis",
"start_date": None,
"end_date": None,
"is_recency_query": False,
}
)
)
middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=42)
await middleware.abefore_agent(
{"messages": [HumanMessage(content="find the quarterly revenue report")]},
runtime=None,
)
assert search_captured["query"] == "quarterly revenue report analysis"
assert not browse_called
# ── KBSearchPlan schema ────────────────────────────────────────────────
class TestKBSearchPlanSchema:
def test_is_recency_query_defaults_to_false(self):
plan = KBSearchPlan(optimized_query="test query")
assert plan.is_recency_query is False
def test_is_recency_query_parses_true(self):
plan = _parse_kb_search_plan_response(
json.dumps(
{
"optimized_query": "latest uploaded file",
"start_date": None,
"end_date": None,
"is_recency_query": True,
}
)
)
assert plan.is_recency_query is True
assert plan.optimized_query == "latest uploaded file"
def test_missing_is_recency_query_defaults_to_false(self):
plan = _parse_kb_search_plan_response(
json.dumps(
{
"optimized_query": "meeting notes",
"start_date": None,
"end_date": None,
}
)
)
assert plan.is_recency_query is False

View file

@ -0,0 +1,275 @@
"""Unit tests for AI file sort service: folder label resolution, date extraction, category sanitization."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
pytestmark = pytest.mark.unit
# ── resolve_root_folder_label ──
def _make_document(document_type: str, connector_id=None):
doc = MagicMock()
doc.document_type = document_type
doc.connector_id = connector_id
return doc
def _make_connector(connector_type: str):
conn = MagicMock()
conn.connector_type = connector_type
return conn
def test_root_label_uses_connector_type_when_available():
from app.services.ai_file_sort_service import resolve_root_folder_label
doc = _make_document("FILE", connector_id=1)
conn = _make_connector("GOOGLE_DRIVE_CONNECTOR")
assert resolve_root_folder_label(doc, conn) == "Google Drive"
def test_root_label_falls_back_to_document_type():
from app.services.ai_file_sort_service import resolve_root_folder_label
doc = _make_document("SLACK_CONNECTOR")
assert resolve_root_folder_label(doc, None) == "Slack"
def test_root_label_unknown_doctype_returns_raw_value():
from app.services.ai_file_sort_service import resolve_root_folder_label
doc = _make_document("UNKNOWN_TYPE")
assert resolve_root_folder_label(doc, None) == "UNKNOWN_TYPE"
# ── resolve_date_folder ──
def test_date_folder_from_updated_at():
from app.services.ai_file_sort_service import resolve_date_folder
doc = MagicMock()
doc.updated_at = datetime(2025, 3, 15, 10, 30, 0, tzinfo=UTC)
doc.created_at = datetime(2025, 1, 1, 0, 0, 0, tzinfo=UTC)
assert resolve_date_folder(doc) == "2025-03-15"
def test_date_folder_falls_back_to_created_at():
from app.services.ai_file_sort_service import resolve_date_folder
doc = MagicMock()
doc.updated_at = None
doc.created_at = datetime(2024, 12, 25, 23, 59, 0, tzinfo=UTC)
assert resolve_date_folder(doc) == "2024-12-25"
def test_date_folder_both_none_uses_today():
from app.services.ai_file_sort_service import resolve_date_folder
doc = MagicMock()
doc.updated_at = None
doc.created_at = None
result = resolve_date_folder(doc)
today = datetime.now(UTC).strftime("%Y-%m-%d")
assert result == today
# ── sanitize_category_folder_name ──
def test_sanitize_normal_value():
from app.services.ai_file_sort_service import sanitize_category_folder_name
assert sanitize_category_folder_name("Machine Learning") == "Machine Learning"
def test_sanitize_strips_special_chars():
from app.services.ai_file_sort_service import sanitize_category_folder_name
assert sanitize_category_folder_name("Tax/Reports!") == "TaxReports"
def test_sanitize_empty_returns_fallback():
from app.services.ai_file_sort_service import sanitize_category_folder_name
assert sanitize_category_folder_name("") == "Uncategorized"
assert sanitize_category_folder_name(None) == "Uncategorized"
def test_sanitize_truncates_long_names():
from app.services.ai_file_sort_service import sanitize_category_folder_name
long_name = "A" * 100
result = sanitize_category_folder_name(long_name)
assert len(result) <= 50
# ── generate_ai_taxonomy ──
@pytest.mark.asyncio
async def test_generate_ai_taxonomy_parses_json():
from app.services.ai_file_sort_service import generate_ai_taxonomy
mock_llm = AsyncMock()
mock_result = MagicMock()
mock_result.content = '{"category": "Science", "subcategory": "Physics"}'
mock_llm.ainvoke.return_value = mock_result
cat, sub = await generate_ai_taxonomy(
"Physics Paper", "Some science document about physics", mock_llm
)
assert cat == "Science"
assert sub == "Physics"
@pytest.mark.asyncio
async def test_generate_ai_taxonomy_handles_markdown_code_block():
from app.services.ai_file_sort_service import generate_ai_taxonomy
mock_llm = AsyncMock()
mock_result = MagicMock()
mock_result.content = (
'```json\n{"category": "Finance", "subcategory": "Tax Reports"}\n```'
)
mock_llm.ainvoke.return_value = mock_result
cat, sub = await generate_ai_taxonomy("Tax Doc", "A tax report document", mock_llm)
assert cat == "Finance"
assert sub == "Tax Reports"
@pytest.mark.asyncio
async def test_generate_ai_taxonomy_includes_title_in_prompt():
from app.services.ai_file_sort_service import generate_ai_taxonomy
mock_llm = AsyncMock()
mock_result = MagicMock()
mock_result.content = '{"category": "Engineering", "subcategory": "Backend"}'
mock_llm.ainvoke.return_value = mock_result
await generate_ai_taxonomy("API Design Guide", "content about REST APIs", mock_llm)
prompt_text = mock_llm.ainvoke.call_args[0][0][0].content
assert "API Design Guide" in prompt_text
assert "content about REST APIs" in prompt_text
@pytest.mark.asyncio
async def test_generate_ai_taxonomy_fallback_on_error():
from app.services.ai_file_sort_service import generate_ai_taxonomy
mock_llm = AsyncMock()
mock_llm.ainvoke.side_effect = RuntimeError("LLM down")
cat, sub = await generate_ai_taxonomy("Title", "some content", mock_llm)
assert cat == "Uncategorized"
assert sub == "General"
@pytest.mark.asyncio
async def test_generate_ai_taxonomy_fallback_on_empty_content():
from app.services.ai_file_sort_service import generate_ai_taxonomy
mock_llm = AsyncMock()
cat, sub = await generate_ai_taxonomy("Title", "", mock_llm)
assert cat == "Uncategorized"
assert sub == "General"
mock_llm.ainvoke.assert_not_called()
@pytest.mark.asyncio
async def test_generate_ai_taxonomy_fallback_on_invalid_json():
from app.services.ai_file_sort_service import generate_ai_taxonomy
mock_llm = AsyncMock()
mock_result = MagicMock()
mock_result.content = "not valid json at all"
mock_llm.ainvoke.return_value = mock_result
cat, sub = await generate_ai_taxonomy("Title", "some content", mock_llm)
assert cat == "Uncategorized"
assert sub == "General"
# ── taxonomy caching ──
def test_get_cached_taxonomy_returns_none_when_no_metadata():
from app.services.ai_file_sort_service import _get_cached_taxonomy
doc = MagicMock()
doc.document_metadata = None
assert _get_cached_taxonomy(doc) is None
def test_get_cached_taxonomy_returns_none_when_keys_missing():
from app.services.ai_file_sort_service import _get_cached_taxonomy
doc = MagicMock()
doc.document_metadata = {"some_other_key": "value"}
assert _get_cached_taxonomy(doc) is None
def test_get_cached_taxonomy_returns_cached_values():
from app.services.ai_file_sort_service import _get_cached_taxonomy
doc = MagicMock()
doc.document_metadata = {
"ai_sort_category": "Finance",
"ai_sort_subcategory": "Tax Reports",
}
assert _get_cached_taxonomy(doc) == ("Finance", "Tax Reports")
def test_set_cached_taxonomy_persists_on_metadata():
from app.services.ai_file_sort_service import _set_cached_taxonomy
doc = MagicMock()
doc.document_metadata = {"existing_key": "keep_me"}
_set_cached_taxonomy(doc, "Science", "Physics")
assert doc.document_metadata["ai_sort_category"] == "Science"
assert doc.document_metadata["ai_sort_subcategory"] == "Physics"
assert doc.document_metadata["existing_key"] == "keep_me"
def test_set_cached_taxonomy_creates_metadata_when_none():
from app.services.ai_file_sort_service import _set_cached_taxonomy
doc = MagicMock()
doc.document_metadata = None
_set_cached_taxonomy(doc, "Engineering", "Backend")
assert doc.document_metadata == {
"ai_sort_category": "Engineering",
"ai_sort_subcategory": "Backend",
}
# ── _build_path_segments ──
def test_build_path_segments_structure():
from app.services.ai_file_sort_service import _build_path_segments
segments = _build_path_segments("Google Drive", "2025-03-15", "Science", "Physics")
assert len(segments) == 4
assert segments[0] == {
"name": "Google Drive",
"metadata": {"ai_sort": True, "ai_sort_level": 1},
}
assert segments[1] == {
"name": "2025-03-15",
"metadata": {"ai_sort": True, "ai_sort_level": 2},
}
assert segments[2] == {
"name": "Science",
"metadata": {"ai_sort": True, "ai_sort_level": 3},
}
assert segments[3] == {
"name": "Physics",
"metadata": {"ai_sort": True, "ai_sort_level": 4},
}

View file

@ -0,0 +1,43 @@
"""Unit tests for AI sort task Redis deduplication lock."""
from unittest.mock import MagicMock, patch
import pytest
pytestmark = pytest.mark.unit
def test_lock_key_format():
from app.tasks.celery_tasks.document_tasks import _ai_sort_lock_key
key = _ai_sort_lock_key(42)
assert key == "ai_sort:search_space:42:lock"
def test_lock_prevents_duplicate_run():
"""When the Redis lock already exists, the task should skip execution."""
mock_redis = MagicMock()
mock_redis.set.return_value = False # Lock already held
with (
patch(
"app.tasks.celery_tasks.document_tasks._get_ai_sort_redis",
return_value=mock_redis,
),
patch(
"app.tasks.celery_tasks.document_tasks.get_celery_session_maker"
) as mock_session_maker,
):
import asyncio
from app.tasks.celery_tasks.document_tasks import _ai_sort_search_space_async
loop = asyncio.new_event_loop()
try:
loop.run_until_complete(_ai_sort_search_space_async(1, "user-123"))
finally:
loop.close()
# Session maker should never be called since lock was not acquired
mock_session_maker.assert_not_called()

View file

@ -0,0 +1,87 @@
"""Unit tests for ensure_folder_hierarchy_with_depth_validation."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
pytestmark = pytest.mark.unit
@pytest.mark.asyncio
async def test_creates_missing_folders_in_chain():
"""Should create all folders when none exist."""
from app.services.folder_service import (
ensure_folder_hierarchy_with_depth_validation,
)
session = AsyncMock()
# All lookups return None (no existing folders)
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
session.execute.return_value = mock_result
folder_instances = []
def track_add(obj):
folder_instances.append(obj)
session.add = track_add
with (
patch(
"app.services.folder_service.validate_folder_depth", new_callable=AsyncMock
),
patch(
"app.services.folder_service.generate_folder_position",
new_callable=AsyncMock,
return_value="a0",
),
):
# Mock flush to assign IDs
call_count = 0
async def mock_flush():
nonlocal call_count
call_count += 1
if folder_instances:
folder_instances[-1].id = call_count
session.flush = mock_flush
segments = [
{"name": "Slack", "metadata": {"ai_sort": True, "ai_sort_level": 1}},
{"name": "2025-03-15", "metadata": {"ai_sort": True, "ai_sort_level": 2}},
]
result = await ensure_folder_hierarchy_with_depth_validation(
session, 1, segments
)
assert len(folder_instances) == 2
assert folder_instances[0].name == "Slack"
assert folder_instances[1].name == "2025-03-15"
assert result is folder_instances[-1]
@pytest.mark.asyncio
async def test_reuses_existing_folder():
"""When a folder already exists, it should be reused, not created."""
from app.services.folder_service import (
ensure_folder_hierarchy_with_depth_validation,
)
session = AsyncMock()
existing_folder = MagicMock()
existing_folder.id = 42
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = existing_folder
session.execute.return_value = mock_result
segments = [{"name": "Existing", "metadata": None}]
result = await ensure_folder_hierarchy_with_depth_validation(session, 1, segments)
assert result is existing_folder
session.add.assert_not_called()