feat: made agent file sytem optimized

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-03-28 16:39:46 -07:00
parent ee0b59c0fa
commit 2cc2d339e6
67 changed files with 8011 additions and 5591 deletions

@ -1 +0,0 @@
Subproject commit a32ce7ff6b2112cf48170d2279a1953eded61987

View file

@ -169,13 +169,3 @@ LANGSMITH_TRACING=true
LANGSMITH_ENDPOINT=https://api.smith.langchain.com
LANGSMITH_API_KEY=lsv2_pt_.....
LANGSMITH_PROJECT=surfsense
# Agent Specific Configuration
# Daytona Sandbox (secure cloud code execution for deep agent)
# Set DAYTONA_SANDBOX_ENABLED=TRUE to give the agent an isolated execute tool
DAYTONA_SANDBOX_ENABLED=TRUE
DAYTONA_API_KEY=dtn_asdasfasfafas
DAYTONA_API_URL=https://app.daytona.io/api
DAYTONA_TARGET=us
# Directory for locally-persisted sandbox files (after sandbox deletion)
SANDBOX_FILES_DIR=sandbox_files

View file

@ -1,11 +1,12 @@
"""
SurfSense New Chat Agent Module.
This module provides the SurfSense deep agent with configurable tools
for knowledge base search, podcast generation, and more.
This module provides the SurfSense deep agent with configurable tools,
middleware, and preloaded knowledge-base filesystem behavior.
Directory Structure:
- tools/: All agent tools (knowledge_base, podcast, generate_image, etc.)
- tools/: All agent tools (podcast, generate_image, web, memory, etc.)
- middleware/: Custom middleware (knowledge search, filesystem, dedup, etc.)
- chat_deepagent.py: Main agent factory
- system_prompt.py: System prompts and instructions
- context.py: Context schema for the agent
@ -23,6 +24,13 @@ from .context import SurfSenseContextSchema
# LLM config
from .llm_config import create_chat_litellm_from_config, load_llm_config_from_yaml
# Middleware
from .middleware import (
DedupHITLToolCallsMiddleware,
KnowledgeBaseSearchMiddleware,
SurfSenseFilesystemMiddleware,
)
# System prompt
from .system_prompt import (
SURFSENSE_CITATION_INSTRUCTIONS,
@ -39,7 +47,6 @@ from .tools import (
build_tools,
create_generate_podcast_tool,
create_scrape_webpage_tool,
create_search_knowledge_base_tool,
format_documents_for_context,
get_all_tool_names,
get_default_enabled_tools,
@ -53,8 +60,12 @@ __all__ = [
# System prompt
"SURFSENSE_CITATION_INSTRUCTIONS",
"SURFSENSE_SYSTEM_PROMPT",
# Middleware
"DedupHITLToolCallsMiddleware",
"KnowledgeBaseSearchMiddleware",
# Context
"SurfSenseContextSchema",
"SurfSenseFilesystemMiddleware",
"ToolDefinition",
"build_surfsense_system_prompt",
"build_tools",
@ -63,7 +74,6 @@ __all__ = [
# Tool factories
"create_generate_podcast_tool",
"create_scrape_webpage_tool",
"create_search_knowledge_base_tool",
# Agent factory
"create_surfsense_deep_agent",
# Knowledge base utilities

View file

@ -4,6 +4,13 @@ SurfSense deep agent implementation.
This module provides the factory function for creating SurfSense deep agents
with configurable tools via the tools registry and configurable prompts
via NewLLMConfig.
We use ``create_agent`` (from langchain) rather than ``create_deep_agent``
(from deepagents) so that the middleware stack is fully under our control.
This lets us swap in ``SurfSenseFilesystemMiddleware`` a customisable
subclass of the default ``FilesystemMiddleware`` while preserving every
other behaviour that ``create_deep_agent`` provides (todo-list, subagents,
summarisation, prompt-caching, etc.).
"""
import asyncio
@ -12,8 +19,15 @@ import time
from collections.abc import Sequence
from typing import Any
from deepagents import create_deep_agent
from deepagents.backends.protocol import SandboxBackendProtocol
from deepagents import SubAgent, SubAgentMiddleware, __version__ as deepagents_version
from deepagents.backends import StateBackend
from deepagents.graph import BASE_AGENT_PROMPT
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
from deepagents.middleware.summarization import create_summarization_middleware
from langchain.agents import create_agent
from langchain.agents.middleware import TodoListMiddleware
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.types import Checkpointer
@ -21,8 +35,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.context import SurfSenseContextSchema
from app.agents.new_chat.llm_config import AgentConfig
from app.agents.new_chat.middleware.dedup_tool_calls import (
from app.agents.new_chat.middleware import (
DedupHITLToolCallsMiddleware,
KnowledgeBaseSearchMiddleware,
SurfSenseFilesystemMiddleware,
)
from app.agents.new_chat.system_prompt import (
build_configurable_system_prompt,
@ -40,15 +56,15 @@ _perf_log = get_perf_logger()
# =============================================================================
# Maps SearchSourceConnectorType enum values to the searchable document/connector types
# used by the knowledge_base and web_search tools.
# used by pre-search middleware and web_search.
# Live search connectors (TAVILY_API, LINKUP_API, BAIDU_SEARCH_API) are routed to
# the web_search tool; all others go to search_knowledge_base.
# the web_search tool; all others are considered local/indexed data.
_CONNECTOR_TYPE_TO_SEARCHABLE: dict[str, str] = {
# Live search connectors (handled by web_search tool)
"TAVILY_API": "TAVILY_API",
"LINKUP_API": "LINKUP_API",
"BAIDU_SEARCH_API": "BAIDU_SEARCH_API",
# Local/indexed connectors (handled by search_knowledge_base tool)
# Local/indexed connectors (handled by KB pre-search middleware)
"SLACK_CONNECTOR": "SLACK_CONNECTOR",
"TEAMS_CONNECTOR": "TEAMS_CONNECTOR",
"NOTION_CONNECTOR": "NOTION_CONNECTOR",
@ -141,13 +157,11 @@ async def create_surfsense_deep_agent(
additional_tools: Sequence[BaseTool] | None = None,
firecrawl_api_key: str | None = None,
thread_visibility: ChatVisibility | None = None,
sandbox_backend: SandboxBackendProtocol | None = None,
):
"""
Create a SurfSense deep agent with configurable tools and prompts.
The agent comes with built-in tools that can be configured:
- search_knowledge_base: Search the user's personal knowledge base
- generate_podcast: Generate audio podcasts from content
- generate_image: Generate images from text descriptions using AI models
- scrape_webpage: Extract content from webpages
@ -179,9 +193,6 @@ async def create_surfsense_deep_agent(
These are always added regardless of enabled/disabled settings.
firecrawl_api_key: Optional Firecrawl API key for premium web scraping.
Falls back to Chromium/Trafilatura if not provided.
sandbox_backend: Optional sandbox backend (e.g. DaytonaSandbox) for
secure code execution. When provided, the agent gets an
isolated ``execute`` tool for running shell commands.
Returns:
CompiledStateGraph: The configured deep agent
@ -205,7 +216,7 @@ async def create_surfsense_deep_agent(
# Create agent with only specific tools
agent = create_surfsense_deep_agent(
llm, search_space_id, db_session, ...,
enabled_tools=["search_knowledge_base", "scrape_webpage"]
enabled_tools=["scrape_webpage"]
)
# Create agent without podcast generation
@ -357,6 +368,10 @@ async def create_surfsense_deep_agent(
]
modified_disabled_tools.extend(confluence_tools)
# Remove direct KB search tool; we now pre-seed a scoped filesystem via middleware.
if "search_knowledge_base" not in modified_disabled_tools:
modified_disabled_tools.append("search_knowledge_base")
# Build tools using the async registry (includes MCP tools)
_t0 = time.perf_counter()
tools = await build_tools_async(
@ -373,7 +388,6 @@ async def create_surfsense_deep_agent(
# Build system prompt based on agent_config, scoped to the tools actually enabled
_t0 = time.perf_counter()
_sandbox_enabled = sandbox_backend is not None
_enabled_tool_names = {t.name for t in tools}
_user_disabled_tool_names = set(disabled_tools) if disabled_tools else set()
if agent_config is not None:
@ -382,14 +396,12 @@ async def create_surfsense_deep_agent(
use_default_system_instructions=agent_config.use_default_system_instructions,
citations_enabled=agent_config.citations_enabled,
thread_visibility=thread_visibility,
sandbox_enabled=_sandbox_enabled,
enabled_tool_names=_enabled_tool_names,
disabled_tool_names=_user_disabled_tool_names,
)
else:
system_prompt = build_surfsense_system_prompt(
thread_visibility=thread_visibility,
sandbox_enabled=_sandbox_enabled,
enabled_tool_names=_enabled_tool_names,
disabled_tool_names=_user_disabled_tool_names,
)
@ -397,24 +409,69 @@ async def create_surfsense_deep_agent(
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
)
# Build optional kwargs for the deep agent
deep_agent_kwargs: dict[str, Any] = {}
if sandbox_backend is not None:
deep_agent_kwargs["backend"] = sandbox_backend
# -- Build the middleware stack (mirrors create_deep_agent internals) ------
# General-purpose subagent middleware
gp_middleware = [
TodoListMiddleware(),
SurfSenseFilesystemMiddleware(
search_space_id=search_space_id,
created_by_id=user_id,
),
create_summarization_middleware(llm, StateBackend),
PatchToolCallsMiddleware(),
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
]
general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key]
**GENERAL_PURPOSE_SUBAGENT,
"model": llm,
"tools": tools,
"middleware": gp_middleware,
}
# Main agent middleware
deepagent_middleware = [
TodoListMiddleware(),
KnowledgeBaseSearchMiddleware(
search_space_id=search_space_id,
available_connectors=available_connectors,
available_document_types=available_document_types,
),
SurfSenseFilesystemMiddleware(
search_space_id=search_space_id,
created_by_id=user_id,
),
SubAgentMiddleware(backend=StateBackend, subagents=[general_purpose_spec]),
create_summarization_middleware(llm, StateBackend),
PatchToolCallsMiddleware(),
DedupHITLToolCallsMiddleware(),
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
]
# Combine system_prompt with BASE_AGENT_PROMPT (same as create_deep_agent)
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
_t0 = time.perf_counter()
agent = await asyncio.to_thread(
create_deep_agent,
model=llm,
create_agent,
llm,
system_prompt=final_system_prompt,
tools=tools,
system_prompt=system_prompt,
middleware=deepagent_middleware,
context_schema=SurfSenseContextSchema,
checkpointer=checkpointer,
middleware=[DedupHITLToolCallsMiddleware()],
**deep_agent_kwargs,
)
agent = agent.with_config(
{
"recursion_limit": 10_000,
"metadata": {
"ls_integration": "deepagents",
"versions": {"deepagents": deepagents_version},
},
}
)
_perf_log.info(
"[create_agent] Graph compiled (create_deep_agent) in %.3fs",
"[create_agent] Graph compiled (create_agent) in %.3fs",
time.perf_counter() - _t0,
)

View file

@ -0,0 +1,17 @@
"""Middleware components for the SurfSense new chat agent."""
from app.agents.new_chat.middleware.dedup_tool_calls import (
DedupHITLToolCallsMiddleware,
)
from app.agents.new_chat.middleware.filesystem import (
SurfSenseFilesystemMiddleware,
)
from app.agents.new_chat.middleware.knowledge_search import (
KnowledgeBaseSearchMiddleware,
)
__all__ = [
"DedupHITLToolCallsMiddleware",
"KnowledgeBaseSearchMiddleware",
"SurfSenseFilesystemMiddleware",
]

View file

@ -0,0 +1,694 @@
"""Custom filesystem middleware for the SurfSense agent.
This middleware customizes prompts and persists write/edit operations for
`/documents/*` files into SurfSense's `Document`/`Chunk` tables.
"""
from __future__ import annotations
import asyncio
import re
from datetime import UTC, datetime
from typing import Annotated, Any
from deepagents import FilesystemMiddleware
from deepagents.backends.protocol import EditResult, WriteResult
from deepagents.backends.utils import validate_path
from deepagents.middleware.filesystem import FilesystemState
from fractional_indexing import generate_key_between
from langchain.tools import ToolRuntime
from langchain_core.callbacks import dispatch_custom_event
from langchain_core.messages import ToolMessage
from langchain_core.tools import BaseTool, StructuredTool
from langgraph.types import Command
from sqlalchemy import delete, select
from app.db import Chunk, Document, DocumentType, Folder, shielded_async_session
from app.indexing_pipeline.document_chunker import chunk_text
from app.utils.document_converters import (
embed_texts,
generate_content_hash,
generate_unique_identifier_hash,
)
# =============================================================================
# System Prompt (injected into every model call by wrap_model_call)
# =============================================================================
SURFSENSE_FILESYSTEM_SYSTEM_PROMPT = """## Following Conventions
- Read files before editing understand existing content before making changes.
- Mimic existing style, naming conventions, and patterns.
## Filesystem Tools `ls`, `read_file`, `write_file`, `edit_file`, `glob`, `grep`, `save_document`
All file paths must start with a `/`.
- ls: list files and directories at a given path.
- read_file: read a file from the filesystem.
- write_file: create a temporary file in the session (not persisted).
- edit_file: edit a file in the session (not persisted for /documents/ files).
- glob: find files matching a pattern (e.g., "**/*.xml").
- grep: search for text within files.
- save_document: **permanently** save a new document to the user's knowledge
base. Use only when the user explicitly asks to save/create a document.
## Reading Documents Efficiently
Documents are formatted as XML. Each document contains:
- `<document_metadata>` title, type, URL, etc.
- `<chunk_index>` a table of every chunk with its **line range** and a
`matched="true"` flag for chunks that matched the search query.
- `<document_content>` the actual chunks in original document order.
**Workflow**: when reading a large document, read the first ~20 lines to see
the `<chunk_index>`, identify chunks marked `matched="true"`, then use
`read_file(path, offset=<start_line>, limit=<lines>)` to jump directly to
those sections instead of reading the entire file sequentially.
Use `<chunk id='...'>` values as citation IDs in your answers.
"""
# =============================================================================
# Per-Tool Descriptions (shown to the LLM as the tool's docstring)
# =============================================================================
SURFSENSE_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path.
"""
SURFSENSE_READ_FILE_TOOL_DESCRIPTION = """Reads a file from the filesystem.
Usage:
- By default, reads up to 100 lines from the beginning.
- Use `offset` and `limit` for pagination when files are large.
- Results include line numbers.
- Documents contain a `<chunk_index>` near the top listing every chunk with
its line range and a `matched="true"` flag for search-relevant chunks.
Read the index first, then jump to matched chunks with
`read_file(path, offset=<start_line>, limit=<num_lines>)`.
- Use chunk IDs (`<chunk id='...'>`) as citations in answers.
"""
SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new file to the in-memory filesystem (session-only).
Use this to create scratch/working files during the conversation. Files created
here are ephemeral and will not be saved to the user's knowledge base.
To permanently save a document to the user's knowledge base, use the
`save_document` tool instead.
"""
SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION = """Performs exact string replacements in files.
IMPORTANT:
- Read the file before editing.
- Preserve exact indentation and formatting.
- Edits to documents under `/documents/` are session-only (not persisted to the
database) because those files use an XML citation wrapper around the original
content.
"""
SURFSENSE_GLOB_TOOL_DESCRIPTION = """Find files matching a glob pattern.
Supports standard glob patterns: `*`, `**`, `?`.
Returns absolute file paths.
"""
SURFSENSE_GREP_TOOL_DESCRIPTION = """Search for a literal text pattern across files.
Use this to locate relevant document files/chunks before reading full files.
"""
SURFSENSE_SAVE_DOCUMENT_TOOL_DESCRIPTION = """Permanently saves a document to the user's knowledge base.
This is an expensive operation it creates a new Document record in the
database, chunks the content, and generates embeddings for search.
Use ONLY when the user explicitly asks to save/create/store a document.
Do NOT use this for scratch work; use `write_file` for temporary files.
Args:
title: The document title (e.g., "Meeting Notes 2025-06-01").
content: The plain-text or markdown content to save. Do NOT include XML
citation wrappers pass only the actual document text.
folder_path: Optional folder path under /documents/ (e.g., "Work/Notes").
Folders are created automatically if they don't exist.
"""
class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
"""SurfSense-specific filesystem middleware with DB persistence for docs."""
def __init__(
self,
*,
search_space_id: int | None = None,
created_by_id: str | None = None,
tool_token_limit_before_evict: int | None = 20000,
) -> None:
self._search_space_id = search_space_id
self._created_by_id = created_by_id
super().__init__(
system_prompt=SURFSENSE_FILESYSTEM_SYSTEM_PROMPT,
custom_tool_descriptions={
"ls": SURFSENSE_LIST_FILES_TOOL_DESCRIPTION,
"read_file": SURFSENSE_READ_FILE_TOOL_DESCRIPTION,
"write_file": SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION,
"edit_file": SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION,
"glob": SURFSENSE_GLOB_TOOL_DESCRIPTION,
"grep": SURFSENSE_GREP_TOOL_DESCRIPTION,
},
tool_token_limit_before_evict=tool_token_limit_before_evict,
)
# Remove the execute tool (no sandbox backend)
self.tools = [t for t in self.tools if t.name != "execute"]
self.tools.append(self._create_save_document_tool())
@staticmethod
def _run_async_blocking(coro: Any) -> Any:
"""Run async coroutine from sync code path when no event loop is running."""
try:
loop = asyncio.get_running_loop()
if loop.is_running():
return "Error: sync filesystem persistence not supported inside an active event loop."
except RuntimeError:
pass
return asyncio.run(coro)
@staticmethod
def _parse_virtual_path(file_path: str) -> tuple[list[str], str]:
"""Parse /documents/... path into folder parts and a document title."""
if not file_path.startswith("/documents/"):
return [], ""
rel = file_path[len("/documents/") :].strip("/")
if not rel:
return [], ""
parts = [part for part in rel.split("/") if part]
file_name = parts[-1]
title = file_name[:-4] if file_name.lower().endswith(".xml") else file_name
return parts[:-1], title
async def _ensure_folder_hierarchy(
self,
*,
folder_parts: list[str],
search_space_id: int,
) -> int | None:
"""Ensure folder hierarchy exists and return leaf folder ID."""
if not folder_parts:
return None
async with shielded_async_session() as session:
parent_id: int | None = None
for name in folder_parts:
result = await session.execute(
select(Folder).where(
Folder.search_space_id == search_space_id,
Folder.parent_id == parent_id
if parent_id is not None
else Folder.parent_id.is_(None),
Folder.name == name,
)
)
folder = result.scalar_one_or_none()
if folder is None:
sibling_result = await session.execute(
select(Folder.position)
.where(
Folder.search_space_id == search_space_id,
Folder.parent_id == parent_id
if parent_id is not None
else Folder.parent_id.is_(None),
)
.order_by(Folder.position.desc())
.limit(1)
)
last_position = sibling_result.scalar_one_or_none()
folder = Folder(
name=name,
position=generate_key_between(last_position, None),
parent_id=parent_id,
search_space_id=search_space_id,
created_by_id=self._created_by_id,
updated_at=datetime.now(UTC),
)
session.add(folder)
await session.flush()
parent_id = folder.id
await session.commit()
return parent_id
async def _persist_new_document(
self, *, file_path: str, content: str
) -> dict[str, Any] | str:
"""Persist a new NOTE document from a newly written file.
Returns a dict with document metadata on success, or an error string.
"""
if self._search_space_id is None:
return {}
folder_parts, title = self._parse_virtual_path(file_path)
if not title:
return "Error: write_file for document persistence requires path under /documents/<name>.xml"
folder_id = await self._ensure_folder_hierarchy(
folder_parts=folder_parts,
search_space_id=self._search_space_id,
)
async with shielded_async_session() as session:
content_hash = generate_content_hash(content, self._search_space_id)
existing = await session.execute(
select(Document.id).where(Document.content_hash == content_hash)
)
if existing.scalar_one_or_none() is not None:
return "Error: A document with identical content already exists."
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.NOTE,
file_path,
self._search_space_id,
)
doc = Document(
title=title,
document_type=DocumentType.NOTE,
document_metadata={"virtual_path": file_path},
content=content,
content_hash=content_hash,
unique_identifier_hash=unique_identifier_hash,
source_markdown=content,
search_space_id=self._search_space_id,
folder_id=folder_id,
created_by_id=self._created_by_id,
updated_at=datetime.now(UTC),
)
session.add(doc)
await session.flush()
summary_embedding = embed_texts([content])[0]
doc.embedding = summary_embedding
chunk_texts = chunk_text(content)
if chunk_texts:
chunk_embeddings = embed_texts(chunk_texts)
chunks = [
Chunk(document_id=doc.id, content=text, embedding=embedding)
for text, embedding in zip(
chunk_texts, chunk_embeddings, strict=True
)
]
session.add_all(chunks)
await session.commit()
return {
"id": doc.id,
"title": title,
"documentType": DocumentType.NOTE.value,
"searchSpaceId": self._search_space_id,
"folderId": folder_id,
"createdById": str(self._created_by_id)
if self._created_by_id
else None,
}
async def _persist_edited_document(
self, *, file_path: str, updated_content: str
) -> str | None:
"""Persist edits for an existing NOTE document and recreate chunks."""
if self._search_space_id is None:
return None
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.NOTE,
file_path,
self._search_space_id,
)
doc_id_from_xml: int | None = None
match = re.search(r"<document_id>\s*(\d+)\s*</document_id>", updated_content)
if match:
doc_id_from_xml = int(match.group(1))
async with shielded_async_session() as session:
doc_result = await session.execute(
select(Document).where(
Document.search_space_id == self._search_space_id,
Document.unique_identifier_hash == unique_identifier_hash,
)
)
document = doc_result.scalar_one_or_none()
if document is None and doc_id_from_xml is not None:
by_id_result = await session.execute(
select(Document).where(
Document.search_space_id == self._search_space_id,
Document.id == doc_id_from_xml,
)
)
document = by_id_result.scalar_one_or_none()
if document is None:
return "Error: Could not map edited file to an existing document."
document.content = updated_content
document.source_markdown = updated_content
document.content_hash = generate_content_hash(
updated_content, self._search_space_id
)
document.updated_at = datetime.now(UTC)
if not document.document_metadata:
document.document_metadata = {}
document.document_metadata["virtual_path"] = file_path
summary_embedding = embed_texts([updated_content])[0]
document.embedding = summary_embedding
await session.execute(delete(Chunk).where(Chunk.document_id == document.id))
chunk_texts = chunk_text(updated_content)
if chunk_texts:
chunk_embeddings = embed_texts(chunk_texts)
session.add_all(
[
Chunk(
document_id=document.id, content=text, embedding=embedding
)
for text, embedding in zip(
chunk_texts, chunk_embeddings, strict=True
)
]
)
await session.commit()
return None
def _create_save_document_tool(self) -> BaseTool:
"""Create save_document tool that persists a new document to the KB."""
def sync_save_document(
title: Annotated[str, "Title for the new document."],
content: Annotated[
str,
"Plain-text or markdown content to save. Do NOT include XML wrappers.",
],
runtime: ToolRuntime[None, FilesystemState],
folder_path: Annotated[
str,
"Optional folder path under /documents/ (e.g. 'Work/Notes'). Created automatically.",
] = "",
) -> Command | str:
if not content.strip():
return "Error: content cannot be empty."
file_name = re.sub(r'[\\/:*?"<>|]+', "_", title).strip() or "untitled"
if not file_name.lower().endswith(".xml"):
file_name = f"{file_name}.xml"
folder = folder_path.strip().strip("/") if folder_path else ""
virtual_path = (
f"/documents/{folder}/{file_name}"
if folder
else f"/documents/{file_name}"
)
persist_result = self._run_async_blocking(
self._persist_new_document(file_path=virtual_path, content=content)
)
if isinstance(persist_result, str):
return persist_result
if isinstance(persist_result, dict) and persist_result.get("id"):
dispatch_custom_event("document_created", persist_result)
return f"Document '{title}' saved to knowledge base (path: {virtual_path})."
async def async_save_document(
title: Annotated[str, "Title for the new document."],
content: Annotated[
str,
"Plain-text or markdown content to save. Do NOT include XML wrappers.",
],
runtime: ToolRuntime[None, FilesystemState],
folder_path: Annotated[
str,
"Optional folder path under /documents/ (e.g. 'Work/Notes'). Created automatically.",
] = "",
) -> Command | str:
if not content.strip():
return "Error: content cannot be empty."
file_name = re.sub(r'[\\/:*?"<>|]+', "_", title).strip() or "untitled"
if not file_name.lower().endswith(".xml"):
file_name = f"{file_name}.xml"
folder = folder_path.strip().strip("/") if folder_path else ""
virtual_path = (
f"/documents/{folder}/{file_name}"
if folder
else f"/documents/{file_name}"
)
persist_result = await self._persist_new_document(
file_path=virtual_path, content=content
)
if isinstance(persist_result, str):
return persist_result
if isinstance(persist_result, dict) and persist_result.get("id"):
dispatch_custom_event("document_created", persist_result)
return f"Document '{title}' saved to knowledge base (path: {virtual_path})."
return StructuredTool.from_function(
name="save_document",
description=SURFSENSE_SAVE_DOCUMENT_TOOL_DESCRIPTION,
func=sync_save_document,
coroutine=async_save_document,
)
def _create_write_file_tool(self) -> BaseTool:
"""Create write_file — ephemeral for /documents/*, persisted otherwise."""
tool_description = (
self._custom_tool_descriptions.get("write_file")
or SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION
)
def sync_write_file(
file_path: Annotated[
str,
"Absolute path where the file should be created. Must be absolute, not relative.",
],
content: Annotated[
str,
"The text content to write to the file. This parameter is required.",
],
runtime: ToolRuntime[None, FilesystemState],
) -> Command | str:
resolved_backend = self._get_backend(runtime)
try:
validated_path = validate_path(file_path)
except ValueError as exc:
return f"Error: {exc}"
res: WriteResult = resolved_backend.write(validated_path, content)
if res.error:
return res.error
if not self._is_kb_document(validated_path):
persist_result = self._run_async_blocking(
self._persist_new_document(
file_path=validated_path, content=content
)
)
if isinstance(persist_result, str):
return persist_result
if isinstance(persist_result, dict) and persist_result.get("id"):
dispatch_custom_event("document_created", persist_result)
if res.files_update is not None:
return Command(
update={
"files": res.files_update,
"messages": [
ToolMessage(
content=f"Updated file {res.path}",
tool_call_id=runtime.tool_call_id,
)
],
}
)
return f"Updated file {res.path}"
async def async_write_file(
file_path: Annotated[
str,
"Absolute path where the file should be created. Must be absolute, not relative.",
],
content: Annotated[
str,
"The text content to write to the file. This parameter is required.",
],
runtime: ToolRuntime[None, FilesystemState],
) -> Command | str:
resolved_backend = self._get_backend(runtime)
try:
validated_path = validate_path(file_path)
except ValueError as exc:
return f"Error: {exc}"
res: WriteResult = await resolved_backend.awrite(validated_path, content)
if res.error:
return res.error
if not self._is_kb_document(validated_path):
persist_result = await self._persist_new_document(
file_path=validated_path,
content=content,
)
if isinstance(persist_result, str):
return persist_result
if isinstance(persist_result, dict) and persist_result.get("id"):
dispatch_custom_event("document_created", persist_result)
if res.files_update is not None:
return Command(
update={
"files": res.files_update,
"messages": [
ToolMessage(
content=f"Updated file {res.path}",
tool_call_id=runtime.tool_call_id,
)
],
}
)
return f"Updated file {res.path}"
return StructuredTool.from_function(
name="write_file",
description=tool_description,
func=sync_write_file,
coroutine=async_write_file,
)
@staticmethod
def _is_kb_document(path: str) -> bool:
"""Return True for paths under /documents/ (KB-sourced, XML-wrapped)."""
return path.startswith("/documents/")
def _create_edit_file_tool(self) -> BaseTool:
"""Create edit_file with DB persistence (skipped for KB documents)."""
tool_description = (
self._custom_tool_descriptions.get("edit_file")
or SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION
)
def sync_edit_file(
file_path: Annotated[
str,
"Absolute path to the file to edit. Must be absolute, not relative.",
],
old_string: Annotated[
str,
"The exact text to find and replace. Must be unique in the file unless replace_all is True.",
],
new_string: Annotated[
str,
"The text to replace old_string with. Must be different from old_string.",
],
runtime: ToolRuntime[None, FilesystemState],
*,
replace_all: Annotated[
bool,
"If True, replace all occurrences of old_string. If False (default), old_string must be unique.",
] = False,
) -> Command | str:
resolved_backend = self._get_backend(runtime)
try:
validated_path = validate_path(file_path)
except ValueError as exc:
return f"Error: {exc}"
res: EditResult = resolved_backend.edit(
validated_path,
old_string,
new_string,
replace_all=replace_all,
)
if res.error:
return res.error
if not self._is_kb_document(validated_path):
read_result = resolved_backend.read(
validated_path, offset=0, limit=200000
)
if read_result.error or read_result.file_data is None:
return f"Error: could not reload edited file '{validated_path}' for persistence."
updated_content = read_result.file_data["content"]
persist_result = self._run_async_blocking(
self._persist_edited_document(
file_path=validated_path,
updated_content=updated_content,
)
)
if isinstance(persist_result, str):
return persist_result
if res.files_update is not None:
return Command(
update={
"files": res.files_update,
"messages": [
ToolMessage(
content=f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'",
tool_call_id=runtime.tool_call_id,
)
],
}
)
return f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'"
async def async_edit_file(
file_path: Annotated[
str,
"Absolute path to the file to edit. Must be absolute, not relative.",
],
old_string: Annotated[
str,
"The exact text to find and replace. Must be unique in the file unless replace_all is True.",
],
new_string: Annotated[
str,
"The text to replace old_string with. Must be different from old_string.",
],
runtime: ToolRuntime[None, FilesystemState],
*,
replace_all: Annotated[
bool,
"If True, replace all occurrences of old_string. If False (default), old_string must be unique.",
] = False,
) -> Command | str:
resolved_backend = self._get_backend(runtime)
try:
validated_path = validate_path(file_path)
except ValueError as exc:
return f"Error: {exc}"
res: EditResult = await resolved_backend.aedit(
validated_path,
old_string,
new_string,
replace_all=replace_all,
)
if res.error:
return res.error
if not self._is_kb_document(validated_path):
read_result = await resolved_backend.aread(
validated_path, offset=0, limit=200000
)
if read_result.error or read_result.file_data is None:
return f"Error: could not reload edited file '{validated_path}' for persistence."
updated_content = read_result.file_data["content"]
persist_error = await self._persist_edited_document(
file_path=validated_path,
updated_content=updated_content,
)
if persist_error:
return persist_error
if res.files_update is not None:
return Command(
update={
"files": res.files_update,
"messages": [
ToolMessage(
content=f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'",
tool_call_id=runtime.tool_call_id,
)
],
}
)
return f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'"
return StructuredTool.from_function(
name="edit_file",
description=tool_description,
func=sync_edit_file,
coroutine=async_edit_file,
)

View file

@ -0,0 +1,414 @@
"""Knowledge-base pre-search middleware for the SurfSense new chat agent.
This middleware runs before the main agent loop and seeds a virtual filesystem
(`files` state) with relevant documents retrieved via hybrid search. On each
turn the filesystem is *expanded* new results merge with documents loaded
during prior turns and a synthetic ``ls`` result is injected into the message
history so the LLM is immediately aware of the current filesystem structure.
"""
from __future__ import annotations
import asyncio
import json
import logging
import re
import uuid
from collections.abc import Sequence
from typing import Any
from langchain.agents.middleware import AgentMiddleware, AgentState
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langgraph.runtime import Runtime
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import NATIVE_TO_LEGACY_DOCTYPE, Document, Folder, shielded_async_session
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
from app.utils.document_converters import embed_texts
from app.utils.perf import get_perf_logger
logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
def _extract_text_from_message(message: BaseMessage) -> str:
"""Extract plain text from a message content."""
content = getattr(message, "content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict) and item.get("type") == "text":
parts.append(str(item.get("text", "")))
return "\n".join(p for p in parts if p)
return str(content)
def _safe_filename(value: str, *, fallback: str = "untitled.xml") -> str:
"""Convert arbitrary text into a filesystem-safe filename."""
name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
name = re.sub(r"\s+", " ", name)
if not name:
name = fallback
if len(name) > 180:
name = name[:180].rstrip()
if not name.lower().endswith(".xml"):
name = f"{name}.xml"
return name
def _build_document_xml(
document: dict[str, Any],
matched_chunk_ids: set[int] | None = None,
) -> str:
"""Build citation-friendly XML with a ``<chunk_index>`` for smart seeking.
The ``<chunk_index>`` at the top of each document lists every chunk with its
line range inside ``<document_content>`` and flags chunks that directly
matched the search query (``matched="true"``). This lets the LLM jump
straight to the most relevant section via ``read_file(offset=, limit=)``
instead of reading sequentially from the start.
"""
matched = matched_chunk_ids or set()
doc_meta = document.get("document") or {}
metadata = (doc_meta.get("metadata") or {}) if isinstance(doc_meta, dict) else {}
document_id = doc_meta.get("id", document.get("document_id", "unknown"))
document_type = doc_meta.get("document_type", document.get("source", "UNKNOWN"))
title = doc_meta.get("title") or metadata.get("title") or "Untitled Document"
url = (
metadata.get("url") or metadata.get("source") or metadata.get("page_url") or ""
)
metadata_json = json.dumps(metadata, ensure_ascii=False)
# --- 1. Metadata header (fixed structure) ---
metadata_lines: list[str] = [
"<document>",
"<document_metadata>",
f" <document_id>{document_id}</document_id>",
f" <document_type>{document_type}</document_type>",
f" <title><![CDATA[{title}]]></title>",
f" <url><![CDATA[{url}]]></url>",
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
"</document_metadata>",
"",
]
# --- 2. Pre-build chunk XML strings to compute line counts ---
chunks = document.get("chunks") or []
chunk_entries: list[tuple[int | None, str]] = [] # (chunk_id, xml_string)
if isinstance(chunks, list):
for chunk in chunks:
if not isinstance(chunk, dict):
continue
chunk_id = chunk.get("chunk_id") or chunk.get("id")
chunk_content = str(chunk.get("content", "")).strip()
if not chunk_content:
continue
if chunk_id is None:
xml = f" <chunk><![CDATA[{chunk_content}]]></chunk>"
else:
xml = f" <chunk id='{chunk_id}'><![CDATA[{chunk_content}]]></chunk>"
chunk_entries.append((chunk_id, xml))
# --- 3. Compute line numbers for every chunk ---
# Layout (1-indexed lines for read_file):
# metadata_lines -> len(metadata_lines) lines
# <chunk_index> -> 1 line
# index entries -> len(chunk_entries) lines
# </chunk_index> -> 1 line
# (empty line) -> 1 line
# <document_content> -> 1 line
# chunk xml lines…
# </document_content> -> 1 line
# </document> -> 1 line
index_overhead = (
1 + len(chunk_entries) + 1 + 1 + 1
) # tags + empty + <document_content>
first_chunk_line = len(metadata_lines) + index_overhead + 1 # 1-indexed
current_line = first_chunk_line
index_entry_lines: list[str] = []
for cid, xml_str in chunk_entries:
num_lines = xml_str.count("\n") + 1
end_line = current_line + num_lines - 1
matched_attr = ' matched="true"' if cid is not None and cid in matched else ""
if cid is not None:
index_entry_lines.append(
f' <entry chunk_id="{cid}" lines="{current_line}-{end_line}"{matched_attr}/>'
)
else:
index_entry_lines.append(
f' <entry lines="{current_line}-{end_line}"{matched_attr}/>'
)
current_line = end_line + 1
# --- 4. Assemble final XML ---
lines = metadata_lines.copy()
lines.append("<chunk_index>")
lines.extend(index_entry_lines)
lines.append("</chunk_index>")
lines.append("")
lines.append("<document_content>")
for _, xml_str in chunk_entries:
lines.append(xml_str)
lines.extend(["</document_content>", "</document>"])
return "\n".join(lines)
async def _get_folder_paths(
session: AsyncSession, search_space_id: int
) -> dict[int, str]:
"""Return a map of folder_id -> virtual folder path under /documents."""
result = await session.execute(
select(Folder.id, Folder.name, Folder.parent_id).where(
Folder.search_space_id == search_space_id
)
)
rows = result.all()
by_id = {row.id: {"name": row.name, "parent_id": row.parent_id} for row in rows}
cache: dict[int, str] = {}
def resolve_path(folder_id: int) -> str:
if folder_id in cache:
return cache[folder_id]
parts: list[str] = []
cursor: int | None = folder_id
visited: set[int] = set()
while cursor is not None and cursor in by_id and cursor not in visited:
visited.add(cursor)
entry = by_id[cursor]
parts.append(
_safe_filename(str(entry["name"]), fallback="folder").removesuffix(
".xml"
)
)
cursor = entry["parent_id"]
parts.reverse()
path = "/documents/" + "/".join(parts) if parts else "/documents"
cache[folder_id] = path
return path
for folder_id in by_id:
resolve_path(folder_id)
return cache
def _build_synthetic_ls(
existing_files: dict[str, Any] | None,
new_files: dict[str, Any],
) -> tuple[AIMessage, ToolMessage]:
"""Build a synthetic ls("/documents") tool-call + result for the LLM context.
Paths are listed with *new* (rank-ordered) files first, then existing files
that were already in state from prior turns.
"""
merged: dict[str, Any] = {**(existing_files or {}), **new_files}
doc_paths = [
p for p, v in merged.items() if p.startswith("/documents/") and v is not None
]
new_set = set(new_files)
new_paths = [p for p in doc_paths if p in new_set]
old_paths = [p for p in doc_paths if p not in new_set]
ordered = new_paths + old_paths
tool_call_id = f"auto_ls_{uuid.uuid4().hex[:12]}"
ai_msg = AIMessage(
content="",
tool_calls=[{"name": "ls", "args": {"path": "/documents"}, "id": tool_call_id}],
)
tool_msg = ToolMessage(
content=str(ordered) if ordered else "No documents found.",
tool_call_id=tool_call_id,
)
return ai_msg, tool_msg
def _resolve_search_types(
available_connectors: list[str] | None,
available_document_types: list[str] | None,
) -> list[str] | None:
"""Build a flat list of document-type strings for the chunk retriever.
Includes legacy equivalents from ``NATIVE_TO_LEGACY_DOCTYPE`` so that
old documents indexed under Composio names are still found.
Returns ``None`` when no filtering is desired (search all types).
"""
types: set[str] = set()
if available_document_types:
types.update(available_document_types)
if available_connectors:
types.update(available_connectors)
if not types:
return None
expanded: set[str] = set(types)
for t in types:
legacy = NATIVE_TO_LEGACY_DOCTYPE.get(t)
if legacy:
expanded.add(legacy)
return list(expanded) if expanded else None
async def search_knowledge_base(
*,
query: str,
search_space_id: int,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
top_k: int = 10,
) -> list[dict[str, Any]]:
"""Run a single unified hybrid search against the knowledge base.
Uses one ``ChucksHybridSearchRetriever`` call across all document types
instead of fanning out per-connector. This reduces the number of DB
queries from ~10 to 2 (one RRF query + one chunk fetch).
"""
if not query:
return []
[embedding] = embed_texts([query])
doc_types = _resolve_search_types(available_connectors, available_document_types)
retriever_top_k = min(top_k * 3, 30)
async with shielded_async_session() as session:
retriever = ChucksHybridSearchRetriever(session)
results = await retriever.hybrid_search(
query_text=query,
top_k=retriever_top_k,
search_space_id=search_space_id,
document_type=doc_types,
query_embedding=embedding.tolist(),
)
return results[:top_k]
async def build_scoped_filesystem(
*,
documents: Sequence[dict[str, Any]],
search_space_id: int,
) -> dict[str, dict[str, str]]:
"""Build a StateBackend-compatible files dict from search results."""
async with shielded_async_session() as session:
folder_paths = await _get_folder_paths(session, search_space_id)
doc_ids = [
(doc.get("document") or {}).get("id")
for doc in documents
if isinstance(doc, dict)
]
doc_ids = [doc_id for doc_id in doc_ids if isinstance(doc_id, int)]
folder_by_doc_id: dict[int, int | None] = {}
if doc_ids:
doc_rows = await session.execute(
select(Document.id, Document.folder_id).where(
Document.search_space_id == search_space_id,
Document.id.in_(doc_ids),
)
)
folder_by_doc_id = {
row.id: row.folder_id for row in doc_rows.all() if row.id is not None
}
files: dict[str, dict[str, str]] = {}
for document in documents:
doc_meta = document.get("document") or {}
title = str(doc_meta.get("title") or "untitled")
doc_id = doc_meta.get("id")
folder_id = folder_by_doc_id.get(doc_id) if isinstance(doc_id, int) else None
base_folder = folder_paths.get(folder_id, "/documents")
file_name = _safe_filename(title)
path = f"{base_folder}/{file_name}"
matched_ids = set(document.get("matched_chunk_ids") or [])
xml_content = _build_document_xml(document, matched_chunk_ids=matched_ids)
files[path] = {
"content": xml_content.split("\n"),
"encoding": "utf-8",
"created_at": "",
"modified_at": "",
}
return files
class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
"""Pre-agent middleware that always searches the KB and seeds a scoped filesystem."""
tools = ()
def __init__(
self,
*,
search_space_id: int,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
top_k: int = 10,
) -> None:
self.search_space_id = search_space_id
self.available_connectors = available_connectors
self.available_document_types = available_document_types
self.top_k = top_k
def before_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
try:
loop = asyncio.get_running_loop()
if loop.is_running():
return None
except RuntimeError:
pass
return asyncio.run(self.abefore_agent(state, runtime))
async def abefore_agent( # type: ignore[override]
self,
state: AgentState,
runtime: Runtime[Any],
) -> dict[str, Any] | None:
del runtime
messages = state.get("messages") or []
if not messages:
return None
last_message = messages[-1]
if not isinstance(last_message, HumanMessage):
return None
user_text = _extract_text_from_message(last_message).strip()
if not user_text:
return None
t0 = _perf_log and asyncio.get_event_loop().time()
existing_files = state.get("files")
search_results = await search_knowledge_base(
query=user_text,
search_space_id=self.search_space_id,
available_connectors=self.available_connectors,
available_document_types=self.available_document_types,
top_k=self.top_k,
)
new_files = await build_scoped_filesystem(
documents=search_results,
search_space_id=self.search_space_id,
)
ai_msg, tool_msg = _build_synthetic_ls(existing_files, new_files)
if t0 is not None:
_perf_log.info(
"[kb_fs_middleware] completed in %.3fs query=%r new_files=%d total=%d",
asyncio.get_event_loop().time() - t0,
user_text[:80],
len(new_files),
len(new_files) + len(existing_files or {}),
)
return {"files": new_files, "messages": [ai_msg, tool_msg]}

View file

@ -25,6 +25,21 @@ When writing mathematical formulas or equations, ALWAYS use LaTeX notation. NEVE
NEVER expose internal tool parameter names, backend IDs, or implementation details to the user. Always use natural, user-friendly language instead.
<knowledge_base_only_policy>
CRITICAL RULE KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
- You MUST answer questions ONLY using information retrieved from the user's knowledge base, web search results, scraped webpages, or other tool outputs.
- You MUST NOT answer factual or informational questions from your own training data or general knowledge unless the user explicitly grants permission.
- If the knowledge base search returns no relevant results AND no other tool provides the answer, you MUST:
1. Inform the user that you could not find relevant information in their knowledge base.
2. Ask the user: "Would you like me to answer from my general knowledge instead?"
3. ONLY provide a general-knowledge answer AFTER the user explicitly says yes.
- This policy does NOT apply to:
* Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?")
* Formatting, summarization, or analysis of content already present in the conversation
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
</knowledge_base_only_policy>
</system_instruction>
"""
@ -41,6 +56,21 @@ When writing mathematical formulas or equations, ALWAYS use LaTeX notation. NEVE
NEVER expose internal tool parameter names, backend IDs, or implementation details to the user. Always use natural, user-friendly language instead.
<knowledge_base_only_policy>
CRITICAL RULE KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
- You MUST answer questions ONLY using information retrieved from the team's shared knowledge base, web search results, scraped webpages, or other tool outputs.
- You MUST NOT answer factual or informational questions from your own training data or general knowledge unless a team member explicitly grants permission.
- If the knowledge base search returns no relevant results AND no other tool provides the answer, you MUST:
1. Inform the team that you could not find relevant information in the shared knowledge base.
2. Ask: "Would you like me to answer from my general knowledge instead?"
3. ONLY provide a general-knowledge answer AFTER a team member explicitly says yes.
- This policy does NOT apply to:
* Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?")
* Formatting, summarization, or analysis of content already present in the conversation
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
</knowledge_base_only_policy>
</system_instruction>
"""
@ -67,15 +97,6 @@ _TOOLS_PREAMBLE = """
<tools>
You have access to the following tools:
CRITICAL BEHAVIORAL RULE SEARCH FIRST, ANSWER LATER:
For ANY user query that is ambiguous, open-ended, or could potentially have relevant context in the
knowledge base, you MUST call `search_knowledge_base` BEFORE attempting to answer from your own
general knowledge. This includes (but is not limited to) questions about concepts, topics, projects,
people, events, recommendations, or anything the user might have stored notes/documents about.
Only fall back to your own general knowledge if the search returns NO relevant results.
Do NOT skip the search and answer directly the user's knowledge base may contain personalized,
up-to-date, or domain-specific information that is more relevant than your general training data.
IMPORTANT: You can ONLY use the tools listed below. If a capability is not listed here, you do NOT have it.
Do NOT claim you can do something if the corresponding tool is not listed.
@ -92,29 +113,6 @@ _TOOL_INSTRUCTIONS["search_surfsense_docs"] = """
- Returns: Documentation content with chunk IDs for citations (prefixed with 'doc-', e.g., [citation:doc-123])
"""
_TOOL_INSTRUCTIONS["search_knowledge_base"] = """
- search_knowledge_base: Search the user's personal knowledge base for relevant information.
- DEFAULT ACTION: For any user question or ambiguous query, ALWAYS call this tool first to check
for relevant context before answering from general knowledge. When in doubt, search.
- IMPORTANT: When searching for information (meetings, schedules, notes, tasks, etc.), ALWAYS search broadly
across ALL sources first by omitting connectors_to_search. The user may store information in various places
including calendar apps, note-taking apps (Obsidian, Notion), chat apps (Slack, Discord), and more.
- This tool searches ONLY local/indexed data (uploaded files, Notion, Slack, browser extension captures, etc.).
For real-time web search (current events, news, live data), use the `web_search` tool instead.
- FALLBACK BEHAVIOR: If the search returns no relevant results, you MAY then answer using your own
general knowledge, but clearly indicate that no matching information was found in the knowledge base.
- Only narrow to specific connectors if the user explicitly asks (e.g., "check my Slack" or "in my calendar").
- Personal notes in Obsidian, Notion, or NOTE often contain schedules, meeting times, reminders, and other
important information that may not be in calendars.
- Args:
- query: The search query - be specific and include key terms
- top_k: Number of results to retrieve (default: 10)
- start_date: Optional ISO date/datetime (e.g. "2025-12-12" or "2025-12-12T00:00:00+00:00")
- end_date: Optional ISO date/datetime (e.g. "2025-12-19" or "2025-12-19T23:59:59+00:00")
- connectors_to_search: Optional list of connector enums to search. If omitted, searches all.
- Returns: Formatted string with relevant documents and their content
"""
_TOOL_INSTRUCTIONS["generate_podcast"] = """
- generate_podcast: Generate an audio podcast from provided content.
- Use this when the user asks to create, generate, or make a podcast.
@ -163,8 +161,8 @@ _TOOL_INSTRUCTIONS["generate_report"] = """
* For source_strategy="kb_search": Can be empty or minimal the tool handles searching internally.
* For source_strategy="auto": Include what you have; the tool searches KB if it's not enough.
- source_strategy: Controls how the tool collects source material. One of:
* "conversation" The conversation already contains enough context (prior Q&A, discussion, pasted text, scraped pages). Pass a thorough summary as source_content. Do NOT call search_knowledge_base separately.
* "kb_search" The tool will search the knowledge base internally. Provide search_queries with 1-5 targeted queries. Do NOT call search_knowledge_base separately.
* "conversation" The conversation already contains enough context (prior Q&A, discussion, pasted text, scraped pages). Pass a thorough summary as source_content.
* "kb_search" The tool will search the knowledge base internally. Provide search_queries with 1-5 targeted queries.
* "auto" Use source_content if sufficient, otherwise fall back to internal KB search using search_queries.
* "provided" Use only what is in source_content (default, backward-compatible).
- search_queries: When source_strategy is "kb_search" or "auto", provide 1-5 specific search queries for the knowledge base. These should be precise, not just the topic name repeated.
@ -176,11 +174,11 @@ _TOOL_INSTRUCTIONS["generate_report"] = """
- The report is generated immediately in Markdown and displayed inline in the chat.
- Export/download formats (PDF, DOCX, HTML, LaTeX, EPUB, ODT, plain text) are produced from the generated Markdown report.
- SOURCE STRATEGY DECISION (HIGH PRIORITY follow this exactly):
* If the conversation already has substantive Q&A / discussion on the topic use source_strategy="conversation" with a comprehensive summary as source_content. Do NOT call search_knowledge_base first.
* If the user wants a report on a topic not yet discussed use source_strategy="kb_search" with targeted search_queries. Do NOT call search_knowledge_base first.
* If the conversation already has substantive Q&A / discussion on the topic use source_strategy="conversation" with a comprehensive summary as source_content.
* If the user wants a report on a topic not yet discussed use source_strategy="kb_search" with targeted search_queries.
* If you have some content but might need more use source_strategy="auto" with both source_content and search_queries.
* When revising an existing report (parent_report_id set) and the conversation has relevant context use source_strategy="conversation". The revision will use the previous report content plus your source_content.
* NEVER call search_knowledge_base and then pass its results to generate_report. The tool handles KB search internally.
* NEVER run a separate KB lookup step and then pass those results to generate_report. The tool handles KB search internally.
- AFTER CALLING THIS TOOL: Do NOT repeat, summarize, or reproduce the report content in the chat. The report is already displayed as an interactive card that the user can open, read, copy, and export. Simply confirm that the report was generated (e.g., "I've generated your report on [topic]. You can view the Markdown report now, and export it in various formats from the card."). NEVER write out the report text in the chat.
"""
@ -204,7 +202,7 @@ _TOOL_INSTRUCTIONS["scrape_webpage"] = """
* When a user asks to "get", "fetch", "pull", "grab", "scrape", or "read" content from a URL
* When the user wants live/dynamic data from a specific webpage (e.g., tables, scores, stats, prices)
* When a URL was mentioned earlier in the conversation and the user asks for its actual content
* When search_knowledge_base returned insufficient data and the user wants more
* When preloaded `/documents/` data is insufficient and the user wants more
- Trigger scenarios:
* "Read this article and summarize it"
* "What does this page say about X?"
@ -366,23 +364,6 @@ _MEMORY_TOOL_EXAMPLES: dict[str, dict[str, str]] = {
# Per-tool examples keyed by tool name. Only examples for enabled tools are included.
_TOOL_EXAMPLES: dict[str, str] = {}
_TOOL_EXAMPLES["search_knowledge_base"] = """
- User: "What time is the team meeting today?"
- Call: `search_knowledge_base(query="team meeting time today")` (searches ALL sources - calendar, notes, Obsidian, etc.)
- DO NOT limit to just calendar - the info might be in notes!
- User: "When is my gym session?"
- Call: `search_knowledge_base(query="gym session time schedule")` (searches ALL sources)
- User: "Fetch all my notes and what's in them?"
- Call: `search_knowledge_base(query="*", top_k=50, connectors_to_search=["NOTE"])`
- User: "What did I discuss on Slack last week about the React migration?"
- Call: `search_knowledge_base(query="React migration", connectors_to_search=["SLACK_CONNECTOR"], start_date="YYYY-MM-DD", end_date="YYYY-MM-DD")`
- User: "Check my Obsidian notes for meeting notes"
- Call: `search_knowledge_base(query="meeting notes", connectors_to_search=["OBSIDIAN_CONNECTOR"])`
- User: "search me current usd to inr rate"
- Call: `web_search(query="current USD to INR exchange rate")`
- Then answer using the returned live web results with citations.
"""
_TOOL_EXAMPLES["search_surfsense_docs"] = """
- User: "How do I install SurfSense?"
- Call: `search_surfsense_docs(query="installation setup")`
@ -400,8 +381,7 @@ _TOOL_EXAMPLES["generate_podcast"] = """
- User: "Create a podcast summary of this conversation"
- Call: `generate_podcast(source_content="Complete conversation summary:\\n\\nUser asked about [topic 1]:\\n[Your detailed response]\\n\\nUser then asked about [topic 2]:\\n[Your detailed response]\\n\\n[Continue for all exchanges in the conversation]", podcast_title="Conversation Summary")`
- User: "Make a podcast about quantum computing"
- First search: `search_knowledge_base(query="quantum computing")`
- Then: `generate_podcast(source_content="Key insights about quantum computing from the knowledge base:\\n\\n[Comprehensive summary of all relevant search results with key facts, concepts, and findings]", podcast_title="Quantum Computing Explained")`
- First explore `/documents/` (ls/glob/grep/read_file), then: `generate_podcast(source_content="Key insights about quantum computing from retrieved files:\\n\\n[Comprehensive summary of findings]", podcast_title="Quantum Computing Explained")`
"""
_TOOL_EXAMPLES["generate_video_presentation"] = """
@ -410,8 +390,7 @@ _TOOL_EXAMPLES["generate_video_presentation"] = """
- User: "Create slides summarizing this conversation"
- Call: `generate_video_presentation(source_content="Complete conversation summary:\\n\\nUser asked about [topic 1]:\\n[Your detailed response]\\n\\nUser then asked about [topic 2]:\\n[Your detailed response]\\n\\n[Continue for all exchanges in the conversation]", video_title="Conversation Summary")`
- User: "Make a video presentation about quantum computing"
- First search: `search_knowledge_base(query="quantum computing")`
- Then: `generate_video_presentation(source_content="Key insights about quantum computing from the knowledge base:\\n\\n[Comprehensive summary of all relevant search results with key facts, concepts, and findings]", video_title="Quantum Computing Explained")`
- First explore `/documents/` (ls/glob/grep/read_file), then: `generate_video_presentation(source_content="Key insights about quantum computing from retrieved files:\\n\\n[Comprehensive summary of findings]", video_title="Quantum Computing Explained")`
"""
_TOOL_EXAMPLES["generate_report"] = """
@ -471,7 +450,6 @@ _TOOL_EXAMPLES["web_search"] = """
# All tool names that have prompt instructions (order matters for prompt readability)
_ALL_TOOL_NAMES_ORDERED = [
"search_surfsense_docs",
"search_knowledge_base",
"web_search",
"generate_podcast",
"generate_video_presentation",
@ -650,87 +628,6 @@ However, from your video learning, it's important to note that asyncio is not su
</citation_instructions>
"""
# Sandbox / code execution instructions — appended when sandbox backend is enabled.
# Inspired by Claude's computer-use prompt, scoped to code execution & data analytics.
SANDBOX_EXECUTION_INSTRUCTIONS = """
<code_execution>
You have access to a secure, isolated Linux sandbox environment for running code and shell commands.
This gives you the `execute` tool alongside the standard filesystem tools (`ls`, `read_file`, `write_file`, `edit_file`, `glob`, `grep`).
## CRITICAL — CODE-FIRST RULE
ALWAYS prefer executing code over giving a text-only response when the user's request involves ANY of the following:
- **Creating a chart, plot, graph, or visualization** Write Python code and generate the actual file. NEVER describe percentages or data in text and offer to "paste into Excel". Just produce the chart.
- **Data analysis, statistics, or computation** Write code to compute the answer. Do not do math by hand in text.
- **Generating or transforming files** (CSV, PDF, images, etc.) Write code to create the file.
- **Running, testing, or debugging code** Execute it in the sandbox.
This applies even when you first retrieve data from the knowledge base. After `search_knowledge_base` returns relevant data, **immediately proceed to write and execute code** if the user's request matches any of the categories above. Do NOT stop at a text summary and wait for the user to ask you to "use Python" — that extra round-trip is a poor experience.
Example (CORRECT):
User: "Create a pie chart of my benefits"
1. search_knowledge_base retrieve benefits data
2. Immediately execute Python code (matplotlib) to generate the pie chart
3. Return the downloadable file + brief description
Example (WRONG):
User: "Create a pie chart of my benefits"
1. search_knowledge_base retrieve benefits data
2. Print a text table with percentages and ask the user if they want a chart NEVER do this
## When to Use Code Execution
Use the sandbox when the task benefits from actually running code rather than just describing it:
- **Data analysis**: Load CSVs/JSON, compute statistics, filter/aggregate data, pivot tables
- **Visualization**: Generate charts and plots (matplotlib, plotly, seaborn)
- **Calculations**: Math, financial modeling, unit conversions, simulations
- **Code validation**: Run and test code snippets the user provides or asks about
- **File processing**: Parse, transform, or convert data files
- **Quick prototyping**: Demonstrate working code for the user's problem
- **Package exploration**: Install and test libraries the user is evaluating
## When NOT to Use Code Execution
Do not use the sandbox for:
- Answering factual questions from your own knowledge
- Summarizing or explaining concepts
- Simple formatting or text generation tasks
- Tasks that don't require running code to answer
## Package Management
- Use `pip install <package>` to install Python packages as needed
- Common data/analytics packages (pandas, numpy, matplotlib, scipy, scikit-learn) may need to be installed on first use
- Always verify a package installed successfully before using it
## Working Guidelines
- **Working directory**: The shell starts in the sandbox user's home directory (e.g. `/home/daytona`). Use **relative paths** or `/tmp/` for all files you create. NEVER write directly to `/home/` — that is the parent directory and is not writable. Use `pwd` if you need to discover the current working directory.
- **Iterative approach**: For complex tasks, break work into steps write code, run it, check output, refine
- **Error handling**: If code fails, read the error, fix the issue, and retry. Don't just report the error without attempting a fix.
- **Show results**: When generating plots or outputs, present the key findings directly in your response. For plots, save to a file and describe the results.
- **Be efficient**: Install packages once per session. Combine related commands when possible.
- **Large outputs**: If command output is very large, use `head`, `tail`, or save to a file and read selectively.
## Sharing Generated Files
When your code creates output files (images, CSVs, PDFs, etc.) in the sandbox:
- **Print the absolute path** at the end of your script so the user can download the file. Example: `print("SANDBOX_FILE: /tmp/chart.png")`
- **DO NOT use markdown image syntax** for files created inside the sandbox. Sandbox files are not accessible via public URLs and will show "Image not available". The frontend automatically renders a download button from the `SANDBOX_FILE:` marker.
- You can output multiple files, one per line: `print("SANDBOX_FILE: /tmp/report.csv")`, `print("SANDBOX_FILE: /tmp/chart.png")`
- Always describe what the file contains in your response text so the user knows what they are downloading.
- IMPORTANT: Every `execute` call that saves a file MUST print the `SANDBOX_FILE: <path>` marker. Without it the user cannot download the file.
## Data Analytics Best Practices
When the user asks you to analyze data:
1. First, inspect the data structure (`head`, `shape`, `dtypes`, `describe()`)
2. Clean and validate before computing (handle nulls, check types)
3. Perform the analysis and present results clearly
4. Offer follow-up insights or visualizations when appropriate
</code_execution>
"""
# Anti-citation prompt - used when citations are disabled
# This explicitly tells the model NOT to include citations
SURFSENSE_NO_CITATION_INSTRUCTIONS = """
@ -756,7 +653,6 @@ Your goal is to provide helpful, informative answers in a clean, readable format
def build_surfsense_system_prompt(
today: datetime | None = None,
thread_visibility: ChatVisibility | None = None,
sandbox_enabled: bool = False,
enabled_tool_names: set[str] | None = None,
disabled_tool_names: set[str] | None = None,
) -> str:
@ -767,12 +663,10 @@ def build_surfsense_system_prompt(
- Default system instructions
- Tools instructions (only for enabled tools)
- Citation instructions enabled
- Sandbox execution instructions (when sandbox_enabled=True)
Args:
today: Optional datetime for today's date (defaults to current UTC date)
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
sandbox_enabled: Whether the sandbox backend is active (adds code execution instructions).
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
@ -786,13 +680,7 @@ def build_surfsense_system_prompt(
visibility, enabled_tool_names, disabled_tool_names
)
citation_instructions = SURFSENSE_CITATION_INSTRUCTIONS
sandbox_instructions = SANDBOX_EXECUTION_INSTRUCTIONS if sandbox_enabled else ""
return (
system_instructions
+ tools_instructions
+ citation_instructions
+ sandbox_instructions
)
return system_instructions + tools_instructions + citation_instructions
def build_configurable_system_prompt(
@ -801,18 +689,16 @@ def build_configurable_system_prompt(
citations_enabled: bool = True,
today: datetime | None = None,
thread_visibility: ChatVisibility | None = None,
sandbox_enabled: bool = False,
enabled_tool_names: set[str] | None = None,
disabled_tool_names: set[str] | None = None,
) -> str:
"""
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
The prompt is composed of up to four parts:
The prompt is composed of three parts:
1. System Instructions - either custom or default SURFSENSE_SYSTEM_INSTRUCTIONS
2. Tools Instructions - only for enabled tools, with a note about disabled ones
3. Citation Instructions - either SURFSENSE_CITATION_INSTRUCTIONS or SURFSENSE_NO_CITATION_INSTRUCTIONS
4. Sandbox Execution Instructions - when sandbox_enabled=True
Args:
custom_system_instructions: Custom system instructions to use. If empty/None and
@ -824,7 +710,6 @@ def build_configurable_system_prompt(
anti-citation instructions (False).
today: Optional datetime for today's date (defaults to current UTC date)
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
sandbox_enabled: Whether the sandbox backend is active (adds code execution instructions).
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
@ -856,14 +741,7 @@ def build_configurable_system_prompt(
else SURFSENSE_NO_CITATION_INSTRUCTIONS
)
sandbox_instructions = SANDBOX_EXECUTION_INSTRUCTIONS if sandbox_enabled else ""
return (
system_instructions
+ tools_instructions
+ citation_instructions
+ sandbox_instructions
)
return system_instructions + tools_instructions + citation_instructions
def get_default_system_instructions() -> str:

View file

@ -5,7 +5,6 @@ This module contains all the tools available to the SurfSense agent.
To add a new tool, see the documentation in registry.py.
Available tools:
- search_knowledge_base: Search the user's personal knowledge base
- search_surfsense_docs: Search Surfsense documentation for usage help
- generate_podcast: Generate audio podcasts from content
- generate_video_presentation: Generate video presentations with slides and narration
@ -20,7 +19,6 @@ Available tools:
from .generate_image import create_generate_image_tool
from .knowledge_base import (
CONNECTOR_DESCRIPTIONS,
create_search_knowledge_base_tool,
format_documents_for_context,
search_knowledge_base_async,
)
@ -52,7 +50,6 @@ __all__ = [
"create_recall_memory_tool",
"create_save_memory_tool",
"create_scrape_webpage_tool",
"create_search_knowledge_base_tool",
"create_search_surfsense_docs_tool",
"format_documents_for_context",
"get_all_tool_names",

View file

@ -273,9 +273,7 @@ def create_update_calendar_event_tool(
final_new_start_datetime, context
)
if final_new_end_datetime is not None:
update_body["end"] = _build_time_body(
final_new_end_datetime, context
)
update_body["end"] = _build_time_body(final_new_end_datetime, context)
if final_new_description is not None:
update_body["description"] = final_new_description
if final_new_location is not None:

View file

@ -5,7 +5,6 @@ This module provides:
- Connector constants and normalization
- Async knowledge base search across multiple connectors
- Document formatting for LLM context
- Tool factory for creating search_knowledge_base tools
"""
import asyncio
@ -16,8 +15,6 @@ import time
from datetime import datetime
from typing import Any
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import NATIVE_TO_LEGACY_DOCTYPE, shielded_async_session
@ -619,9 +616,76 @@ async def search_knowledge_base_async(
perf = get_perf_logger()
t0 = time.perf_counter()
deduplicated = await search_knowledge_base_raw_async(
query=query,
search_space_id=search_space_id,
db_session=db_session,
connector_service=connector_service,
connectors_to_search=connectors_to_search,
top_k=top_k,
start_date=start_date,
end_date=end_date,
available_connectors=available_connectors,
available_document_types=available_document_types,
)
if not deduplicated:
return "No documents found in the knowledge base. The search space has no indexed content yet."
# Use browse chunk cap for degenerate queries, otherwise adaptive chunking.
max_chunks_per_doc = (
_BROWSE_MAX_CHUNKS_PER_DOC if _is_degenerate_query(query) else 0
)
output_budget = _compute_tool_output_budget(max_input_tokens)
result = format_documents_for_context(
deduplicated,
max_chars=output_budget,
max_chunks_per_doc=max_chunks_per_doc,
)
if len(result) > output_budget:
perf.warning(
"[kb_search] output STILL exceeds budget after format (%d > %d), "
"hard truncation should have fired",
len(result),
output_budget,
)
perf.info(
"[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d "
"budget=%d max_input_tokens=%s space=%d",
time.perf_counter() - t0,
len(deduplicated),
len(deduplicated),
len(result),
output_budget,
max_input_tokens,
search_space_id,
)
return result
async def search_knowledge_base_raw_async(
query: str,
search_space_id: int,
db_session: AsyncSession,
connector_service: ConnectorService,
connectors_to_search: list[str] | None = None,
top_k: int = 10,
start_date: datetime | None = None,
end_date: datetime | None = None,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
query_embedding: list[float] | None = None,
) -> list[dict[str, Any]]:
"""Search knowledge base and return raw document dicts (no XML formatting)."""
perf = get_perf_logger()
t0 = time.perf_counter()
all_documents: list[dict[str, Any]] = []
# Resolve date range (default last 2 years)
# Preserve the public signature for compatibility even if values are unused.
_ = (db_session, connector_service)
from app.agents.new_chat.utils import resolve_date_range
resolved_start_date, resolved_end_date = resolve_date_range(
@ -631,144 +695,76 @@ async def search_knowledge_base_async(
connectors = _normalize_connectors(connectors_to_search, available_connectors)
# --- Optimization 1: skip connectors that have zero indexed documents ---
if available_document_types:
doc_types_set = set(available_document_types)
before_count = len(connectors)
connectors = [
c
for c in connectors
if c in doc_types_set
or NATIVE_TO_LEGACY_DOCTYPE.get(c, "") in doc_types_set
]
skipped = before_count - len(connectors)
if skipped:
perf.info(
"[kb_search] skipped %d empty connectors (had %d, now %d)",
skipped,
before_count,
len(connectors),
)
perf.info(
"[kb_search] searching %d connectors: %s (space=%d, top_k=%d)",
len(connectors),
connectors[:5],
search_space_id,
top_k,
)
# --- Fast-path: no connectors left after filtering ---
if not connectors:
perf.info(
"[kb_search] TOTAL in %.3fs — no connectors to search, returning empty",
time.perf_counter() - t0,
)
return "No documents found in the knowledge base. The search space has no indexed content yet."
return []
# --- Fast-path: degenerate queries (*, **, empty, etc.) ---
# Semantic embedding of '*' is noise and plainto_tsquery('english', '*')
# yields an empty tsquery, so both retrieval signals are useless.
# Fall back to a recency-ordered browse that returns diverse results.
if _is_degenerate_query(query):
perf.info(
"[kb_search] degenerate query %r detected - falling back to recency browse",
"[kb_search_raw] degenerate query %r detected - recency browse",
query,
)
browse_connectors = connectors if connectors else [None] # type: ignore[list-item]
expanded_browse = []
for c in browse_connectors:
if c is not None and c in NATIVE_TO_LEGACY_DOCTYPE:
expanded_browse.append([c, NATIVE_TO_LEGACY_DOCTYPE[c]])
for connector in browse_connectors:
if connector is not None and connector in NATIVE_TO_LEGACY_DOCTYPE:
expanded_browse.append([connector, NATIVE_TO_LEGACY_DOCTYPE[connector]])
else:
expanded_browse.append(c)
expanded_browse.append(connector)
browse_results = await asyncio.gather(
*[
_browse_recent_documents(
search_space_id=search_space_id,
document_type=c,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
for c in expanded_browse
]
)
for docs in browse_results:
all_documents.extend(docs)
# Skip dedup + formatting below (browse already returns unique docs)
# but still cap output budget.
output_budget = _compute_tool_output_budget(max_input_tokens)
result = format_documents_for_context(
all_documents,
max_chars=output_budget,
max_chunks_per_doc=_BROWSE_MAX_CHUNKS_PER_DOC,
)
perf.info(
"[kb_search] TOTAL (browse) in %.3fs total_docs=%d output_chars=%d "
"budget=%d space=%d",
time.perf_counter() - t0,
len(all_documents),
len(result),
output_budget,
search_space_id,
)
return result
# --- Optimization 2: compute the query embedding once, share across all local searches ---
from app.config import config as app_config
t_embed = time.perf_counter()
precomputed_embedding = app_config.embedding_model_instance.embed(query)
perf.info(
"[kb_search] shared embedding computed in %.3fs",
time.perf_counter() - t_embed,
)
max_parallel_searches = 4
semaphore = asyncio.Semaphore(max_parallel_searches)
async def _search_one_connector(connector: str) -> list[dict[str, Any]]:
try:
t_conn = time.perf_counter()
async with semaphore, shielded_async_session() as isolated_session:
svc = ConnectorService(isolated_session, search_space_id)
chunks = await svc._combined_rrf_search(
query_text=query,
search_space_id=search_space_id,
document_type=connector,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
query_embedding=precomputed_embedding,
)
perf.info(
"[kb_search] connector=%s results=%d in %.3fs",
connector,
len(chunks),
time.perf_counter() - t_conn,
)
return chunks
except Exception as e:
perf.warning("[kb_search] connector=%s FAILED: %s", connector, e)
return []
for connector in expanded_browse
]
)
for docs in browse_results:
all_documents.extend(docs)
else:
if query_embedding is None:
from app.config import config as app_config
t_gather = time.perf_counter()
connector_results = await asyncio.gather(
*[_search_one_connector(connector) for connector in connectors]
)
perf.info(
"[kb_search] all connectors gathered in %.3fs",
time.perf_counter() - t_gather,
)
for chunks in connector_results:
all_documents.extend(chunks)
query_embedding = app_config.embedding_model_instance.embed(query)
max_parallel_searches = 4
semaphore = asyncio.Semaphore(max_parallel_searches)
async def _search_one_connector(connector: str) -> list[dict[str, Any]]:
try:
async with semaphore, shielded_async_session() as isolated_session:
svc = ConnectorService(isolated_session, search_space_id)
return await svc._combined_rrf_search(
query_text=query,
search_space_id=search_space_id,
document_type=connector,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
query_embedding=query_embedding,
)
except Exception as exc:
perf.warning("[kb_search_raw] connector=%s FAILED: %s", connector, exc)
return []
connector_results = await asyncio.gather(
*[_search_one_connector(connector) for connector in connectors]
)
for docs in connector_results:
all_documents.extend(docs)
# Deduplicate primarily by document ID. Only fall back to content hashing
# when a document has no ID.
seen_doc_ids: set[Any] = set()
seen_content_hashes: set[int] = set()
deduplicated: list[dict[str, Any]] = []
@ -785,7 +781,6 @@ async def search_knowledge_base_async(
chunk_texts.append(chunk_content)
if chunk_texts:
return hash("||".join(chunk_texts))
flat_content = (document.get("content") or "").strip()
if flat_content:
return hash(flat_content)
@ -793,216 +788,24 @@ async def search_knowledge_base_async(
for doc in all_documents:
doc_id = (doc.get("document", {}) or {}).get("id")
if doc_id is not None:
if doc_id in seen_doc_ids:
continue
seen_doc_ids.add(doc_id)
deduplicated.append(doc)
continue
content_hash = _content_fingerprint(doc)
if content_hash is not None and content_hash in seen_content_hashes:
continue
if content_hash is not None:
if content_hash in seen_content_hashes:
continue
seen_content_hashes.add(content_hash)
deduplicated.append(doc)
# Sort by RRF score so the most relevant documents from ANY connector
# appear first, preventing budget truncation from hiding top results.
deduplicated.sort(key=lambda d: d.get("score", 0), reverse=True)
output_budget = _compute_tool_output_budget(max_input_tokens)
result = format_documents_for_context(deduplicated, max_chars=output_budget)
if len(result) > output_budget:
perf.warning(
"[kb_search] output STILL exceeds budget after format (%d > %d), "
"hard truncation should have fired",
len(result),
output_budget,
)
deduplicated.sort(key=lambda doc: doc.get("score", 0), reverse=True)
perf.info(
"[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d "
"budget=%d max_input_tokens=%s space=%d",
"[kb_search_raw] done in %.3fs total=%d deduped=%d",
time.perf_counter() - t0,
len(all_documents),
len(deduplicated),
len(result),
output_budget,
max_input_tokens,
search_space_id,
)
return result
def _build_connector_docstring(available_connectors: list[str] | None) -> str:
"""
Build the connector documentation section for the tool docstring.
Args:
available_connectors: List of available connector types, or None for all
Returns:
Formatted docstring section listing available connectors
"""
connectors = available_connectors if available_connectors else list(_ALL_CONNECTORS)
lines = []
for connector in connectors:
# Skip internal names, prefer user-facing aliases
if connector == "CRAWLED_URL":
# Show as WEBCRAWLER_CONNECTOR for user-facing docs
description = CONNECTOR_DESCRIPTIONS.get(connector, connector)
lines.append(f"- WEBCRAWLER_CONNECTOR: {description}")
else:
description = CONNECTOR_DESCRIPTIONS.get(connector, connector)
lines.append(f"- {connector}: {description}")
return "\n".join(lines)
# =============================================================================
# Tool Input Schema
# =============================================================================
class SearchKnowledgeBaseInput(BaseModel):
"""Input schema for the search_knowledge_base tool."""
query: str = Field(
description=(
"The search query - use specific natural language terms. "
"NEVER use wildcards like '*' or '**'; instead describe what you want "
"(e.g. 'recent meeting notes' or 'project architecture overview')."
),
)
top_k: int = Field(
default=10,
description="Number of results to retrieve (default: 10). Keep ≤20 for focused searches.",
)
start_date: str | None = Field(
default=None,
description="Optional ISO date/datetime (e.g. '2025-12-12' or '2025-12-12T00:00:00+00:00')",
)
end_date: str | None = Field(
default=None,
description="Optional ISO date/datetime (e.g. '2025-12-19' or '2025-12-19T23:59:59+00:00')",
)
connectors_to_search: list[str] | None = Field(
default=None,
description="Optional list of connector enums to search. If omitted, searches all available.",
)
def create_search_knowledge_base_tool(
search_space_id: int,
db_session: AsyncSession,
connector_service: ConnectorService,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
max_input_tokens: int | None = None,
) -> StructuredTool:
"""
Factory function to create the search_knowledge_base tool with injected dependencies.
Args:
search_space_id: The user's search space ID
db_session: Database session
connector_service: Initialized connector service
available_connectors: Optional list of connector types available in the search space.
Used to dynamically generate the tool docstring.
available_document_types: Optional list of document types that have data in the search space.
Used to inform the LLM about what data exists.
max_input_tokens: Model context window (tokens) from litellm model info.
Used to dynamically size tool output.
Returns:
A configured StructuredTool instance
"""
# Build connector documentation dynamically
connector_docs = _build_connector_docstring(available_connectors)
# Build context about available document types
doc_types_info = ""
if available_document_types:
doc_types_info = f"""
## Document types with indexed content in this search space
The following document types have content available for search:
{", ".join(available_document_types)}
Focus searches on these types for best results."""
# Build the dynamic description for the tool
# This is what the LLM sees when deciding whether/how to use the tool
dynamic_description = f"""Search the user's personal knowledge base for relevant information.
Use this tool to find documents, notes, files, web pages, and other content the user has indexed.
This searches ONLY local/indexed data (uploaded files, Notion, Slack, browser extension captures, etc.).
For real-time web search (current events, news, live data), use the `web_search` tool instead.
IMPORTANT:
- Always craft specific, descriptive search queries using natural language keywords.
Good: "quarterly sales report Q3", "Python API authentication design".
Bad: "*", "**", "everything", single characters. Wildcard/empty queries yield poor results.
- Prefer multiple focused searches over a single broad one with high top_k.
- If the user requests a specific source type (e.g. "my notes", "Slack messages"), pass `connectors_to_search=[...]` using the enums below.
- If `connectors_to_search` is omitted/empty, the system will search broadly.
- Only connectors that are enabled/configured for this search space are available.{doc_types_info}
## Available connector enums for `connectors_to_search`
{connector_docs}
NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type `CRAWLED_URL`."""
# Capture for closure
_available_connectors = available_connectors
_available_document_types = available_document_types
async def _search_knowledge_base_impl(
query: str,
top_k: int = 10,
start_date: str | None = None,
end_date: str | None = None,
connectors_to_search: list[str] | None = None,
) -> str:
"""Implementation function for knowledge base search."""
from app.agents.new_chat.utils import parse_date_or_datetime
parsed_start: datetime | None = None
parsed_end: datetime | None = None
if start_date:
parsed_start = parse_date_or_datetime(start_date)
if end_date:
parsed_end = parse_date_or_datetime(end_date)
return await search_knowledge_base_async(
query=query,
search_space_id=search_space_id,
db_session=db_session,
connector_service=connector_service,
connectors_to_search=connectors_to_search,
top_k=top_k,
start_date=parsed_start,
end_date=parsed_end,
available_connectors=_available_connectors,
available_document_types=_available_document_types,
max_input_tokens=max_input_tokens,
)
# Create StructuredTool with dynamic description
# This properly sets the description that the LLM sees
tool = StructuredTool(
name="search_knowledge_base",
description=dynamic_description,
coroutine=_search_knowledge_base_impl,
args_schema=SearchKnowledgeBaseInput,
)
return tool
return deduplicated

View file

@ -71,7 +71,6 @@ from .jira import (
create_delete_jira_issue_tool,
create_update_jira_issue_tool,
)
from .knowledge_base import create_search_knowledge_base_tool
from .linear import (
create_create_linear_issue_tool,
create_delete_linear_issue_tool,
@ -128,23 +127,6 @@ class ToolDefinition:
# Registry of all built-in tools
# Contributors: Add your new tools here!
BUILTIN_TOOLS: list[ToolDefinition] = [
# Core tool - searches the user's knowledge base
# Now supports dynamic connector/document type discovery
ToolDefinition(
name="search_knowledge_base",
description="Search the user's personal knowledge base for relevant information",
factory=lambda deps: create_search_knowledge_base_tool(
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
connector_service=deps["connector_service"],
# Optional: dynamically discovered connectors/document types
available_connectors=deps.get("available_connectors"),
available_document_types=deps.get("available_document_types"),
max_input_tokens=deps.get("max_input_tokens"),
),
requires=["search_space_id", "db_session", "connector_service"],
# Note: available_connectors and available_document_types are optional
),
# Podcast generation tool
ToolDefinition(
name="generate_podcast",
@ -168,8 +150,8 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
requires=["search_space_id", "db_session", "thread_id"],
),
# Report generation tool (inline, short-lived sessions for DB ops)
# Supports internal KB search via source_strategy so the agent doesn't
# need to call search_knowledge_base separately before generating.
# Supports internal KB search via source_strategy so the agent does not
# need a separate search step before generating.
ToolDefinition(
name="generate_report",
description="Generate a structured report from provided content and export it",
@ -551,7 +533,7 @@ def build_tools(
tools = build_tools(deps)
# Use only specific tools
tools = build_tools(deps, enabled_tools=["search_knowledge_base"])
tools = build_tools(deps, enabled_tools=["generate_report"])
# Use defaults but disable podcast
tools = build_tools(deps, disabled_tools=["generate_podcast"])

View file

@ -584,8 +584,8 @@ def create_generate_report_tool(
search_space_id: The user's search space ID
thread_id: The chat thread ID for associating the report
connector_service: Optional connector service for internal KB search.
When provided, the tool can search the knowledge base without the
agent having to call search_knowledge_base separately.
When provided, the tool can search the knowledge base internally
(used by the "kb_search" and "auto" source strategies).
available_connectors: Optional list of connector types available in the
search space (used to scope internal KB searches).
@ -639,12 +639,13 @@ def create_generate_report_tool(
SOURCE STRATEGY (how to collect source material):
- source_strategy="conversation" The conversation already has
enough context (prior Q&A, pasted text, uploaded files, scraped
webpages). Pass a thorough summary as source_content.
NEVER call search_knowledge_base separately first.
enough context (prior Q&A, filesystem exploration, pasted text,
uploaded files, scraped webpages). Pass a thorough summary as
source_content.
- source_strategy="kb_search" Search the knowledge base
internally. Provide 1-5 targeted search_queries. The tool
handles searching do NOT call search_knowledge_base first.
handles searching internally do NOT manually read and dump
/documents/ files into source_content.
- source_strategy="provided" Use only what is in source_content
(default, backward-compatible).
- source_strategy="auto" Use source_content if it has enough
@ -1064,6 +1065,7 @@ def create_generate_report_tool(
"title": topic,
"word_count": metadata.get("word_count", 0),
"is_revision": bool(parent_report_content),
"report_markdown": report_content,
"message": f"Report generated successfully: {topic}",
}

View file

@ -137,9 +137,7 @@ async def _filter_changes_by_folder(
continue
parents = file.get("parents", [])
if folder_id in parents:
filtered.append(change)
elif await _is_descendant_of(client, parents, folder_id):
if folder_id in parents or await _is_descendant_of(client, parents, folder_id):
filtered.append(change)
return filtered

View file

@ -157,7 +157,9 @@ class GoogleDriveClient:
@staticmethod
def _sync_download_file(
service, file_id: str, credentials: Credentials,
service,
file_id: str,
credentials: Credentials,
) -> tuple[bytes | None, str | None]:
"""Blocking download — runs on a worker thread via ``to_thread``."""
thread = threading.current_thread().name
@ -180,7 +182,9 @@ class GoogleDriveClient:
except Exception as e:
return None, f"Error downloading file: {e!s}"
finally:
logger.info(f"[download] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s")
logger.info(
f"[download] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s"
)
async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]:
"""
@ -194,12 +198,18 @@ class GoogleDriveClient:
"""
service = await self.get_service()
return await asyncio.to_thread(
self._sync_download_file, service, file_id, self._resolved_credentials,
self._sync_download_file,
service,
file_id,
self._resolved_credentials,
)
@staticmethod
def _sync_download_file_to_disk(
service, file_id: str, dest_path: str, chunksize: int,
service,
file_id: str,
dest_path: str,
chunksize: int,
credentials: Credentials,
) -> str | None:
"""Blocking download-to-disk — runs on a worker thread via ``to_thread``."""
@ -223,10 +233,15 @@ class GoogleDriveClient:
except Exception as e:
return f"Error downloading file: {e!s}"
finally:
logger.info(f"[download-to-disk] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s")
logger.info(
f"[download-to-disk] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s"
)
async def download_file_to_disk(
self, file_id: str, dest_path: str, chunksize: int = 5 * 1024 * 1024,
self,
file_id: str,
dest_path: str,
chunksize: int = 5 * 1024 * 1024,
) -> str | None:
"""Stream file directly to disk in chunks, avoiding full in-memory buffering.
@ -235,12 +250,19 @@ class GoogleDriveClient:
service = await self.get_service()
return await asyncio.to_thread(
self._sync_download_file_to_disk,
service, file_id, dest_path, chunksize, self._resolved_credentials,
service,
file_id,
dest_path,
chunksize,
self._resolved_credentials,
)
@staticmethod
def _sync_export_google_file(
service, file_id: str, mime_type: str, credentials: Credentials,
service,
file_id: str,
mime_type: str,
credentials: Credentials,
) -> tuple[bytes | None, str | None]:
"""Blocking export — runs on a worker thread via ``to_thread``."""
thread = threading.current_thread().name
@ -261,7 +283,9 @@ class GoogleDriveClient:
except Exception as e:
return None, f"Error exporting file: {e!s}"
finally:
logger.info(f"[export] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s")
logger.info(
f"[export] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s"
)
async def export_google_file(
self, file_id: str, mime_type: str
@ -278,7 +302,10 @@ class GoogleDriveClient:
"""
service = await self.get_service()
return await asyncio.to_thread(
self._sync_export_google_file, service, file_id, mime_type,
self._sync_export_google_file,
service,
file_id,
mime_type,
self._resolved_credentials,
)

View file

@ -1,6 +1,7 @@
"""Content extraction for Google Drive files."""
import asyncio
import contextlib
import logging
import os
import tempfile
@ -72,7 +73,11 @@ async def download_and_extract_content(
if is_google_workspace_file(mime_type):
export_mime = get_export_mime_type(mime_type)
if not export_mime:
return None, drive_metadata, f"Cannot export Google Workspace type: {mime_type}"
return (
None,
drive_metadata,
f"Cannot export Google Workspace type: {mime_type}",
)
content_bytes, error = await client.export_google_file(file_id, export_mime)
if error:
return None, drive_metadata, error
@ -83,9 +88,7 @@ async def download_and_extract_content(
temp_file_path = tmp.name
else:
extension = (
Path(file_name).suffix
or get_extension_from_mime(mime_type)
or ".bin"
Path(file_name).suffix or get_extension_from_mime(mime_type) or ".bin"
)
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp:
temp_file_path = tmp.name
@ -102,10 +105,8 @@ async def download_and_extract_content(
return None, drive_metadata, str(e)
finally:
if temp_file_path and os.path.exists(temp_file_path):
try:
with contextlib.suppress(Exception):
os.unlink(temp_file_path)
except Exception:
pass
async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
@ -117,9 +118,10 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
return f.read()
if lower.endswith((".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")):
from app.config import config as app_config
from litellm import atranscription
from app.config import config as app_config
stt_service_type = (
"local"
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
@ -127,10 +129,15 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
)
if stt_service_type == "local":
from app.services.stt_service import stt_service
t0 = time.monotonic()
logger.info(f"[local-stt] START file={filename} thread={threading.current_thread().name}")
logger.info(
f"[local-stt] START file={filename} thread={threading.current_thread().name}"
)
result = await asyncio.to_thread(stt_service.transcribe_file, file_path)
logger.info(f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s")
logger.info(
f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s"
)
text = result.get("text", "")
else:
with open(file_path, "rb") as audio_file:
@ -153,6 +160,7 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
if app_config.ETL_SERVICE == "UNSTRUCTURED":
from langchain_unstructured import UnstructuredLoader
from app.utils.document_converters import convert_document_to_markdown
loader = UnstructuredLoader(
@ -172,7 +180,9 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
parse_with_llamacloud_retry,
)
result = await parse_with_llamacloud_retry(file_path=file_path, estimated_pages=50)
result = await parse_with_llamacloud_retry(
file_path=file_path, estimated_pages=50
)
markdown_documents = await result.aget_markdown_documents(split_by_page=False)
if not markdown_documents:
raise RuntimeError(f"LlamaCloud returned no documents for {filename}")
@ -183,9 +193,13 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
converter = DocumentConverter()
t0 = time.monotonic()
logger.info(f"[docling] START file={filename} thread={threading.current_thread().name}")
logger.info(
f"[docling] START file={filename} thread={threading.current_thread().name}"
)
result = await asyncio.to_thread(converter.convert, file_path)
logger.info(f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s")
logger.info(
f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s"
)
return result.document.export_to_markdown()
raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}")
@ -249,9 +263,7 @@ async def download_and_process_file(
return None, error
extension = (
Path(file_name).suffix
or get_extension_from_mime(mime_type)
or ".bin"
Path(file_name).suffix or get_extension_from_mime(mime_type) or ".bin"
)
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp_file:
@ -290,9 +302,7 @@ async def download_and_process_file(
connector_info["metadata"]["md5_checksum"] = file["md5Checksum"]
if is_google_workspace_file(mime_type):
export_ext = get_extension_from_mime(
get_export_mime_type(mime_type) or ""
)
export_ext = get_extension_from_mime(get_export_mime_type(mime_type) or "")
connector_info["metadata"]["exported_as"] = (
export_ext.lstrip(".") if export_ext else "pdf"
)

View file

@ -13,7 +13,9 @@ def compute_identifier_hash(
def compute_unique_identifier_hash(doc: ConnectorDocument) -> str:
"""Return a stable SHA-256 hash identifying a document by its source identity."""
return compute_identifier_hash(doc.document_type.value, doc.unique_id, doc.search_space_id)
return compute_identifier_hash(
doc.document_type.value, doc.unique_id, doc.search_space_id
)
def compute_content_hash(doc: ConnectorDocument) -> str:

View file

@ -1,15 +1,23 @@
import asyncio
import contextlib
import hashlib
import logging
import time
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from datetime import UTC, datetime
from sqlalchemy import delete, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import NATIVE_TO_LEGACY_DOCTYPE, Chunk, Document, DocumentStatus
from app.db import (
NATIVE_TO_LEGACY_DOCTYPE,
Chunk,
Document,
DocumentStatus,
DocumentType,
)
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_chunker import chunk_text
from app.indexing_pipeline.document_embedder import embed_texts
@ -52,12 +60,114 @@ from app.indexing_pipeline.pipeline_logger import (
from app.utils.perf import get_perf_logger
@dataclass
class PlaceholderInfo:
"""Minimal info to create a placeholder document row for instant UI feedback.
These are created immediately when items are discovered (before content
extraction) so users see them in the UI via Zero sync right away.
"""
title: str
document_type: DocumentType
unique_id: str
search_space_id: int
connector_id: int | None
created_by_id: str
metadata: dict = field(default_factory=dict)
class IndexingPipelineService:
"""Single pipeline for indexing connector documents. All connectors use this service."""
def __init__(self, session: AsyncSession) -> None:
self.session = session
async def create_placeholder_documents(
self, placeholders: list[PlaceholderInfo]
) -> int:
"""Create placeholder document rows with pending status for instant UI feedback.
These rows appear immediately in the UI via Zero sync. They are later
updated by prepare_for_indexing() when actual content is available.
Returns the number of placeholders successfully created.
Failures are logged but never block the main indexing flow.
NOTE: This method commits on ``self.session`` so the rows become
visible to Zero sync immediately. Any pending ORM mutations on the
session are committed together, which is consistent with how other
mid-flow commits work in the indexing codebase (e.g. rename-only
updates in ``_should_skip_file``, ``migrate_legacy_docs``).
"""
if not placeholders:
return 0
_logger = logging.getLogger(__name__)
uid_hashes: dict[str, PlaceholderInfo] = {}
for p in placeholders:
try:
uid_hash = compute_identifier_hash(
p.document_type.value, p.unique_id, p.search_space_id
)
uid_hashes.setdefault(uid_hash, p)
except Exception:
_logger.debug(
"Skipping placeholder hash for %s", p.unique_id, exc_info=True
)
if not uid_hashes:
return 0
result = await self.session.execute(
select(Document.unique_identifier_hash).where(
Document.unique_identifier_hash.in_(list(uid_hashes.keys()))
)
)
existing_hashes: set[str] = set(result.scalars().all())
created = 0
for uid_hash, p in uid_hashes.items():
if uid_hash in existing_hashes:
continue
try:
content_hash = hashlib.sha256(
f"placeholder:{uid_hash}".encode()
).hexdigest()
document = Document(
title=p.title,
document_type=p.document_type,
content="Pending...",
content_hash=content_hash,
unique_identifier_hash=uid_hash,
document_metadata=p.metadata or {},
search_space_id=p.search_space_id,
connector_id=p.connector_id,
created_by_id=p.created_by_id,
updated_at=datetime.now(UTC),
status=DocumentStatus.pending(),
)
self.session.add(document)
created += 1
except Exception:
_logger.debug("Skipping placeholder for %s", p.unique_id, exc_info=True)
if created > 0:
try:
await self.session.commit()
_logger.info(
"Created %d placeholder document(s) for instant UI feedback",
created,
)
except IntegrityError:
await self.session.rollback()
_logger.debug("Placeholder commit failed (race condition), continuing")
created = 0
return created
async def migrate_legacy_docs(
self, connector_docs: list[ConnectorDocument]
) -> None:
@ -77,9 +187,7 @@ class IndexingPipelineService:
legacy_type, doc.unique_id, doc.search_space_id
)
result = await self.session.execute(
select(Document).filter(
Document.unique_identifier_hash == legacy_hash
)
select(Document).filter(Document.unique_identifier_hash == legacy_hash)
)
existing = result.scalars().first()
if existing is None:
@ -101,9 +209,7 @@ class IndexingPipelineService:
Indexers that need heartbeat callbacks or custom per-document logic
should call prepare_for_indexing() + index() directly instead.
"""
doc_map = {
compute_unique_identifier_hash(cd): cd for cd in connector_docs
}
doc_map = {compute_unique_identifier_hash(cd): cd for cd in connector_docs}
documents = await self.prepare_for_indexing(connector_docs)
results: list[Document] = []
for document in documents:
@ -166,6 +272,21 @@ class IndexingPipelineService:
log_document_requeued(ctx)
continue
dup_check = await self.session.execute(
select(Document.id).filter(
Document.content_hash == content_hash,
Document.id != existing.id,
)
)
if dup_check.scalars().first() is not None:
if not DocumentStatus.is_state(
existing.status, DocumentStatus.READY
):
existing.status = DocumentStatus.failed(
"Duplicate content — already indexed by another document"
)
continue
existing.title = connector_doc.title
existing.content_hash = content_hash
existing.source_markdown = connector_doc.source_markdown
@ -349,9 +470,7 @@ class IndexingPipelineService:
perf = get_perf_logger()
t_total = time.perf_counter()
doc_map = {
compute_unique_identifier_hash(cd): cd for cd in connector_docs
}
doc_map = {compute_unique_identifier_hash(cd): cd for cd in connector_docs}
documents = await self.prepare_for_indexing(connector_docs)
if not documents:
@ -383,9 +502,7 @@ class IndexingPipelineService:
session_maker = get_celery_session_maker()
async with session_maker() as isolated_session:
try:
refetched = await isolated_session.get(
Document, document.id
)
refetched = await isolated_session.get(Document, document.id)
if refetched is None:
async with lock:
failed_count += 1
@ -393,9 +510,7 @@ class IndexingPipelineService:
llm = await get_llm(isolated_session)
iso_pipeline = IndexingPipelineService(isolated_session)
result = await iso_pipeline.index(
refetched, connector_doc, llm
)
result = await iso_pipeline.index(refetched, connector_doc, llm)
async with lock:
if DocumentStatus.is_state(

View file

@ -5,7 +5,7 @@ from datetime import datetime
from app.utils.perf import get_perf_logger
_MAX_FETCH_CHUNKS_PER_DOC = 30
_MAX_FETCH_CHUNKS_PER_DOC = 20
class ChucksHybridSearchRetriever:
@ -185,7 +185,7 @@ class ChucksHybridSearchRetriever:
- chunks: list[{chunk_id, content}] for citation-aware prompting
- document: {id, title, document_type, metadata}
"""
from sqlalchemy import func, select, text
from sqlalchemy import func, or_, select, text
from sqlalchemy.orm import joinedload
from app.config import config
@ -360,64 +360,81 @@ class ChucksHybridSearchRetriever:
if not doc_ids:
return []
# Fetch chunks for selected documents. We cap per document to avoid
# loading hundreds of chunks for a single large file while still
# ensuring the chunks that matched the RRF query are always included.
chunk_query = (
select(Chunk)
.options(joinedload(Chunk.document))
.join(Document, Chunk.document_id == Document.id)
.where(Document.id.in_(doc_ids))
.where(*base_conditions)
.order_by(Chunk.document_id, Chunk.id)
)
chunks_result = await self.db_session.execute(chunk_query)
raw_chunks = chunks_result.scalars().all()
# Collect document metadata from the small RRF result set (already
# loaded via joinedload) so the bulk chunk fetch can skip the expensive
# Document JOIN entirely.
matched_chunk_ids: set[int] = {
item["chunk_id"] for item in serialized_chunk_results
}
doc_meta_cache: dict[int, dict] = {}
for item in serialized_chunk_results:
did = item["document"]["id"]
if did not in doc_meta_cache:
doc_meta_cache[did] = item["document"]
doc_chunk_counts: dict[int, int] = {}
all_chunks: list = []
for chunk in raw_chunks:
did = chunk.document_id
count = doc_chunk_counts.get(did, 0)
if chunk.id in matched_chunk_ids or count < _MAX_FETCH_CHUNKS_PER_DOC:
all_chunks.append(chunk)
doc_chunk_counts[did] = count + 1
# SQL-level per-document chunk limit using ROW_NUMBER().
# Avoids loading hundreds of chunks per large document only to
# discard them in Python.
numbered = (
select(
Chunk.id.label("chunk_id"),
func.row_number()
.over(partition_by=Chunk.document_id, order_by=Chunk.id)
.label("rn"),
)
.where(Chunk.document_id.in_(doc_ids))
.subquery("numbered")
)
# Assemble final doc-grouped results in the same order as doc_ids
matched_list = list(matched_chunk_ids)
if matched_list:
chunk_filter = or_(
numbered.c.rn <= _MAX_FETCH_CHUNKS_PER_DOC,
Chunk.id.in_(matched_list),
)
else:
chunk_filter = numbered.c.rn <= _MAX_FETCH_CHUNKS_PER_DOC
# Select only the columns we need (skip Chunk.embedding ~12KB/row).
chunk_query = (
select(Chunk.id, Chunk.content, Chunk.document_id)
.join(numbered, Chunk.id == numbered.c.chunk_id)
.where(chunk_filter)
.order_by(Chunk.document_id, Chunk.id)
)
t_fetch = time.perf_counter()
chunks_result = await self.db_session.execute(chunk_query)
fetched_chunks = chunks_result.all()
perf.debug(
"[chunk_search] chunk fetch in %.3fs rows=%d",
time.perf_counter() - t_fetch,
len(fetched_chunks),
)
# Assemble final doc-grouped results in the same order as doc_ids,
# using pre-cached doc metadata instead of joinedload.
doc_map: dict[int, dict] = {
doc_id: {
"document_id": doc_id,
"content": "",
"score": float(doc_scores.get(doc_id, 0.0)),
"chunks": [],
"document": {},
"source": None,
"matched_chunk_ids": [],
"document": doc_meta_cache.get(doc_id, {}),
"source": (doc_meta_cache.get(doc_id) or {}).get("document_type"),
}
for doc_id in doc_ids
}
for chunk in all_chunks:
doc = chunk.document
doc_id = doc.id
for row in fetched_chunks:
doc_id = row.document_id
if doc_id not in doc_map:
continue
doc_entry = doc_map[doc_id]
doc_entry["document"] = {
"id": doc.id,
"title": doc.title,
"document_type": doc.document_type.value
if getattr(doc, "document_type", None)
else None,
"metadata": doc.document_metadata or {},
}
doc_entry["source"] = (
doc.document_type.value if getattr(doc, "document_type", None) else None
)
doc_entry["chunks"].append({"chunk_id": chunk.id, "content": chunk.content})
doc_entry["chunks"].append({"chunk_id": row.id, "content": row.content})
if row.id in matched_chunk_ids:
doc_entry["matched_chunk_ids"].append(row.id)
# Fill concatenated content (useful for reranking)
final_docs: list[dict] = []

View file

@ -4,7 +4,7 @@ from datetime import datetime
from app.utils.perf import get_perf_logger
_MAX_FETCH_CHUNKS_PER_DOC = 30
_MAX_FETCH_CHUNKS_PER_DOC = 20
class DocumentHybridSearchRetriever:
@ -289,57 +289,77 @@ class DocumentHybridSearchRetriever:
if not documents_with_scores:
return []
# Collect document IDs for chunk fetching
# Collect document IDs and pre-cache metadata from the small RRF
# result set so the bulk chunk fetch can skip joinedload entirely.
doc_ids: list[int] = [doc.id for doc, _score in documents_with_scores]
# Fetch chunks for these documents, capped per document to avoid
# loading hundreds of chunks for a single large file.
chunks_query = (
select(Chunk)
.options(joinedload(Chunk.document))
.where(Chunk.document_id.in_(doc_ids))
.order_by(Chunk.document_id, Chunk.id)
)
chunks_result = await self.db_session.execute(chunks_query)
raw_chunks = chunks_result.scalars().all()
doc_chunk_counts: dict[int, int] = {}
chunks: list = []
for chunk in raw_chunks:
did = chunk.document_id
count = doc_chunk_counts.get(did, 0)
if count < _MAX_FETCH_CHUNKS_PER_DOC:
chunks.append(chunk)
doc_chunk_counts[did] = count + 1
# Assemble doc-grouped results
doc_map: dict[int, dict] = {
doc.id: {
"document_id": doc.id,
"content": "",
"score": float(score),
"chunks": [],
"document": {
"id": doc.id,
"title": doc.title,
"document_type": doc.document_type.value
if getattr(doc, "document_type", None)
else None,
"metadata": doc.document_metadata or {},
},
"source": doc.document_type.value
doc_meta_cache: dict[int, dict] = {}
doc_score_cache: dict[int, float] = {}
doc_source_cache: dict[int, str | None] = {}
for doc, score in documents_with_scores:
doc_meta_cache[doc.id] = {
"id": doc.id,
"title": doc.title,
"document_type": doc.document_type.value
if getattr(doc, "document_type", None)
else None,
"metadata": doc.document_metadata or {},
}
for doc, score in documents_with_scores
doc_score_cache[doc.id] = float(score)
doc_source_cache[doc.id] = (
doc.document_type.value if getattr(doc, "document_type", None) else None
)
# SQL-level per-document chunk limit using ROW_NUMBER().
# Avoids loading hundreds of chunks per large document only to
# discard them in Python.
numbered = (
select(
Chunk.id.label("chunk_id"),
func.row_number()
.over(partition_by=Chunk.document_id, order_by=Chunk.id)
.label("rn"),
)
.where(Chunk.document_id.in_(doc_ids))
.subquery("numbered")
)
# Select only the columns we need (skip Chunk.embedding ~12KB/row).
chunks_query = (
select(Chunk.id, Chunk.content, Chunk.document_id)
.join(numbered, Chunk.id == numbered.c.chunk_id)
.where(numbered.c.rn <= _MAX_FETCH_CHUNKS_PER_DOC)
.order_by(Chunk.document_id, Chunk.id)
)
t_fetch = time.perf_counter()
chunks_result = await self.db_session.execute(chunks_query)
fetched_chunks = chunks_result.all()
perf.debug(
"[doc_search] chunk fetch in %.3fs rows=%d",
time.perf_counter() - t_fetch,
len(fetched_chunks),
)
# Assemble doc-grouped results using pre-cached metadata.
doc_map: dict[int, dict] = {
doc_id: {
"document_id": doc_id,
"content": "",
"score": doc_score_cache.get(doc_id, 0.0),
"chunks": [],
"matched_chunk_ids": [],
"document": doc_meta_cache.get(doc_id, {}),
"source": doc_source_cache.get(doc_id),
}
for doc_id in doc_ids
}
for chunk in chunks:
doc_id = chunk.document_id
for row in fetched_chunks:
doc_id = row.document_id
if doc_id not in doc_map:
continue
doc_map[doc_id]["chunks"].append(
{"chunk_id": chunk.id, "content": chunk.content}
{"chunk_id": row.id, "content": row.content}
)
# Fill concatenated content (useful for reranking)

View file

@ -7,7 +7,6 @@ import asyncio
import io
import logging
import os
import re
import tempfile
from datetime import UTC, datetime
from typing import Any
@ -22,9 +21,9 @@ from sqlalchemy.orm import selectinload
from app.db import Document, DocumentType, Permission, User, get_async_session
from app.routes.reports_routes import (
ExportFormat,
_FILE_EXTENSIONS,
_MEDIA_TYPES,
ExportFormat,
_normalize_latex_delimiters,
_strip_wrapping_code_fences,
)
@ -238,9 +237,7 @@ async def save_document(
}
@router.get(
"/search-spaces/{search_space_id}/documents/{document_id}/export"
)
@router.get("/search-spaces/{search_space_id}/documents/{document_id}/export")
async def export_document(
search_space_id: int,
document_id: int,
@ -284,9 +281,7 @@ async def export_document(
markdown_content = "\n\n".join(chunk.content for chunk in chunks)
if not markdown_content or not markdown_content.strip():
raise HTTPException(
status_code=400, detail="Document has no content to export"
)
raise HTTPException(status_code=400, detail="Document has no content to export")
markdown_content = _strip_wrapping_code_fences(markdown_content)
markdown_content = _normalize_latex_delimiters(markdown_content)
@ -308,8 +303,10 @@ async def export_document(
extra_args=[
"--standalone",
f"--template={typst_template}",
"-V", "mainfont:Libertinus Serif",
"-V", "codefont:DejaVu Sans Mono",
"-V",
"mainfont:Libertinus Serif",
"-V",
"codefont:DejaVu Sans Mono",
*meta_args,
],
)
@ -318,7 +315,11 @@ async def export_document(
if format == ExportFormat.DOCX:
return _pandoc_to_tempfile(
format.value,
["--standalone", f"--reference-doc={get_reference_docx_path()}", *meta_args],
[
"--standalone",
f"--reference-doc={get_reference_docx_path()}",
*meta_args,
],
)
if format == ExportFormat.HTML:
@ -327,7 +328,8 @@ async def export_document(
"html5",
format=input_fmt,
extra_args=[
"--standalone", "--embed-resources",
"--standalone",
"--embed-resources",
f"--css={get_html_css_path()}",
"--syntax-highlighting=pygments",
*meta_args,
@ -343,13 +345,17 @@ async def export_document(
if format == ExportFormat.LATEX:
tex_str: str = pypandoc.convert_text(
markdown_content, "latex", format=input_fmt,
markdown_content,
"latex",
format=input_fmt,
extra_args=["--standalone", *meta_args],
)
return tex_str.encode("utf-8")
plain_str: str = pypandoc.convert_text(
markdown_content, "plain", format=input_fmt,
markdown_content,
"plain",
format=input_fmt,
extra_args=["--wrap=auto", "--columns=80"],
)
return plain_str.encode("utf-8")
@ -359,8 +365,11 @@ async def export_document(
os.close(fd)
try:
pypandoc.convert_text(
markdown_content, output_format, format=input_fmt,
extra_args=extra_args, outputfile=tmp_path,
markdown_content,
output_format,
format=input_fmt,
extra_args=extra_args,
outputfile=tmp_path,
)
with open(tmp_path, "rb") as f:
return f.read()
@ -375,8 +384,7 @@ async def export_document(
raise HTTPException(status_code=500, detail=f"Export failed: {e!s}") from e
safe_title = (
"".join(c if c.isalnum() or c in " -_" else "_" for c in doc_title)
.strip()[:80]
"".join(c if c.isalnum() or c in " -_" else "_" for c in doc_title).strip()[:80]
or "document"
)
ext = _FILE_EXTENSIONS[format]

View file

@ -2406,7 +2406,11 @@ async def run_google_drive_indexing(
if items.files:
try:
file_tuples = [(f.id, f.name) for f in items.files]
indexed_count, _skipped, file_errors = await index_google_drive_selected_files(
(
indexed_count,
_skipped,
file_errors,
) = await index_google_drive_selected_files(
session,
connector_id,
search_space_id,

View file

@ -9,6 +9,7 @@ Supports loading LLM configurations from:
- NewLLMConfig database table (positive IDs for user-created configs with prompt settings)
"""
import ast
import asyncio
import contextlib
import gc
@ -36,10 +37,6 @@ from app.agents.new_chat.llm_config import (
load_agent_config,
load_llm_config_from_yaml,
)
from app.agents.new_chat.sandbox import (
get_or_create_sandbox,
is_sandbox_enabled,
)
from app.db import (
ChatVisibility,
Document,
@ -212,7 +209,7 @@ class StreamResult:
accumulated_text: str = ""
is_interrupted: bool = False
interrupt_value: dict[str, Any] | None = None
sandbox_files: list[str] = field(default_factory=list)
sandbox_files: list[str] = field(default_factory=list) # unused, kept for compat
async def _stream_agent_events(
@ -281,6 +278,8 @@ async def _stream_agent_events(
if event_type == "on_chat_model_stream":
if active_tool_depth > 0:
continue # Suppress inner-tool LLM tokens from leaking into chat
if "surfsense:internal" in event.get("tags", []):
continue # Suppress middleware-internal LLM tokens (e.g. KB search classification)
chunk = event.get("data", {}).get("chunk")
if chunk and hasattr(chunk, "content"):
content = chunk.content
@ -319,19 +318,114 @@ async def _stream_agent_events(
tool_step_ids[run_id] = tool_step_id
last_active_step_id = tool_step_id
if tool_name == "search_knowledge_base":
query = (
tool_input.get("query", "")
if tool_name == "ls":
ls_path = (
tool_input.get("path", "/")
if isinstance(tool_input, dict)
else str(tool_input)
)
last_active_step_title = "Searching knowledge base"
last_active_step_title = "Listing files"
last_active_step_items = [ls_path]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Listing files",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "read_file":
fp = (
tool_input.get("file_path", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
display_fp = fp if len(fp) <= 80 else "" + fp[-77:]
last_active_step_title = "Reading file"
last_active_step_items = [display_fp]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Reading file",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "write_file":
fp = (
tool_input.get("file_path", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
display_fp = fp if len(fp) <= 80 else "" + fp[-77:]
last_active_step_title = "Writing file"
last_active_step_items = [display_fp]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Writing file",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "edit_file":
fp = (
tool_input.get("file_path", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
display_fp = fp if len(fp) <= 80 else "" + fp[-77:]
last_active_step_title = "Editing file"
last_active_step_items = [display_fp]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Editing file",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "glob":
pat = (
tool_input.get("pattern", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
base_path = (
tool_input.get("path", "/") if isinstance(tool_input, dict) else "/"
)
last_active_step_title = "Searching files"
last_active_step_items = [f"{pat} in {base_path}"]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Searching files",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "grep":
pat = (
tool_input.get("pattern", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
grep_path = (
tool_input.get("path", "") if isinstance(tool_input, dict) else ""
)
display_pat = pat[:60] + ("" if len(pat) > 60 else "")
last_active_step_title = "Searching content"
last_active_step_items = [
f"Query: {query[:100]}{'...' if len(query) > 100 else ''}"
f'"{display_pat}"' + (f" in {grep_path}" if grep_path else "")
]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Searching knowledge base",
title="Searching content",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "save_document":
doc_title = (
tool_input.get("title", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
display_title = doc_title[:60] + ("" if len(doc_title) > 60 else "")
last_active_step_title = "Saving document"
last_active_step_items = [display_title]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Saving document",
status="in_progress",
items=last_active_step_items,
)
@ -441,10 +535,22 @@ async def _stream_agent_events(
else streaming_service.generate_tool_call_id()
)
yield streaming_service.format_tool_input_start(tool_call_id, tool_name)
# Sanitize tool_input: strip runtime-injected non-serializable
# values (e.g. LangChain ToolRuntime) before sending over SSE.
if isinstance(tool_input, dict):
_safe_input: dict[str, Any] = {}
for _k, _v in tool_input.items():
try:
json.dumps(_v)
_safe_input[_k] = _v
except (TypeError, ValueError, OverflowError):
pass
else:
_safe_input = {"input": tool_input}
yield streaming_service.format_tool_input_available(
tool_call_id,
tool_name,
tool_input if isinstance(tool_input, dict) else {"input": tool_input},
_safe_input,
)
elif event_type == "on_tool_end":
@ -475,16 +581,55 @@ async def _stream_agent_events(
)
completed_step_ids.add(original_step_id)
if tool_name == "search_knowledge_base":
result_info = "Search completed"
if isinstance(tool_output, dict):
result_len = tool_output.get("result_length", 0)
if result_len > 0:
result_info = f"Found relevant information ({result_len} chars)"
completed_items = [*last_active_step_items, result_info]
if tool_name == "read_file":
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Searching knowledge base",
title="Reading file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "write_file":
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Writing file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "edit_file":
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Editing file",
status="completed",
items=last_active_step_items,
)
elif tool_name == "glob":
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Searching files",
status="completed",
items=last_active_step_items,
)
elif tool_name == "grep":
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Searching content",
status="completed",
items=last_active_step_items,
)
elif tool_name == "save_document":
result_str = (
tool_output.get("result", "")
if isinstance(tool_output, dict)
else str(tool_output)
)
is_error = "Error" in result_str
completed_items = [
*last_active_step_items,
result_str[:80] if is_error else "Saved to knowledge base",
]
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Saving document",
status="completed",
items=completed_items,
)
@ -690,14 +835,23 @@ async def _stream_agent_events(
ls_output = str(tool_output) if tool_output else ""
file_names: list[str] = []
if ls_output:
for line in ls_output.strip().split("\n"):
line = line.strip()
if line:
name = line.rstrip("/").split("/")[-1]
if name and len(name) <= 40:
file_names.append(name)
elif name:
file_names.append(name[:37] + "...")
paths: list[str] = []
try:
parsed = ast.literal_eval(ls_output)
if isinstance(parsed, list):
paths = [str(p) for p in parsed]
except (ValueError, SyntaxError):
paths = [
line.strip()
for line in ls_output.strip().split("\n")
if line.strip()
]
for p in paths:
name = p.rstrip("/").split("/")[-1]
if name and len(name) <= 40:
file_names.append(name)
elif name:
file_names.append(name[:37] + "...")
if file_names:
if len(file_names) <= 5:
completed_items = [f"[{name}]" for name in file_names]
@ -708,7 +862,7 @@ async def _stream_agent_events(
completed_items = ["No files found"]
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Exploring files",
title="Listing files",
status="completed",
items=completed_items,
)
@ -832,14 +986,6 @@ async def _stream_agent_events(
f"Scrape failed: {error_msg}",
"error",
)
elif tool_name == "search_knowledge_base":
yield streaming_service.format_tool_output_available(
tool_call_id,
{"status": "completed", "result_length": len(str(tool_output))},
)
yield streaming_service.format_terminal_info(
"Knowledge base search completed", "success"
)
elif tool_name == "generate_report":
# Stream the full report result so frontend can render the ReportCard
yield streaming_service.format_tool_output_available(
@ -973,6 +1119,19 @@ async def _stream_agent_events(
items=last_active_step_items,
)
elif (
event_type == "on_custom_event" and event.get("name") == "document_created"
):
data = event.get("data", {})
if data.get("id"):
yield streaming_service.format_data(
"documents-updated",
{
"action": "created",
"document": data,
},
)
elif event_type in ("on_chain_end", "on_agent_end"):
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
@ -995,38 +1154,6 @@ async def _stream_agent_events(
yield streaming_service.format_interrupt_request(result.interrupt_value)
def _try_persist_and_delete_sandbox(
thread_id: int,
sandbox_files: list[str],
) -> None:
"""Fire-and-forget: persist sandbox files locally then delete the sandbox."""
from app.agents.new_chat.sandbox import (
is_sandbox_enabled,
persist_and_delete_sandbox,
)
if not is_sandbox_enabled():
return
async def _run() -> None:
try:
await persist_and_delete_sandbox(thread_id, sandbox_files)
except Exception:
logging.getLogger(__name__).warning(
"persist_and_delete_sandbox failed for thread %s",
thread_id,
exc_info=True,
)
try:
loop = asyncio.get_running_loop()
task = loop.create_task(_run())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
except RuntimeError:
pass
async def stream_new_chat(
user_query: str,
search_space_id: int,
@ -1141,22 +1268,6 @@ async def stream_new_chat(
"[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
sandbox_backend = None
_t0 = time.perf_counter()
if is_sandbox_enabled():
try:
sandbox_backend = await get_or_create_sandbox(chat_id)
except Exception as sandbox_err:
logging.getLogger(__name__).warning(
"Sandbox creation failed, continuing without execute tool: %s",
sandbox_err,
)
_perf_log.info(
"[stream_new_chat] Sandbox provisioning in %.3fs (enabled=%s)",
time.perf_counter() - _t0,
sandbox_backend is not None,
)
visibility = thread_visibility or ChatVisibility.PRIVATE
_t0 = time.perf_counter()
agent = await create_surfsense_deep_agent(
@ -1170,7 +1281,6 @@ async def stream_new_chat(
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
sandbox_backend=sandbox_backend,
disabled_tools=disabled_tools,
)
_perf_log.info(
@ -1531,8 +1641,6 @@ async def stream_new_chat(
"Failed to clear AI responding state for thread %s", chat_id
)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
with contextlib.suppress(Exception):
session.expunge_all()
@ -1541,7 +1649,7 @@ async def stream_new_chat(
# Break circular refs held by the agent graph, tools, and LLM
# wrappers so the GC can reclaim them in a single pass.
agent = llm = connector_service = sandbox_backend = None
agent = llm = connector_service = None
input_state = stream_result = None
session = None
@ -1627,22 +1735,6 @@ async def stream_resume_chat(
"[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
)
sandbox_backend = None
_t0 = time.perf_counter()
if is_sandbox_enabled():
try:
sandbox_backend = await get_or_create_sandbox(chat_id)
except Exception as sandbox_err:
logging.getLogger(__name__).warning(
"Sandbox creation failed, continuing without execute tool: %s",
sandbox_err,
)
_perf_log.info(
"[stream_resume] Sandbox provisioning in %.3fs (enabled=%s)",
time.perf_counter() - _t0,
sandbox_backend is not None,
)
visibility = thread_visibility or ChatVisibility.PRIVATE
_t0 = time.perf_counter()
@ -1657,7 +1749,6 @@ async def stream_resume_chat(
agent_config=agent_config,
firecrawl_api_key=firecrawl_api_key,
thread_visibility=visibility,
sandbox_backend=sandbox_backend,
)
_perf_log.info(
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
@ -1742,15 +1833,13 @@ async def stream_resume_chat(
"Failed to clear AI responding state for thread %s", chat_id
)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
with contextlib.suppress(Exception):
session.expunge_all()
with contextlib.suppress(Exception):
await session.close()
agent = llm = connector_service = sandbox_backend = None
agent = llm = connector_service = None
stream_result = None
session = None

View file

@ -10,7 +10,10 @@ from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
@ -194,6 +197,27 @@ async def index_confluence_pages(
await confluence_client.close()
return 0, 0, None
# ── Create placeholders for instant UI feedback ───────────────
pipeline = IndexingPipelineService(session)
placeholders = [
PlaceholderInfo(
title=page.get("title", ""),
document_type=DocumentType.CONFLUENCE_CONNECTOR,
unique_id=page.get("id", ""),
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
metadata={
"page_id": page.get("id", ""),
"connector_id": connector_id,
"connector_type": "Confluence",
},
)
for page in pages
if page.get("id") and page.get("title")
]
await pipeline.create_placeholder_documents(placeholders)
documents_skipped = 0
duplicate_content_count = 0
connector_docs: list[ConnectorDocument] = []
@ -202,7 +226,7 @@ async def index_confluence_pages(
try:
page_id = page.get("id")
page_title = page.get("title", "")
space_id = page.get("spaceId", "")
page.get("spaceId", "")
if not page_id or not page_title:
logger.warning(
@ -265,11 +289,12 @@ async def index_confluence_pages(
connector_docs.append(doc)
except Exception as e:
logger.error(f"Error building ConnectorDocument for page: {e!s}", exc_info=True)
logger.error(
f"Error building ConnectorDocument for page: {e!s}", exc_info=True
)
documents_skipped += 1
continue
pipeline = IndexingPipelineService(session)
await pipeline.migrate_legacy_docs(connector_docs)
async def _get_llm(s: AsyncSession):

View file

@ -16,7 +16,10 @@ from app.connectors.google_calendar_connector import GoogleCalendarConnector
from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.google_credentials import (
@ -73,9 +76,7 @@ def _build_connector_doc(
"connector_type": "Google Calendar",
}
fallback_summary = (
f"Google Calendar Event: {event_summary}\n\n{event_markdown}"
)
fallback_summary = f"Google Calendar Event: {event_summary}\n\n{event_markdown}"
return ConnectorDocument(
title=event_summary,
@ -344,6 +345,27 @@ async def index_google_calendar_events(
logger.error(f"Error fetching Google Calendar events: {e!s}", exc_info=True)
return 0, 0, f"Error fetching Google Calendar events: {e!s}"
# ── Create placeholders for instant UI feedback ───────────────
pipeline = IndexingPipelineService(session)
placeholders = [
PlaceholderInfo(
title=event.get("summary", "No Title"),
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
unique_id=event.get("id", ""),
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
metadata={
"event_id": event.get("id", ""),
"connector_id": connector_id,
"connector_type": "Google Calendar",
},
)
for event in events
if event.get("id")
]
await pipeline.create_placeholder_documents(placeholders)
# ── Build ConnectorDocuments ──────────────────────────────────
connector_docs: list[ConnectorDocument] = []
documents_skipped = 0
@ -391,13 +413,13 @@ async def index_google_calendar_events(
connector_docs.append(doc)
except Exception as e:
logger.error(f"Error building ConnectorDocument for event: {e!s}", exc_info=True)
logger.error(
f"Error building ConnectorDocument for event: {e!s}", exc_info=True
)
documents_skipped += 1
continue
# ── Pipeline: migrate legacy docs + parallel index ─────────────
pipeline = IndexingPipelineService(session)
await pipeline.migrate_legacy_docs(connector_docs)
async def _get_llm(s):

View file

@ -29,7 +29,10 @@ from app.connectors.google_drive.file_types import should_skip_file as skip_mime
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.tasks.connector_indexers.base import (
@ -57,6 +60,7 @@ logger = logging.getLogger(__name__)
# Helpers
# ---------------------------------------------------------------------------
async def _should_skip_file(
session: AsyncSession,
file: dict,
@ -97,11 +101,14 @@ async def _should_skip_file(
result = await session.execute(
select(Document).where(
Document.search_space_id == search_space_id,
Document.document_type.in_([
DocumentType.GOOGLE_DRIVE_FILE,
DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
]),
cast(Document.document_metadata["google_drive_file_id"], String) == file_id,
Document.document_type.in_(
[
DocumentType.GOOGLE_DRIVE_FILE,
DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
]
),
cast(Document.document_metadata["google_drive_file_id"], String)
== file_id,
)
)
existing = result.scalar_one_or_none()
@ -191,6 +198,50 @@ def _build_connector_doc(
)
async def _create_drive_placeholders(
session: AsyncSession,
files: list[dict],
*,
connector_id: int,
search_space_id: int,
user_id: str,
) -> None:
"""Create placeholder document rows for discovered Drive files.
Called immediately after file discovery (Phase 1) so documents appear
in the UI via Zero sync before the slow download/ETL phase begins.
"""
if not files:
return
placeholders = []
for file in files:
file_id = file.get("id")
file_name = file.get("name", "Unknown")
if not file_id:
continue
placeholders.append(
PlaceholderInfo(
title=file_name,
document_type=DocumentType.GOOGLE_DRIVE_FILE,
unique_id=file_id,
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
metadata={
"google_drive_file_id": file_id,
"FILE_NAME": file_name,
"connector_id": connector_id,
"connector_type": "Google Drive",
},
)
)
if placeholders:
pipeline = IndexingPipelineService(session)
await pipeline.create_placeholder_documents(placeholders)
async def _download_files_parallel(
drive_client: GoogleDriveClient,
files: list[dict],
@ -246,9 +297,7 @@ async def _download_files_parallel(
failed = 0
for outcome in outcomes:
if isinstance(outcome, Exception):
failed += 1
elif outcome is None:
if isinstance(outcome, Exception) or outcome is None:
failed += 1
else:
results.append(outcome)
@ -300,14 +349,18 @@ async def _process_single_file(
if not documents:
return 0, 1, 0
from app.indexing_pipeline.document_hashing import compute_unique_identifier_hash
from app.indexing_pipeline.document_hashing import (
compute_unique_identifier_hash,
)
doc_map = {compute_unique_identifier_hash(doc): doc}
for document in documents:
connector_doc = doc_map.get(document.unique_identifier_hash)
if not connector_doc:
continue
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
await pipeline.index(document, connector_doc, user_llm)
logger.info(f"Successfully indexed Google Drive file: {file_name}")
@ -335,11 +388,14 @@ async def _remove_document(session: AsyncSession, file_id: str, search_space_id:
result = await session.execute(
select(Document).where(
Document.search_space_id == search_space_id,
Document.document_type.in_([
DocumentType.GOOGLE_DRIVE_FILE,
DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
]),
cast(Document.document_metadata["google_drive_file_id"], String) == file_id,
Document.document_type.in_(
[
DocumentType.GOOGLE_DRIVE_FILE,
DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
]
),
cast(Document.document_metadata["google_drive_file_id"], String)
== file_id,
)
)
existing = result.scalar_one_or_none()
@ -383,7 +439,9 @@ async def _download_and_index(
return await get_user_long_context_llm(s, user_id, search_space_id)
_, batch_indexed, batch_failed = await pipeline.index_batch_parallel(
connector_docs, _get_llm, max_concurrency=3,
connector_docs,
_get_llm,
max_concurrency=3,
on_heartbeat=on_heartbeat,
)
@ -430,10 +488,22 @@ async def _index_selected_files(
files_to_download.append(file)
batch_indexed, failed = await _download_and_index(
drive_client, session, files_to_download,
connector_id=connector_id, search_space_id=search_space_id,
user_id=user_id, enable_summary=enable_summary,
await _create_drive_placeholders(
session,
files_to_download,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
batch_indexed, _failed = await _download_and_index(
drive_client,
session,
files_to_download,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=enable_summary,
on_heartbeat=on_heartbeat,
)
@ -444,6 +514,7 @@ async def _index_selected_files(
# Scan strategies
# ---------------------------------------------------------------------------
async def _index_full_scan(
drive_client: GoogleDriveClient,
session: AsyncSession,
@ -464,7 +535,11 @@ async def _index_full_scan(
await task_logger.log_task_progress(
log_entry,
f"Starting full scan of folder: {folder_name} (include_subfolders={include_subfolders})",
{"stage": "full_scan", "folder_id": folder_id, "include_subfolders": include_subfolders},
{
"stage": "full_scan",
"folder_id": folder_id,
"include_subfolders": include_subfolders,
},
)
# ------------------------------------------------------------------
@ -483,7 +558,10 @@ async def _index_full_scan(
while files_processed < max_files:
files, next_token, error = await get_files_in_folder(
drive_client, cur_id, include_subfolders=True, page_token=page_token,
drive_client,
cur_id,
include_subfolders=True,
page_token=page_token,
)
if error:
logger.error(f"Error listing files in {cur_name}: {error}")
@ -500,7 +578,9 @@ async def _index_full_scan(
mime = file.get("mimeType", "")
if mime == "application/vnd.google-apps.folder":
if include_subfolders:
folders_to_process.append((file["id"], file.get("name", "Unknown")))
folders_to_process.append(
(file["id"], file.get("name", "Unknown"))
)
continue
files_processed += 1
@ -521,24 +601,45 @@ async def _index_full_scan(
if not files_processed and first_error:
err_lower = first_error.lower()
if "401" in first_error or "invalid credentials" in err_lower or "authError" in first_error:
if (
"401" in first_error
or "invalid credentials" in err_lower
or "authError" in first_error
):
raise Exception(
f"Google Drive authentication failed. Please re-authenticate. (Error: {first_error})"
)
raise Exception(f"Failed to list Google Drive files: {first_error}")
# ------------------------------------------------------------------
# Phase 1.5: create placeholders for instant UI feedback
# ------------------------------------------------------------------
await _create_drive_placeholders(
session,
files_to_download,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
# ------------------------------------------------------------------
# Phase 2+3 (parallel): download, ETL, index
# ------------------------------------------------------------------
batch_indexed, failed = await _download_and_index(
drive_client, session, files_to_download,
connector_id=connector_id, search_space_id=search_space_id,
user_id=user_id, enable_summary=enable_summary,
drive_client,
session,
files_to_download,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=enable_summary,
on_heartbeat=on_heartbeat_callback,
)
indexed = renamed_count + batch_indexed
logger.info(f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed")
logger.info(
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
)
return indexed, skipped
@ -565,7 +666,9 @@ async def _index_with_delta_sync(
{"stage": "delta_sync", "start_token": start_page_token},
)
changes, _final_token, error = await fetch_all_changes(drive_client, start_page_token, folder_id)
changes, _final_token, error = await fetch_all_changes(
drive_client, start_page_token, folder_id
)
if error:
err_lower = error.lower()
if "401" in error or "invalid credentials" in err_lower or "authError" in error:
@ -614,18 +717,35 @@ async def _index_with_delta_sync(
files_to_download.append(file)
# ------------------------------------------------------------------
# Phase 1.5: create placeholders for instant UI feedback
# ------------------------------------------------------------------
await _create_drive_placeholders(
session,
files_to_download,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
# ------------------------------------------------------------------
# Phase 2+3 (parallel): download, ETL, index
# ------------------------------------------------------------------
batch_indexed, failed = await _download_and_index(
drive_client, session, files_to_download,
connector_id=connector_id, search_space_id=search_space_id,
user_id=user_id, enable_summary=enable_summary,
drive_client,
session,
files_to_download,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=enable_summary,
on_heartbeat=on_heartbeat_callback,
)
indexed = renamed_count + batch_indexed
logger.info(f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed")
logger.info(
f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed"
)
return indexed, skipped
@ -633,6 +753,7 @@ async def _index_with_delta_sync(
# Public entry points
# ---------------------------------------------------------------------------
async def index_google_drive_files(
session: AsyncSession,
connector_id: int,
@ -653,8 +774,11 @@ async def index_google_drive_files(
source="connector_indexing_task",
message=f"Starting Google Drive indexing for connector {connector_id}",
metadata={
"connector_id": connector_id, "user_id": str(user_id),
"folder_id": folder_id, "use_delta_sync": use_delta_sync, "max_files": max_files,
"connector_id": connector_id,
"user_id": str(user_id),
"folder_id": folder_id,
"use_delta_sync": use_delta_sync,
"max_files": max_files,
},
)
@ -666,11 +790,14 @@ async def index_google_drive_files(
break
if not connector:
error_msg = f"Google Drive connector with ID {connector_id} not found"
await task_logger.log_task_failure(log_entry, error_msg, None, {"error_type": "ConnectorNotFound"})
await task_logger.log_task_failure(
log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}
)
return 0, 0, error_msg
await task_logger.log_task_progress(
log_entry, f"Initializing Google Drive client for connector {connector_id}",
log_entry,
f"Initializing Google Drive client for connector {connector_id}",
{"stage": "client_initialization"},
)
@ -679,24 +806,39 @@ async def index_google_drive_files(
connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id:
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
await task_logger.log_task_failure(log_entry, error_msg, "Missing Composio account", {"error_type": "MissingComposioAccount"})
await task_logger.log_task_failure(
log_entry,
error_msg,
"Missing Composio account",
{"error_type": "MissingComposioAccount"},
)
return 0, 0, error_msg
pre_built_credentials = build_composio_credentials(connected_account_id)
else:
token_encrypted = connector.config.get("_token_encrypted", False)
if token_encrypted and not config.SECRET_KEY:
await task_logger.log_task_failure(
log_entry, "SECRET_KEY not configured but credentials are encrypted",
"Missing SECRET_KEY", {"error_type": "MissingSecretKey"},
log_entry,
"SECRET_KEY not configured but credentials are encrypted",
"Missing SECRET_KEY",
{"error_type": "MissingSecretKey"},
)
return (
0,
0,
"SECRET_KEY not configured but credentials are marked as encrypted",
)
return 0, 0, "SECRET_KEY not configured but credentials are marked as encrypted"
connector_enable_summary = getattr(connector, "enable_summary", True)
drive_client = GoogleDriveClient(session, connector_id, credentials=pre_built_credentials)
drive_client = GoogleDriveClient(
session, connector_id, credentials=pre_built_credentials
)
if not folder_id:
error_msg = "folder_id is required for Google Drive indexing"
await task_logger.log_task_failure(log_entry, error_msg, {"error_type": "MissingParameter"})
await task_logger.log_task_failure(
log_entry, error_msg, {"error_type": "MissingParameter"}
)
return 0, 0, error_msg
target_folder_id = folder_id
@ -704,29 +846,64 @@ async def index_google_drive_files(
folder_tokens = connector.config.get("folder_tokens", {})
start_page_token = folder_tokens.get(target_folder_id)
can_use_delta = use_delta_sync and start_page_token and connector.last_indexed_at
can_use_delta = (
use_delta_sync and start_page_token and connector.last_indexed_at
)
if can_use_delta:
logger.info(f"Using delta sync for connector {connector_id}")
documents_indexed, documents_skipped = await _index_with_delta_sync(
drive_client, session, connector, connector_id, search_space_id, user_id,
target_folder_id, start_page_token, task_logger, log_entry, max_files,
include_subfolders, on_heartbeat_callback, connector_enable_summary,
drive_client,
session,
connector,
connector_id,
search_space_id,
user_id,
target_folder_id,
start_page_token,
task_logger,
log_entry,
max_files,
include_subfolders,
on_heartbeat_callback,
connector_enable_summary,
)
logger.info("Running reconciliation scan after delta sync")
ri, rs = await _index_full_scan(
drive_client, session, connector, connector_id, search_space_id, user_id,
target_folder_id, target_folder_name, task_logger, log_entry, max_files,
include_subfolders, on_heartbeat_callback, connector_enable_summary,
drive_client,
session,
connector,
connector_id,
search_space_id,
user_id,
target_folder_id,
target_folder_name,
task_logger,
log_entry,
max_files,
include_subfolders,
on_heartbeat_callback,
connector_enable_summary,
)
documents_indexed += ri
documents_skipped += rs
else:
logger.info(f"Using full scan for connector {connector_id}")
documents_indexed, documents_skipped = await _index_full_scan(
drive_client, session, connector, connector_id, search_space_id, user_id,
target_folder_id, target_folder_name, task_logger, log_entry, max_files,
include_subfolders, on_heartbeat_callback, connector_enable_summary,
drive_client,
session,
connector,
connector_id,
search_space_id,
user_id,
target_folder_id,
target_folder_name,
task_logger,
log_entry,
max_files,
include_subfolders,
on_heartbeat_callback,
connector_enable_summary,
)
if documents_indexed > 0 or can_use_delta:
@ -745,26 +922,34 @@ async def index_google_drive_files(
log_entry,
f"Successfully completed Google Drive indexing for connector {connector_id}",
{
"files_processed": documents_indexed, "files_skipped": documents_skipped,
"sync_type": "delta" if can_use_delta else "full", "folder": target_folder_name,
"files_processed": documents_indexed,
"files_skipped": documents_skipped,
"sync_type": "delta" if can_use_delta else "full",
"folder": target_folder_name,
},
)
logger.info(f"Google Drive indexing completed: {documents_indexed} indexed, {documents_skipped} skipped")
logger.info(
f"Google Drive indexing completed: {documents_indexed} indexed, {documents_skipped} skipped"
)
return documents_indexed, documents_skipped, None
except SQLAlchemyError as db_error:
await session.rollback()
await task_logger.log_task_failure(
log_entry, f"Database error during Google Drive indexing for connector {connector_id}",
str(db_error), {"error_type": "SQLAlchemyError"},
log_entry,
f"Database error during Google Drive indexing for connector {connector_id}",
str(db_error),
{"error_type": "SQLAlchemyError"},
)
logger.error(f"Database error: {db_error!s}", exc_info=True)
return 0, 0, f"Database error: {db_error!s}"
except Exception as e:
await session.rollback()
await task_logger.log_task_failure(
log_entry, f"Failed to index Google Drive files for connector {connector_id}",
str(e), {"error_type": type(e).__name__},
log_entry,
f"Failed to index Google Drive files for connector {connector_id}",
str(e),
{"error_type": type(e).__name__},
)
logger.error(f"Failed to index Google Drive files: {e!s}", exc_info=True)
return 0, 0, f"Failed to index Google Drive files: {e!s}"
@ -784,7 +969,12 @@ async def index_google_drive_single_file(
task_name="google_drive_single_file_indexing",
source="connector_indexing_task",
message=f"Starting Google Drive single file indexing for file {file_id}",
metadata={"connector_id": connector_id, "user_id": str(user_id), "file_id": file_id, "file_name": file_name},
metadata={
"connector_id": connector_id,
"user_id": str(user_id),
"file_id": file_id,
"file_name": file_name,
},
)
try:
@ -795,7 +985,9 @@ async def index_google_drive_single_file(
break
if not connector:
error_msg = f"Google Drive connector with ID {connector_id} not found"
await task_logger.log_task_failure(log_entry, error_msg, None, {"error_type": "ConnectorNotFound"})
await task_logger.log_task_failure(
log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}
)
return 0, error_msg
pre_built_credentials = None
@ -803,43 +995,65 @@ async def index_google_drive_single_file(
connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id:
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
await task_logger.log_task_failure(log_entry, error_msg, "Missing Composio account", {"error_type": "MissingComposioAccount"})
await task_logger.log_task_failure(
log_entry,
error_msg,
"Missing Composio account",
{"error_type": "MissingComposioAccount"},
)
return 0, error_msg
pre_built_credentials = build_composio_credentials(connected_account_id)
else:
token_encrypted = connector.config.get("_token_encrypted", False)
if token_encrypted and not config.SECRET_KEY:
await task_logger.log_task_failure(
log_entry, "SECRET_KEY not configured but credentials are encrypted",
"Missing SECRET_KEY", {"error_type": "MissingSecretKey"},
log_entry,
"SECRET_KEY not configured but credentials are encrypted",
"Missing SECRET_KEY",
{"error_type": "MissingSecretKey"},
)
return (
0,
"SECRET_KEY not configured but credentials are marked as encrypted",
)
return 0, "SECRET_KEY not configured but credentials are marked as encrypted"
connector_enable_summary = getattr(connector, "enable_summary", True)
drive_client = GoogleDriveClient(session, connector_id, credentials=pre_built_credentials)
drive_client = GoogleDriveClient(
session, connector_id, credentials=pre_built_credentials
)
file, error = await get_file_by_id(drive_client, file_id)
if error or not file:
error_msg = f"Failed to fetch file {file_id}: {error or 'File not found'}"
await task_logger.log_task_failure(log_entry, error_msg, {"error_type": "FileNotFound"})
await task_logger.log_task_failure(
log_entry, error_msg, {"error_type": "FileNotFound"}
)
return 0, error_msg
display_name = file_name or file.get("name", "Unknown")
indexed, _skipped, failed = await _process_single_file(
drive_client, session, file,
connector_id, search_space_id, user_id, connector_enable_summary,
drive_client,
session,
file,
connector_id,
search_space_id,
user_id,
connector_enable_summary,
)
await session.commit()
if failed > 0:
error_msg = f"Failed to index file {display_name}"
await task_logger.log_task_failure(log_entry, error_msg, {"file_name": display_name, "file_id": file_id})
await task_logger.log_task_failure(
log_entry, error_msg, {"file_name": display_name, "file_id": file_id}
)
return 0, error_msg
if indexed > 0:
await task_logger.log_task_success(
log_entry, f"Successfully indexed file {display_name}",
log_entry,
f"Successfully indexed file {display_name}",
{"file_name": display_name, "file_id": file_id},
)
return 1, None
@ -848,12 +1062,22 @@ async def index_google_drive_single_file(
except SQLAlchemyError as db_error:
await session.rollback()
await task_logger.log_task_failure(log_entry, "Database error during file indexing", str(db_error), {"error_type": "SQLAlchemyError"})
await task_logger.log_task_failure(
log_entry,
"Database error during file indexing",
str(db_error),
{"error_type": "SQLAlchemyError"},
)
logger.error(f"Database error: {db_error!s}", exc_info=True)
return 0, f"Database error: {db_error!s}"
except Exception as e:
await session.rollback()
await task_logger.log_task_failure(log_entry, "Failed to index Google Drive file", str(e), {"error_type": type(e).__name__})
await task_logger.log_task_failure(
log_entry,
"Failed to index Google Drive file",
str(e),
{"error_type": type(e).__name__},
)
logger.error(f"Failed to index Google Drive file: {e!s}", exc_info=True)
return 0, f"Failed to index Google Drive file: {e!s}"
@ -878,7 +1102,11 @@ async def index_google_drive_selected_files(
task_name="google_drive_selected_files_indexing",
source="connector_indexing_task",
message=f"Starting Google Drive batch file indexing for {len(files)} files",
metadata={"connector_id": connector_id, "user_id": str(user_id), "file_count": len(files)},
metadata={
"connector_id": connector_id,
"user_id": str(user_id),
"file_count": len(files),
},
)
try:
@ -889,7 +1117,9 @@ async def index_google_drive_selected_files(
break
if not connector:
error_msg = f"Google Drive connector with ID {connector_id} not found"
await task_logger.log_task_failure(log_entry, error_msg, None, {"error_type": "ConnectorNotFound"})
await task_logger.log_task_failure(
log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}
)
return 0, 0, [error_msg]
pre_built_credentials = None
@ -897,25 +1127,41 @@ async def index_google_drive_selected_files(
connected_account_id = connector.config.get("composio_connected_account_id")
if not connected_account_id:
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
await task_logger.log_task_failure(log_entry, error_msg, "Missing Composio account", {"error_type": "MissingComposioAccount"})
await task_logger.log_task_failure(
log_entry,
error_msg,
"Missing Composio account",
{"error_type": "MissingComposioAccount"},
)
return 0, 0, [error_msg]
pre_built_credentials = build_composio_credentials(connected_account_id)
else:
token_encrypted = connector.config.get("_token_encrypted", False)
if token_encrypted and not config.SECRET_KEY:
error_msg = "SECRET_KEY not configured but credentials are marked as encrypted"
error_msg = (
"SECRET_KEY not configured but credentials are marked as encrypted"
)
await task_logger.log_task_failure(
log_entry, error_msg, "Missing SECRET_KEY", {"error_type": "MissingSecretKey"},
log_entry,
error_msg,
"Missing SECRET_KEY",
{"error_type": "MissingSecretKey"},
)
return 0, 0, [error_msg]
connector_enable_summary = getattr(connector, "enable_summary", True)
drive_client = GoogleDriveClient(session, connector_id, credentials=pre_built_credentials)
drive_client = GoogleDriveClient(
session, connector_id, credentials=pre_built_credentials
)
indexed, skipped, errors = await _index_selected_files(
drive_client, session, files,
connector_id=connector_id, search_space_id=search_space_id,
user_id=user_id, enable_summary=connector_enable_summary,
drive_client,
session,
files,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
enable_summary=connector_enable_summary,
on_heartbeat=on_heartbeat_callback,
)
@ -935,18 +1181,24 @@ async def index_google_drive_selected_files(
{"indexed": indexed, "skipped": skipped},
)
logger.info(f"Selected files indexing: {indexed} indexed, {skipped} skipped, {len(errors)} errors")
logger.info(
f"Selected files indexing: {indexed} indexed, {skipped} skipped, {len(errors)} errors"
)
return indexed, skipped, errors
except SQLAlchemyError as db_error:
await session.rollback()
error_msg = f"Database error: {db_error!s}"
await task_logger.log_task_failure(log_entry, error_msg, str(db_error), {"error_type": "SQLAlchemyError"})
await task_logger.log_task_failure(
log_entry, error_msg, str(db_error), {"error_type": "SQLAlchemyError"}
)
logger.error(error_msg, exc_info=True)
return 0, 0, [error_msg]
except Exception as e:
await session.rollback()
error_msg = f"Failed to index Google Drive files: {e!s}"
await task_logger.log_task_failure(log_entry, error_msg, str(e), {"error_type": type(e).__name__})
await task_logger.log_task_failure(
log_entry, error_msg, str(e), {"error_type": type(e).__name__}
)
logger.error(error_msg, exc_info=True)
return 0, 0, [error_msg]

View file

@ -16,7 +16,10 @@ from app.connectors.google_gmail_connector import GoogleGmailConnector
from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.google_credentials import (
@ -282,6 +285,34 @@ async def index_google_gmail_messages(
logger.info(f"Found {len(messages)} Google gmail messages to index")
# ── Create placeholders for instant UI feedback ───────────────
pipeline = IndexingPipelineService(session)
def _gmail_subject(msg: dict) -> str:
for h in msg.get("payload", {}).get("headers", []):
if h.get("name", "").lower() == "subject":
return h.get("value", "No Subject")
return "No Subject"
placeholders = [
PlaceholderInfo(
title=_gmail_subject(msg),
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
unique_id=msg.get("id", ""),
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
metadata={
"message_id": msg.get("id", ""),
"connector_id": connector_id,
"connector_type": "Google Gmail",
},
)
for msg in messages
if msg.get("id")
]
await pipeline.create_placeholder_documents(placeholders)
# ── Build ConnectorDocuments ──────────────────────────────────
connector_docs: list[ConnectorDocument] = []
documents_skipped = 0
@ -327,13 +358,14 @@ async def index_google_gmail_messages(
connector_docs.append(doc)
except Exception as e:
logger.error(f"Error building ConnectorDocument for message: {e!s}", exc_info=True)
logger.error(
f"Error building ConnectorDocument for message: {e!s}",
exc_info=True,
)
documents_skipped += 1
continue
# ── Pipeline: migrate legacy docs + parallel index ─────────────
pipeline = IndexingPipelineService(session)
await pipeline.migrate_legacy_docs(connector_docs)
async def _get_llm(s):

View file

@ -10,7 +10,10 @@ from app.connectors.jira_history import JiraHistoryConnector
from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
@ -191,6 +194,27 @@ async def index_jira_issues(
await jira_client.close()
return 0, 0, None
# ── Create placeholders for instant UI feedback ───────────────
pipeline = IndexingPipelineService(session)
placeholders = [
PlaceholderInfo(
title=f"{issue.get('key', '')}: {issue.get('id', '')}",
document_type=DocumentType.JIRA_CONNECTOR,
unique_id=issue.get("key", ""),
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
metadata={
"issue_id": issue.get("key", ""),
"connector_id": connector_id,
"connector_type": "Jira",
},
)
for issue in issues
if issue.get("key") and issue.get("id")
]
await pipeline.create_placeholder_documents(placeholders)
connector_docs: list[ConnectorDocument] = []
documents_skipped = 0
duplicate_content_count = 0
@ -253,7 +277,6 @@ async def index_jira_issues(
documents_skipped += 1
continue
pipeline = IndexingPipelineService(session)
await pipeline.migrate_legacy_docs(connector_docs)
async def _get_llm(s: AsyncSession):

View file

@ -14,7 +14,10 @@ from app.connectors.linear_connector import LinearConnector
from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
@ -199,9 +202,7 @@ async def index_linear_issues(
logger.info(f"Retrieved {len(issues)} issues from Linear API")
except Exception as e:
logger.error(
f"Exception when calling Linear API: {e!s}", exc_info=True
)
logger.error(f"Exception when calling Linear API: {e!s}", exc_info=True)
return 0, 0, f"Failed to get Linear issues: {e!s}"
if not issues:
@ -213,6 +214,28 @@ async def index_linear_issues(
await session.commit()
return 0, 0, None
# ── Create placeholders for instant UI feedback ───────────────
pipeline = IndexingPipelineService(session)
placeholders = [
PlaceholderInfo(
title=f"{issue.get('identifier', '')}: {issue.get('title', '')}",
document_type=DocumentType.LINEAR_CONNECTOR,
unique_id=issue.get("id", ""),
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
metadata={
"issue_id": issue.get("id", ""),
"issue_identifier": issue.get("identifier", ""),
"connector_id": connector_id,
"connector_type": "Linear",
},
)
for issue in issues
if issue.get("id") and issue.get("title")
]
await pipeline.create_placeholder_documents(placeholders)
# ── Build ConnectorDocuments ──────────────────────────────────
connector_docs: list[ConnectorDocument] = []
documents_skipped = 0
@ -238,9 +261,7 @@ async def index_linear_issues(
continue
formatted_issue = linear_client.format_issue(issue)
issue_content = linear_client.format_issue_to_markdown(
formatted_issue
)
issue_content = linear_client.format_issue_to_markdown(formatted_issue)
if not issue_content:
logger.warning(
@ -284,8 +305,6 @@ async def index_linear_issues(
continue
# ── Pipeline: migrate legacy docs + parallel index ────────────
pipeline = IndexingPipelineService(session)
await pipeline.migrate_legacy_docs(connector_docs)
async def _get_llm(s):
@ -302,9 +321,7 @@ async def index_linear_issues(
# ── Finalize ──────────────────────────────────────────────────
await update_connector_last_indexed(session, connector, update_last_indexed)
logger.info(
f"Final commit: Total {documents_indexed} Linear issues processed"
)
logger.info(f"Final commit: Total {documents_indexed} Linear issues processed")
try:
await session.commit()
logger.info(

View file

@ -15,7 +15,10 @@ from app.connectors.notion_history import NotionHistoryConnector
from app.db import DocumentType, SearchSourceConnectorType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import compute_content_hash
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
from app.services.llm_service import get_user_long_context_llm
from app.services.task_logging_service import TaskLoggingService
from app.utils.notion_utils import process_blocks
@ -245,13 +248,32 @@ async def index_notion_pages(
{"pages_found": 0},
)
logger.info("No Notion pages found to index")
await update_connector_last_indexed(
session, connector, update_last_indexed
)
await update_connector_last_indexed(session, connector, update_last_indexed)
await session.commit()
await notion_client.close()
return 0, 0, None
# ── Create placeholders for instant UI feedback ───────────────
pipeline = IndexingPipelineService(session)
placeholders = [
PlaceholderInfo(
title=page.get("title", f"Untitled page ({page.get('page_id', '')})"),
document_type=DocumentType.NOTION_CONNECTOR,
unique_id=page.get("page_id", ""),
search_space_id=search_space_id,
connector_id=connector_id,
created_by_id=user_id,
metadata={
"page_id": page.get("page_id", ""),
"connector_id": connector_id,
"connector_type": "Notion",
},
)
for page in pages
if page.get("page_id")
]
await pipeline.create_placeholder_documents(placeholders)
# ── Build ConnectorDocuments ──────────────────────────────────
connector_docs: list[ConnectorDocument] = []
documents_skipped = 0
@ -282,9 +304,7 @@ async def index_notion_pages(
markdown_content += process_blocks(page_content)
if not markdown_content.strip():
logger.warning(
f"Skipping page with empty markdown: {page_title}"
)
logger.warning(f"Skipping page with empty markdown: {page_title}")
documents_skipped += 1
continue
@ -322,8 +342,6 @@ async def index_notion_pages(
continue
# ── Pipeline: migrate legacy docs + parallel index ────────────
pipeline = IndexingPipelineService(session)
await pipeline.migrate_legacy_docs(connector_docs)
async def _get_llm(s):

View file

@ -46,7 +46,6 @@ dependencies = [
"redis>=5.2.1",
"firecrawl-py>=4.9.0",
"boto3>=1.35.0",
"langchain-community>=0.3.31",
"litellm>=1.80.10",
"langchain-litellm>=0.3.5",
"fake-useragent>=2.2.0",
@ -60,20 +59,21 @@ dependencies = [
"sse-starlette>=3.1.1,<3.1.2",
"gitingest>=0.3.1",
"composio>=0.10.9",
"langchain>=1.2.6",
"langgraph>=1.0.5",
"unstructured[all-docs]>=0.18.31",
"unstructured-client>=0.42.3",
"langchain-unstructured>=1.0.1",
"slowapi>=0.1.9",
"pypandoc_binary>=1.16.2",
"typst>=0.14.0",
"deepagents>=0.4.3",
"daytona>=0.146.0",
"langchain-daytona>=0.0.2",
"pypandoc>=1.16.2",
"notion-markdown>=0.7.0",
"fractional-indexing>=0.1.3",
"langchain>=1.2.13",
"langgraph>=1.1.3",
"langchain-community>=0.4.1",
"deepagents>=0.4.12",
]
[dependency-groups]

View file

@ -14,7 +14,9 @@ _EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
def _cal_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
def _cal_doc(
*, unique_id: str, search_space_id: int, connector_id: int, user_id: str
) -> ConnectorDocument:
return ConnectorDocument(
title=f"Event {unique_id}",
source_markdown=f"## Calendar Event\n\nDetails for {unique_id}",
@ -34,7 +36,9 @@ def _cal_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id
)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_calendar_pipeline_creates_ready_document(
db_session, db_search_space, db_connector, db_user, mocker
):
@ -63,7 +67,9 @@ async def test_calendar_pipeline_creates_ready_document(
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_calendar_legacy_doc_migrated(
db_session, db_search_space, db_connector, db_user, mocker
):
@ -101,7 +107,9 @@ async def test_calendar_legacy_doc_migrated(
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])
result = await db_session.execute(select(Document).filter(Document.id == original_id))
result = await db_session.execute(
select(Document).filter(Document.id == original_id)
)
row = result.scalars().first()
assert row.document_type == DocumentType.GOOGLE_CALENDAR_CONNECTOR

View file

@ -14,7 +14,9 @@ _EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
def _drive_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
def _drive_doc(
*, unique_id: str, search_space_id: int, connector_id: int, user_id: str
) -> ConnectorDocument:
return ConnectorDocument(
title=f"File {unique_id}.pdf",
source_markdown=f"## Document Content\n\nText from file {unique_id}",
@ -33,7 +35,9 @@ def _drive_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_
)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_drive_pipeline_creates_ready_document(
db_session, db_search_space, db_connector, db_user, mocker
):
@ -62,7 +66,9 @@ async def test_drive_pipeline_creates_ready_document(
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_drive_legacy_doc_migrated(
db_session, db_search_space, db_connector, db_user, mocker
):
@ -100,7 +106,9 @@ async def test_drive_legacy_doc_migrated(
service = IndexingPipelineService(session=db_session)
await service.migrate_legacy_docs([connector_doc])
result = await db_session.execute(select(Document).filter(Document.id == original_id))
result = await db_session.execute(
select(Document).filter(Document.id == original_id)
)
row = result.scalars().first()
assert row.document_type == DocumentType.GOOGLE_DRIVE_FILE
@ -111,7 +119,9 @@ async def test_drive_legacy_doc_migrated(
async def test_should_skip_file_skips_failed_document(
db_session, db_search_space, db_user,
db_session,
db_search_space,
db_user,
):
"""A FAILED document with unchanged md5 must be skipped — user can manually retry via Quick Index."""
import importlib
@ -162,7 +172,12 @@ async def test_should_skip_file_skips_failed_document(
db_session.add(failed_doc)
await db_session.flush()
incoming_file = {"id": file_id, "name": "Failed File.pdf", "mimeType": "application/pdf", "md5Checksum": md5}
incoming_file = {
"id": file_id,
"name": "Failed File.pdf",
"mimeType": "application/pdf",
"md5Checksum": md5,
}
should_skip, msg = await _should_skip_file(db_session, incoming_file, space_id)

View file

@ -8,7 +8,6 @@ from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import (
compute_identifier_hash,
compute_unique_identifier_hash,
)
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
@ -17,7 +16,9 @@ _EMBEDDING_DIM = app_config.embedding_model_instance.dimension
pytestmark = pytest.mark.integration
def _gmail_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
def _gmail_doc(
*, unique_id: str, search_space_id: int, connector_id: int, user_id: str
) -> ConnectorDocument:
"""Build a Gmail-style ConnectorDocument like the real indexer does."""
return ConnectorDocument(
title=f"Subject for {unique_id}",
@ -37,7 +38,9 @@ def _gmail_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_
)
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_gmail_pipeline_creates_ready_document(
db_session, db_search_space, db_connector, db_user, mocker
):
@ -67,7 +70,9 @@ async def test_gmail_pipeline_creates_ready_document(
assert row.source_markdown == doc.source_markdown
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_gmail_legacy_doc_migrated_then_reused(
db_session, db_search_space, db_connector, db_user, mocker
):

View file

@ -9,7 +9,9 @@ from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineServ
pytestmark = pytest.mark.integration
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_index_batch_creates_ready_documents(
db_session, db_search_space, make_connector_document, mocker
):
@ -47,7 +49,9 @@ async def test_index_batch_creates_ready_documents(
assert row.embedding is not None
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
@pytest.mark.usefixtures(
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
)
async def test_index_batch_empty_returns_empty(db_session, mocker):
"""index_batch with empty input returns an empty list."""
service = IndexingPipelineService(session=db_session)

View file

@ -0,0 +1,106 @@
"""Shared fixtures for retriever integration tests."""
from __future__ import annotations
import uuid
from datetime import UTC, datetime
import pytest_asyncio
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config as app_config
from app.db import Chunk, Document, DocumentType, SearchSpace, User
EMBEDDING_DIM = app_config.embedding_model_instance.dimension
DUMMY_EMBEDDING = [0.1] * EMBEDDING_DIM
def _make_document(
*,
title: str,
document_type: DocumentType,
content: str,
search_space_id: int,
created_by_id: str,
) -> Document:
uid = uuid.uuid4().hex[:12]
return Document(
title=title,
document_type=document_type,
content=content,
content_hash=f"content-{uid}",
unique_identifier_hash=f"uid-{uid}",
source_markdown=content,
search_space_id=search_space_id,
created_by_id=created_by_id,
embedding=DUMMY_EMBEDDING,
updated_at=datetime.now(UTC),
status={"state": "ready"},
)
def _make_chunk(*, content: str, document_id: int) -> Chunk:
return Chunk(
content=content,
document_id=document_id,
embedding=DUMMY_EMBEDDING,
)
@pytest_asyncio.fixture
async def seed_large_doc(
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
):
"""Insert a document with 35 chunks (more than _MAX_FETCH_CHUNKS_PER_DOC=20).
Also inserts a small 3-chunk document for diversity testing.
Returns a dict with ``large_doc``, ``small_doc``, ``search_space``, ``user``,
and ``large_chunk_ids`` (all 35 chunk IDs).
"""
user_id = str(db_user.id)
space_id = db_search_space.id
large_doc = _make_document(
title="Large PDF Document",
document_type=DocumentType.FILE,
content="large document about quarterly performance reviews and budgets",
search_space_id=space_id,
created_by_id=user_id,
)
small_doc = _make_document(
title="Small Note",
document_type=DocumentType.NOTE,
content="quarterly performance review summary note",
search_space_id=space_id,
created_by_id=user_id,
)
db_session.add_all([large_doc, small_doc])
await db_session.flush()
large_chunks = []
for i in range(35):
chunk = _make_chunk(
content=f"chunk {i} about quarterly performance review section {i}",
document_id=large_doc.id,
)
large_chunks.append(chunk)
small_chunks = [
_make_chunk(
content="quarterly performance review summary note content",
document_id=small_doc.id,
),
]
db_session.add_all(large_chunks + small_chunks)
await db_session.flush()
return {
"large_doc": large_doc,
"small_doc": small_doc,
"large_chunk_ids": [c.id for c in large_chunks],
"small_chunk_ids": [c.id for c in small_chunks],
"search_space": db_search_space,
"user": db_user,
}

View file

@ -0,0 +1,116 @@
"""Integration tests for optimized ChucksHybridSearchRetriever.
Verifies the SQL ROW_NUMBER per-doc chunk limit, column pruning,
and doc metadata caching from RRF results.
"""
import pytest
from app.retriever.chunks_hybrid_search import (
_MAX_FETCH_CHUNKS_PER_DOC,
ChucksHybridSearchRetriever,
)
from .conftest import DUMMY_EMBEDDING
pytestmark = pytest.mark.integration
async def test_per_doc_chunk_limit_respected(db_session, seed_large_doc):
"""A document with 35 chunks should have at most _MAX_FETCH_CHUNKS_PER_DOC chunks returned."""
space_id = seed_large_doc["search_space"].id
retriever = ChucksHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
large_doc_id = seed_large_doc["large_doc"].id
for result in results:
if result["document"].get("id") == large_doc_id:
assert len(result["chunks"]) <= _MAX_FETCH_CHUNKS_PER_DOC
assert len(result["chunks"]) == _MAX_FETCH_CHUNKS_PER_DOC
break
else:
pytest.fail("Large doc not found in search results")
async def test_doc_metadata_populated_from_rrf(db_session, seed_large_doc):
"""Document metadata (title, type, etc.) should be present even without joinedload."""
space_id = seed_large_doc["search_space"].id
retriever = ChucksHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
assert len(results) >= 1
for result in results:
doc = result["document"]
assert "id" in doc
assert "title" in doc
assert doc["title"]
assert "document_type" in doc
assert doc["document_type"] is not None
async def test_matched_chunk_ids_tracked(db_session, seed_large_doc):
"""matched_chunk_ids should contain the chunk IDs that appeared in the RRF results."""
space_id = seed_large_doc["search_space"].id
retriever = ChucksHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
for result in results:
matched = result.get("matched_chunk_ids", [])
chunk_ids_in_result = {c["chunk_id"] for c in result["chunks"]}
for mid in matched:
assert mid in chunk_ids_in_result, (
f"matched_chunk_id {mid} not found in chunks"
)
async def test_chunks_ordered_by_id(db_session, seed_large_doc):
"""Chunks within each document should be ordered by chunk ID (original order)."""
space_id = seed_large_doc["search_space"].id
retriever = ChucksHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
for result in results:
chunk_ids = [c["chunk_id"] for c in result["chunks"]]
assert chunk_ids == sorted(chunk_ids), "Chunks not ordered by ID"
async def test_score_is_positive_float(db_session, seed_large_doc):
"""Each result should have a positive float score from RRF."""
space_id = seed_large_doc["search_space"].id
retriever = ChucksHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
assert len(results) >= 1
for result in results:
assert isinstance(result["score"], float)
assert result["score"] > 0

View file

@ -0,0 +1,76 @@
"""Integration tests for optimized DocumentHybridSearchRetriever.
Verifies the SQL ROW_NUMBER per-doc chunk limit and column pruning.
"""
import pytest
from app.retriever.documents_hybrid_search import (
_MAX_FETCH_CHUNKS_PER_DOC,
DocumentHybridSearchRetriever,
)
from .conftest import DUMMY_EMBEDDING
pytestmark = pytest.mark.integration
async def test_per_doc_chunk_limit_respected(db_session, seed_large_doc):
"""A document with 35 chunks should have at most _MAX_FETCH_CHUNKS_PER_DOC chunks returned."""
space_id = seed_large_doc["search_space"].id
retriever = DocumentHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
large_doc_id = seed_large_doc["large_doc"].id
for result in results:
if result["document"].get("id") == large_doc_id:
assert len(result["chunks"]) <= _MAX_FETCH_CHUNKS_PER_DOC
assert len(result["chunks"]) == _MAX_FETCH_CHUNKS_PER_DOC
break
else:
pytest.fail("Large doc not found in search results")
async def test_doc_metadata_populated(db_session, seed_large_doc):
"""Document metadata should be present from the RRF results."""
space_id = seed_large_doc["search_space"].id
retriever = DocumentHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
assert len(results) >= 1
for result in results:
doc = result["document"]
assert "id" in doc
assert "title" in doc
assert doc["title"]
assert "document_type" in doc
assert doc["document_type"] is not None
async def test_chunks_ordered_by_id(db_session, seed_large_doc):
"""Chunks within each document should be ordered by chunk ID."""
space_id = seed_large_doc["search_space"].id
retriever = DocumentHybridSearchRetriever(db_session)
results = await retriever.hybrid_search(
query_text="quarterly performance review",
top_k=10,
search_space_id=space_id,
query_embedding=DUMMY_EMBEDDING,
)
for result in results:
chunk_ids = [c["chunk_id"] for c in result["chunks"]]
assert chunk_ids == sorted(chunk_ids), "Chunks not ordered by ID"

View file

@ -42,14 +42,11 @@ def _to_markdown(page: dict) -> str:
if comments:
comments_content = "\n\n## Comments\n\n"
for comment in comments:
comment_body = (
comment.get("body", {}).get("storage", {}).get("value", "")
)
comment_body = comment.get("body", {}).get("storage", {}).get("value", "")
comment_author = comment.get("version", {}).get("authorId", "Unknown")
comment_date = comment.get("version", {}).get("createdAt", "")
comments_content += (
f"**Comment by {comment_author}** ({comment_date}):\n"
f"{comment_body}\n\n"
f"**Comment by {comment_author}** ({comment_date}):\n{comment_body}\n\n"
)
return f"# {page_title}\n\n{page_content}{comments_content}"
@ -138,22 +135,32 @@ def confluence_mocks(monkeypatch):
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
_mod,
"get_connector_by_id",
AsyncMock(return_value=mock_connector),
)
confluence_client = _mock_confluence_client(pages=[_make_page()])
monkeypatch.setattr(
_mod, "ConfluenceHistoryConnector", MagicMock(return_value=confluence_client),
_mod,
"ConfluenceHistoryConnector",
MagicMock(return_value=confluence_client),
)
monkeypatch.setattr(
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
_mod,
"check_duplicate_document_by_hash",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod, "update_connector_last_indexed", AsyncMock(),
_mod,
"update_connector_last_indexed",
AsyncMock(),
)
monkeypatch.setattr(
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
_mod,
"calculate_date_range",
MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
mock_task_logger = MagicMock()
@ -162,15 +169,20 @@ def confluence_mocks(monkeypatch):
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
_mod,
"TaskLoggingService",
MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
return {

View file

@ -41,6 +41,7 @@ def mock_drive_client():
@pytest.fixture
def patch_extract(monkeypatch):
"""Provide a helper to set the download_and_extract_content mock."""
def _patch(side_effect=None, return_value=None):
mock = AsyncMock(side_effect=side_effect, return_value=return_value)
monkeypatch.setattr(
@ -48,11 +49,13 @@ def patch_extract(monkeypatch):
mock,
)
return mock
return _patch
async def test_single_file_returns_one_connector_document(
mock_drive_client, patch_extract,
mock_drive_client,
patch_extract,
):
"""Tracer bullet: downloading one file produces one ConnectorDocument."""
patch_extract(return_value=_mock_extract_ok("f1", "test.txt"))
@ -73,7 +76,8 @@ async def test_single_file_returns_one_connector_document(
async def test_multiple_files_all_produce_documents(
mock_drive_client, patch_extract,
mock_drive_client,
patch_extract,
):
"""All files are downloaded and converted to ConnectorDocuments."""
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
@ -96,7 +100,8 @@ async def test_multiple_files_all_produce_documents(
async def test_one_download_exception_does_not_block_others(
mock_drive_client, patch_extract,
mock_drive_client,
patch_extract,
):
"""A RuntimeError in one download still lets the other files succeed."""
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
@ -123,7 +128,8 @@ async def test_one_download_exception_does_not_block_others(
async def test_etl_error_counts_as_download_failure(
mock_drive_client, patch_extract,
mock_drive_client,
patch_extract,
):
"""download_and_extract_content returning an error is counted as failed."""
files = [_make_file_dict("f0", "good.txt"), _make_file_dict("f1", "bad.txt")]
@ -148,7 +154,8 @@ async def test_etl_error_counts_as_download_failure(
async def test_concurrency_bounded_by_semaphore(
mock_drive_client, monkeypatch,
mock_drive_client,
monkeypatch,
):
"""Peak concurrent downloads never exceeds max_concurrency."""
lock = asyncio.Lock()
@ -189,7 +196,8 @@ async def test_concurrency_bounded_by_semaphore(
async def test_heartbeat_fires_during_parallel_downloads(
mock_drive_client, monkeypatch,
mock_drive_client,
monkeypatch,
):
"""on_heartbeat is called at least once when downloads take time."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
@ -231,8 +239,13 @@ async def test_heartbeat_fires_during_parallel_downloads(
# Slice 6, 6b, 6c -- _index_full_scan three-phase pipeline
# ---------------------------------------------------------------------------
def _folder_dict(file_id: str, name: str) -> dict:
return {"id": file_id, "name": name, "mimeType": "application/vnd.google-apps.folder"}
return {
"id": file_id,
"name": name,
"mimeType": "application/vnd.google-apps.folder",
}
@pytest.fixture
@ -259,12 +272,17 @@ def full_scan_mocks(mock_drive_client, monkeypatch):
batch_mock = AsyncMock(return_value=([], 0, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
monkeypatch.setattr(
_mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock()),
_mod,
"get_user_long_context_llm",
AsyncMock(return_value=MagicMock()),
)
return {
@ -312,12 +330,16 @@ async def test_full_scan_three_phase_counts(full_scan_mocks, monkeypatch):
]
monkeypatch.setattr(
_mod, "get_files_in_folder",
_mod,
"get_files_in_folder",
AsyncMock(return_value=(page_files, None, None)),
)
full_scan_mocks["skip_results"]["skip1"] = (True, "unchanged")
full_scan_mocks["skip_results"]["rename1"] = (True, "File renamed: 'old''renamed.txt'")
full_scan_mocks["skip_results"]["rename1"] = (
True,
"File renamed: 'old''renamed.txt'",
)
mock_docs = [MagicMock(), MagicMock()]
full_scan_mocks["download_mock"].return_value = (mock_docs, 0)
@ -341,7 +363,8 @@ async def test_full_scan_respects_max_files(full_scan_mocks, monkeypatch):
page_files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(10)]
monkeypatch.setattr(
_mod, "get_files_in_folder",
_mod,
"get_files_in_folder",
AsyncMock(return_value=(page_files, None, None)),
)
@ -355,14 +378,16 @@ async def test_full_scan_respects_max_files(full_scan_mocks, monkeypatch):
async def test_full_scan_uses_max_concurrency_3_for_indexing(
full_scan_mocks, monkeypatch,
full_scan_mocks,
monkeypatch,
):
"""index_batch_parallel is called with max_concurrency=3."""
import app.tasks.connector_indexers.google_drive_indexer as _mod
page_files = [_make_file_dict("f1", "file1.txt")]
monkeypatch.setattr(
_mod, "get_files_in_folder",
_mod,
"get_files_in_folder",
AsyncMock(return_value=(page_files, None, None)),
)
@ -382,6 +407,7 @@ async def test_full_scan_uses_max_concurrency_3_for_indexing(
# Slice 7 -- _index_with_delta_sync three-phase pipeline
# ---------------------------------------------------------------------------
async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
"""Removed/trashed changes call _remove_document; the rest go through
_download_files_parallel and index_batch_parallel."""
@ -396,7 +422,8 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
]
monkeypatch.setattr(
_mod, "fetch_all_changes",
_mod,
"fetch_all_changes",
AsyncMock(return_value=(changes, "new-token", None)),
)
@ -408,7 +435,8 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
"mod2": "modified",
}
monkeypatch.setattr(
_mod, "categorize_change",
_mod,
"categorize_change",
lambda change: change_types[change["fileId"]],
)
@ -420,7 +448,8 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
monkeypatch.setattr(_mod, "_remove_document", _fake_remove)
monkeypatch.setattr(
_mod, "_should_skip_file",
_mod,
"_should_skip_file",
AsyncMock(return_value=(False, None)),
)
@ -431,11 +460,16 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
batch_mock = AsyncMock(return_value=([], 2, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
monkeypatch.setattr(
_mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock()),
_mod,
"get_user_long_context_llm",
AsyncMock(return_value=MagicMock()),
)
mock_session = AsyncMock()
@ -472,6 +506,7 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
# _index_selected_files -- parallel indexing of user-selected files
# ---------------------------------------------------------------------------
@pytest.fixture
def selected_files_mocks(mock_drive_client, monkeypatch):
"""Wire up mocks for _index_selected_files tests."""
@ -496,6 +531,14 @@ def selected_files_mocks(mock_drive_client, monkeypatch):
download_and_index_mock = AsyncMock(return_value=(0, 0))
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
pipeline_mock = MagicMock()
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
return {
"drive_client": mock_drive_client,
"session": mock_session,
@ -526,7 +569,8 @@ async def test_selected_files_single_file_indexed(selected_files_mocks):
selected_files_mocks["download_and_index_mock"].return_value = (1, 0)
indexed, skipped, errors = await _run_selected(
selected_files_mocks, [("f1", "report.pdf")],
selected_files_mocks,
[("f1", "report.pdf")],
)
assert indexed == 1
@ -538,11 +582,13 @@ async def test_selected_files_single_file_indexed(selected_files_mocks):
async def test_selected_files_fetch_failure_isolation(selected_files_mocks):
"""get_file_by_id failing for one file collects an error; others still indexed."""
selected_files_mocks["get_file_results"]["f1"] = (
_make_file_dict("f1", "first.txt"), None,
_make_file_dict("f1", "first.txt"),
None,
)
selected_files_mocks["get_file_results"]["f2"] = (None, "HTTP 404")
selected_files_mocks["get_file_results"]["f3"] = (
_make_file_dict("f3", "third.txt"), None,
_make_file_dict("f3", "third.txt"),
None,
)
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
@ -561,30 +607,46 @@ async def test_selected_files_fetch_failure_isolation(selected_files_mocks):
async def test_selected_files_skip_rename_counting(selected_files_mocks):
"""Unchanged files are skipped, renames counted as indexed,
and only new files are sent to _download_and_index."""
for fid, fname in [("s1", "unchanged.txt"), ("r1", "renamed.txt"),
("n1", "new1.txt"), ("n2", "new2.txt")]:
for fid, fname in [
("s1", "unchanged.txt"),
("r1", "renamed.txt"),
("n1", "new1.txt"),
("n2", "new2.txt"),
]:
selected_files_mocks["get_file_results"][fid] = (
_make_file_dict(fid, fname), None,
_make_file_dict(fid, fname),
None,
)
selected_files_mocks["skip_results"]["s1"] = (True, "unchanged")
selected_files_mocks["skip_results"]["r1"] = (True, "File renamed: 'old' \u2192 'renamed.txt'")
selected_files_mocks["skip_results"]["r1"] = (
True,
"File renamed: 'old' \u2192 'renamed.txt'",
)
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
indexed, skipped, errors = await _run_selected(
selected_files_mocks,
[("s1", "unchanged.txt"), ("r1", "renamed.txt"),
("n1", "new1.txt"), ("n2", "new2.txt")],
[
("s1", "unchanged.txt"),
("r1", "renamed.txt"),
("n1", "new1.txt"),
("n2", "new2.txt"),
],
)
assert indexed == 3 # 1 renamed + 2 batch
assert skipped == 1 # 1 unchanged
assert indexed == 3 # 1 renamed + 2 batch
assert skipped == 1 # 1 unchanged
assert errors == []
mock = selected_files_mocks["download_and_index_mock"]
mock.assert_called_once()
call_files = mock.call_args[1].get("files") if "files" in (mock.call_args[1] or {}) else mock.call_args[0][2]
call_files = (
mock.call_args[1].get("files")
if "files" in (mock.call_args[1] or {})
else mock.call_args[0][2]
)
assert len(call_files) == 2
assert {f["id"] for f in call_files} == {"n1", "n2"}
@ -593,6 +655,7 @@ async def test_selected_files_skip_rename_counting(selected_files_mocks):
# asyncio.to_thread verification — prove blocking calls run in parallel
# ---------------------------------------------------------------------------
async def test_client_download_file_runs_in_thread_parallel():
"""Calling download_file concurrently via asyncio.gather should overlap
blocking work on separate threads, proving to_thread is effective.
@ -602,11 +665,11 @@ async def test_client_download_file_runs_in_thread_parallel():
"""
from app.connectors.google_drive.client import GoogleDriveClient
BLOCK_SECONDS = 0.2
NUM_CALLS = 3
block_seconds = 0.2
num_calls = 3
def _blocking_download(service, file_id, credentials):
time.sleep(BLOCK_SECONDS)
time.sleep(block_seconds)
return b"fake-content", None
client = GoogleDriveClient.__new__(GoogleDriveClient)
@ -615,11 +678,13 @@ async def test_client_download_file_runs_in_thread_parallel():
client._service_lock = asyncio.Lock()
with patch.object(
GoogleDriveClient, "_sync_download_file", staticmethod(_blocking_download),
GoogleDriveClient,
"_sync_download_file",
staticmethod(_blocking_download),
):
start = time.monotonic()
results = await asyncio.gather(
*(client.download_file(f"file-{i}") for i in range(NUM_CALLS))
*(client.download_file(f"file-{i}") for i in range(num_calls))
)
elapsed = time.monotonic() - start
@ -627,7 +692,7 @@ async def test_client_download_file_runs_in_thread_parallel():
assert content == b"fake-content"
assert error is None
serial_minimum = BLOCK_SECONDS * NUM_CALLS
serial_minimum = block_seconds * num_calls
assert elapsed < serial_minimum, (
f"Elapsed {elapsed:.2f}s >= serial minimum {serial_minimum:.2f}s — "
f"downloads are not running in parallel"
@ -638,11 +703,11 @@ async def test_client_export_google_file_runs_in_thread_parallel():
"""Same strategy for export_google_file — verify to_thread parallelism."""
from app.connectors.google_drive.client import GoogleDriveClient
BLOCK_SECONDS = 0.2
NUM_CALLS = 3
block_seconds = 0.2
num_calls = 3
def _blocking_export(service, file_id, mime_type, credentials):
time.sleep(BLOCK_SECONDS)
time.sleep(block_seconds)
return b"exported", None
client = GoogleDriveClient.__new__(GoogleDriveClient)
@ -651,12 +716,16 @@ async def test_client_export_google_file_runs_in_thread_parallel():
client._service_lock = asyncio.Lock()
with patch.object(
GoogleDriveClient, "_sync_export_google_file", staticmethod(_blocking_export),
GoogleDriveClient,
"_sync_export_google_file",
staticmethod(_blocking_export),
):
start = time.monotonic()
results = await asyncio.gather(
*(client.export_google_file(f"file-{i}", "application/pdf")
for i in range(NUM_CALLS))
*(
client.export_google_file(f"file-{i}", "application/pdf")
for i in range(num_calls)
)
)
elapsed = time.monotonic() - start
@ -664,7 +733,7 @@ async def test_client_export_google_file_runs_in_thread_parallel():
assert content == b"exported"
assert error is None
serial_minimum = BLOCK_SECONDS * NUM_CALLS
serial_minimum = block_seconds * num_calls
assert elapsed < serial_minimum, (
f"Elapsed {elapsed:.2f}s >= serial minimum {serial_minimum:.2f}s — "
f"exports are not running in parallel"

View file

@ -145,22 +145,32 @@ def jira_mocks(monkeypatch):
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
_mod,
"get_connector_by_id",
AsyncMock(return_value=mock_connector),
)
jira_client = _mock_jira_client(issues=[_make_issue()])
monkeypatch.setattr(
_mod, "JiraHistoryConnector", MagicMock(return_value=jira_client),
_mod,
"JiraHistoryConnector",
MagicMock(return_value=jira_client),
)
monkeypatch.setattr(
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
_mod,
"check_duplicate_document_by_hash",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod, "update_connector_last_indexed", AsyncMock(),
_mod,
"update_connector_last_indexed",
AsyncMock(),
)
monkeypatch.setattr(
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
_mod,
"calculate_date_range",
MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
mock_task_logger = MagicMock()
@ -169,15 +179,20 @@ def jira_mocks(monkeypatch):
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
_mod,
"TaskLoggingService",
MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
return {

View file

@ -128,13 +128,17 @@ def _mock_linear_client(issues=None, error=None):
client.get_issues_by_date_range = AsyncMock(
return_value=(issues if issues is not None else [], error),
)
client.format_issue = MagicMock(side_effect=lambda i: _make_formatted_issue(
issue_id=i.get("id", ""),
identifier=i.get("identifier", ""),
title=i.get("title", ""),
))
client.format_issue = MagicMock(
side_effect=lambda i: _make_formatted_issue(
issue_id=i.get("id", ""),
identifier=i.get("identifier", ""),
title=i.get("title", ""),
)
)
client.format_issue_to_markdown = MagicMock(
side_effect=lambda fi: f"# {fi.get('identifier', '')}: {fi.get('title', '')}\n\nContent"
side_effect=lambda fi: (
f"# {fi.get('identifier', '')}: {fi.get('title', '')}\n\nContent"
)
)
return client
@ -147,24 +151,34 @@ def linear_mocks(monkeypatch):
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
_mod,
"get_connector_by_id",
AsyncMock(return_value=mock_connector),
)
linear_client = _mock_linear_client(issues=[_make_issue()])
monkeypatch.setattr(
_mod, "LinearConnector", MagicMock(return_value=linear_client),
_mod,
"LinearConnector",
MagicMock(return_value=linear_client),
)
monkeypatch.setattr(
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
_mod,
"check_duplicate_document_by_hash",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod, "update_connector_last_indexed", AsyncMock(),
_mod,
"update_connector_last_indexed",
AsyncMock(),
)
monkeypatch.setattr(
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
_mod,
"calculate_date_range",
MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
mock_task_logger = MagicMock()
@ -173,15 +187,20 @@ def linear_mocks(monkeypatch):
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
_mod,
"TaskLoggingService",
MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
return {
@ -255,7 +274,7 @@ async def test_issues_with_missing_id_are_skipped(linear_mocks):
]
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
indexed, skipped, _ = await _run_index(linear_mocks)
_indexed, skipped, _ = await _run_index(linear_mocks)
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
@ -271,7 +290,7 @@ async def test_issues_with_missing_title_are_skipped(linear_mocks):
]
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
indexed, skipped, _ = await _run_index(linear_mocks)
_indexed, skipped, _ = await _run_index(linear_mocks)
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1
@ -305,7 +324,7 @@ async def test_duplicate_content_issues_are_skipped(linear_mocks, monkeypatch):
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
indexed, skipped, _ = await _run_index(linear_mocks)
_indexed, skipped, _ = await _run_index(linear_mocks)
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
assert len(connector_docs) == 1

View file

@ -107,28 +107,40 @@ def notion_mocks(monkeypatch):
mock_connector = _mock_connector()
monkeypatch.setattr(
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
_mod,
"get_connector_by_id",
AsyncMock(return_value=mock_connector),
)
notion_client = _mock_notion_client(pages=[_make_page()])
monkeypatch.setattr(
_mod, "NotionHistoryConnector", MagicMock(return_value=notion_client),
_mod,
"NotionHistoryConnector",
MagicMock(return_value=notion_client),
)
monkeypatch.setattr(
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
_mod,
"check_duplicate_document_by_hash",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
_mod, "update_connector_last_indexed", AsyncMock(),
_mod,
"update_connector_last_indexed",
AsyncMock(),
)
monkeypatch.setattr(
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
_mod,
"calculate_date_range",
MagicMock(return_value=("2025-01-01", "2025-12-31")),
)
monkeypatch.setattr(
_mod, "process_blocks", MagicMock(return_value="Converted markdown content"),
_mod,
"process_blocks",
MagicMock(return_value="Converted markdown content"),
)
mock_task_logger = MagicMock()
@ -137,15 +149,20 @@ def notion_mocks(monkeypatch):
mock_task_logger.log_task_success = AsyncMock()
mock_task_logger.log_task_failure = AsyncMock()
monkeypatch.setattr(
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
_mod,
"TaskLoggingService",
MagicMock(return_value=mock_task_logger),
)
batch_mock = AsyncMock(return_value=([], 1, 0))
pipeline_mock = MagicMock()
pipeline_mock.index_batch_parallel = batch_mock
pipeline_mock.migrate_legacy_docs = AsyncMock()
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
monkeypatch.setattr(
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
_mod,
"IndexingPipelineService",
MagicMock(return_value=pipeline_mock),
)
return {
@ -216,7 +233,10 @@ async def test_pages_with_missing_id_are_skipped(notion_mocks, monkeypatch):
"""Pages without page_id are skipped and not passed to the pipeline."""
pages = [
_make_page(page_id="valid-1"),
{"title": "No ID page", "content": [{"type": "paragraph", "content": "text", "children": []}]},
{
"title": "No ID page",
"content": [{"type": "paragraph", "content": "text", "children": []}],
},
]
notion_mocks["notion_client"].get_all_pages.return_value = pages

View file

@ -0,0 +1,131 @@
"""Unit tests for IndexingPipelineService.create_placeholder_documents."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from sqlalchemy.exc import IntegrityError
from app.db import DocumentStatus, DocumentType
from app.indexing_pipeline.document_hashing import compute_identifier_hash
from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
PlaceholderInfo,
)
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_placeholder(**overrides) -> PlaceholderInfo:
defaults = {
"title": "Test Doc",
"document_type": DocumentType.GOOGLE_DRIVE_FILE,
"unique_id": "file-001",
"search_space_id": 1,
"connector_id": 42,
"created_by_id": "00000000-0000-0000-0000-000000000001",
}
defaults.update(overrides)
return PlaceholderInfo(**defaults)
def _uid_hash(p: PlaceholderInfo) -> str:
return compute_identifier_hash(
p.document_type.value, p.unique_id, p.search_space_id
)
def _session_with_existing_hashes(existing: set[str] | None = None):
"""Build an AsyncMock session whose batch-query returns *existing* hashes."""
session = AsyncMock()
result = MagicMock()
result.scalars.return_value.all.return_value = list(existing or [])
session.execute = AsyncMock(return_value=result)
session.add = MagicMock()
return session
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
async def test_empty_input_returns_zero_without_db_calls():
session = AsyncMock()
pipeline = IndexingPipelineService(session)
result = await pipeline.create_placeholder_documents([])
assert result == 0
session.execute.assert_not_awaited()
session.commit.assert_not_awaited()
async def test_creates_documents_with_pending_status_and_commits():
session = _session_with_existing_hashes(set())
pipeline = IndexingPipelineService(session)
p = _make_placeholder(title="My File", unique_id="file-abc")
result = await pipeline.create_placeholder_documents([p])
assert result == 1
session.add.assert_called_once()
doc = session.add.call_args[0][0]
assert doc.title == "My File"
assert doc.document_type == DocumentType.GOOGLE_DRIVE_FILE
assert doc.content == "Pending..."
assert DocumentStatus.is_state(doc.status, DocumentStatus.PENDING)
assert doc.search_space_id == 1
assert doc.connector_id == 42
session.commit.assert_awaited_once()
async def test_existing_documents_are_skipped():
"""Placeholders whose unique_identifier_hash already exists are not re-created."""
existing_p = _make_placeholder(unique_id="already-there")
new_p = _make_placeholder(unique_id="brand-new")
existing_hash = _uid_hash(existing_p)
session = _session_with_existing_hashes({existing_hash})
pipeline = IndexingPipelineService(session)
result = await pipeline.create_placeholder_documents([existing_p, new_p])
assert result == 1
doc = session.add.call_args[0][0]
assert doc.unique_identifier_hash == _uid_hash(new_p)
async def test_duplicate_unique_ids_within_input_are_deduped():
"""Same unique_id passed twice only produces one placeholder."""
p1 = _make_placeholder(unique_id="dup-id", title="First")
p2 = _make_placeholder(unique_id="dup-id", title="Second")
session = _session_with_existing_hashes(set())
pipeline = IndexingPipelineService(session)
result = await pipeline.create_placeholder_documents([p1, p2])
assert result == 1
session.add.assert_called_once()
async def test_integrity_error_on_commit_returns_zero():
"""IntegrityError during commit (race condition) is swallowed gracefully."""
session = _session_with_existing_hashes(set())
session.commit = AsyncMock(side_effect=IntegrityError("dup", {}, None))
pipeline = IndexingPipelineService(session)
p = _make_placeholder()
result = await pipeline.create_placeholder_documents([p])
assert result == 0
session.rollback.assert_awaited_once()

View file

@ -19,9 +19,7 @@ def pipeline(mock_session):
return IndexingPipelineService(mock_session)
async def test_calls_prepare_then_index_per_document(
pipeline, make_connector_document
):
async def test_calls_prepare_then_index_per_document(pipeline, make_connector_document):
"""index_batch calls prepare_for_indexing, then index() for each returned doc."""
doc1 = make_connector_document(
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,

View file

@ -1,5 +1,5 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock
import pytest
@ -57,7 +57,9 @@ async def test_index_calls_embed_and_chunk_via_to_thread(
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
mock_chunk,
)
mock_embed = MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts])
mock_embed = MagicMock(
side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]
)
mock_embed.__name__ = "embed_texts"
monkeypatch.setattr(
"app.indexing_pipeline.indexing_pipeline_service.embed_texts",

View file

@ -0,0 +1,110 @@
"""Unit tests for the duplicate-content safety logic in prepare_for_indexing.
Verifies that when an existing document's updated content matches another
document's content_hash, the system marks it as failed (for placeholders)
or leaves it untouched (for ready documents) never deletes.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.db import Document, DocumentStatus, DocumentType
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.document_hashing import (
compute_unique_identifier_hash,
)
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
pytestmark = pytest.mark.unit
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_connector_doc(**overrides) -> ConnectorDocument:
defaults = {
"title": "Test Doc",
"source_markdown": "## Some new content",
"unique_id": "file-001",
"document_type": DocumentType.GOOGLE_DRIVE_FILE,
"search_space_id": 1,
"connector_id": 42,
"created_by_id": "00000000-0000-0000-0000-000000000001",
}
defaults.update(overrides)
return ConnectorDocument(**defaults)
def _make_existing_doc(connector_doc: ConnectorDocument, *, status: dict) -> MagicMock:
"""Build a MagicMock that looks like an ORM Document with given status."""
doc = MagicMock(spec=Document)
doc.id = 999
doc.unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
doc.content_hash = "old-placeholder-content-hash"
doc.title = connector_doc.title
doc.status = status
return doc
def _mock_session_for_dedup(existing_doc, *, has_duplicate: bool):
"""Build a session whose sequential execute() calls return:
1. The *existing_doc* for the unique_identifier_hash lookup.
2. A row (or None) for the duplicate content_hash check.
"""
session = AsyncMock()
existing_result = MagicMock()
existing_result.scalars.return_value.first.return_value = existing_doc
dup_result = MagicMock()
dup_result.scalars.return_value.first.return_value = 42 if has_duplicate else None
session.execute = AsyncMock(side_effect=[existing_result, dup_result])
session.add = MagicMock()
return session
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
async def test_pending_placeholder_with_duplicate_content_is_marked_failed():
"""A placeholder (pending) whose updated content duplicates another doc
must be marked as FAILED never deleted."""
cdoc = _make_connector_doc(source_markdown="## Shared content")
existing = _make_existing_doc(cdoc, status=DocumentStatus.pending())
session = _mock_session_for_dedup(existing, has_duplicate=True)
pipeline = IndexingPipelineService(session)
results = await pipeline.prepare_for_indexing([cdoc])
assert results == [], "duplicate should not be returned for indexing"
assert DocumentStatus.is_state(existing.status, DocumentStatus.FAILED)
assert "Duplicate content" in existing.status.get("reason", "")
session.delete.assert_not_called()
async def test_ready_document_with_duplicate_content_is_left_untouched():
"""A READY document whose updated content duplicates another doc
must be left completely untouched not failed, not deleted."""
cdoc = _make_connector_doc(source_markdown="## Shared content")
existing = _make_existing_doc(cdoc, status=DocumentStatus.ready())
session = _mock_session_for_dedup(existing, has_duplicate=True)
pipeline = IndexingPipelineService(session)
results = await pipeline.prepare_for_indexing([cdoc])
assert results == [], "duplicate should not be returned for indexing"
assert DocumentStatus.is_state(existing.status, DocumentStatus.READY)
session.delete.assert_not_called()

View file

@ -0,0 +1,133 @@
"""Unit tests for knowledge_search middleware helpers.
These test pure functions that don't require a database.
"""
import pytest
from app.agents.new_chat.middleware.knowledge_search import (
_build_document_xml,
_resolve_search_types,
)
pytestmark = pytest.mark.unit
# ── _resolve_search_types ──────────────────────────────────────────────
class TestResolveSearchTypes:
def test_returns_none_when_no_inputs(self):
assert _resolve_search_types(None, None) is None
def test_returns_none_when_both_empty(self):
assert _resolve_search_types([], []) is None
def test_includes_legacy_type_for_google_gmail(self):
result = _resolve_search_types(["GOOGLE_GMAIL_CONNECTOR"], None)
assert "GOOGLE_GMAIL_CONNECTOR" in result
assert "COMPOSIO_GMAIL_CONNECTOR" in result
def test_includes_legacy_type_for_google_drive(self):
result = _resolve_search_types(None, ["GOOGLE_DRIVE_FILE"])
assert "GOOGLE_DRIVE_FILE" in result
assert "COMPOSIO_GOOGLE_DRIVE_CONNECTOR" in result
def test_includes_legacy_type_for_google_calendar(self):
result = _resolve_search_types(["GOOGLE_CALENDAR_CONNECTOR"], None)
assert "GOOGLE_CALENDAR_CONNECTOR" in result
assert "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR" in result
def test_no_legacy_expansion_for_unrelated_types(self):
result = _resolve_search_types(["FILE", "NOTE"], None)
assert set(result) == {"FILE", "NOTE"}
def test_combines_connectors_and_document_types(self):
result = _resolve_search_types(["FILE"], ["NOTE", "CRAWLED_URL"])
assert {"FILE", "NOTE", "CRAWLED_URL"}.issubset(set(result))
def test_deduplicates(self):
result = _resolve_search_types(["FILE", "FILE"], ["FILE"])
assert result.count("FILE") == 1
# ── _build_document_xml ────────────────────────────────────────────────
class TestBuildDocumentXml:
@pytest.fixture
def sample_document(self):
return {
"document_id": 42,
"document": {
"id": 42,
"document_type": "FILE",
"title": "Test Doc",
"metadata": {"url": "https://example.com"},
},
"chunks": [
{"chunk_id": 101, "content": "First chunk content"},
{"chunk_id": 102, "content": "Second chunk content"},
{"chunk_id": 103, "content": "Third chunk content"},
],
}
def test_contains_document_metadata(self, sample_document):
xml = _build_document_xml(sample_document)
assert "<document_id>42</document_id>" in xml
assert "<document_type>FILE</document_type>" in xml
assert "Test Doc" in xml
def test_contains_chunk_index(self, sample_document):
xml = _build_document_xml(sample_document)
assert "<chunk_index>" in xml
assert "</chunk_index>" in xml
assert 'chunk_id="101"' in xml
assert 'chunk_id="102"' in xml
assert 'chunk_id="103"' in xml
def test_matched_chunks_flagged_in_index(self, sample_document):
xml = _build_document_xml(sample_document, matched_chunk_ids={101, 103})
lines = xml.split("\n")
for line in lines:
if 'chunk_id="101"' in line:
assert 'matched="true"' in line
if 'chunk_id="102"' in line:
assert 'matched="true"' not in line
if 'chunk_id="103"' in line:
assert 'matched="true"' in line
def test_chunk_content_in_document_content_section(self, sample_document):
xml = _build_document_xml(sample_document)
assert "<document_content>" in xml
assert "First chunk content" in xml
assert "Second chunk content" in xml
assert "Third chunk content" in xml
def test_line_numbers_in_chunk_index_are_accurate(self, sample_document):
"""Verify that the line ranges in chunk_index actually point to the right content."""
xml = _build_document_xml(sample_document, matched_chunk_ids={101})
xml_lines = xml.split("\n")
for line in xml_lines:
if 'chunk_id="101"' in line and "lines=" in line:
import re
m = re.search(r'lines="(\d+)-(\d+)"', line)
assert m, f"No lines= attribute found in: {line}"
start, _end = int(m.group(1)), int(m.group(2))
target_line = xml_lines[start - 1]
assert "101" in target_line
assert "First chunk content" in target_line
break
else:
pytest.fail("chunk_id=101 entry not found in chunk_index")
def test_splits_into_lines_correctly(self, sample_document):
"""Each chunk occupies exactly one line (no embedded newlines)."""
xml = _build_document_xml(sample_document)
lines = xml.split("\n")
chunk_lines = [
line for line in lines if "<![CDATA[" in line and "<chunk" in line
]
assert len(chunk_lines) == 3

8549
surfsense_backend/uv.lock generated

File diff suppressed because it is too large Load diff

View file

@ -30,6 +30,7 @@ import {
// extractWriteTodosFromContent,
} from "@/atoms/chat/plan-state.atom";
import { closeReportPanelAtom } from "@/atoms/chat/report-panel.atom";
import { type AgentCreatedDocument, agentCreatedDocumentsAtom } from "@/atoms/documents/ui.atoms";
import { closeEditorPanelAtom } from "@/atoms/editor/editor-panel.atom";
import { membersAtom } from "@/atoms/members/members-query.atoms";
import { updateChatTabTitleAtom } from "@/atoms/tabs/tabs.atom";
@ -191,6 +192,7 @@ export default function NewChatPage() {
const closeReportPanel = useSetAtom(closeReportPanelAtom);
const closeEditorPanel = useSetAtom(closeEditorPanelAtom);
const updateChatTabTitle = useSetAtom(updateChatTabTitleAtom);
const setAgentCreatedDocuments = useSetAtom(agentCreatedDocumentsAtom);
// Get current user for author info in shared chats
const { data: currentUser } = useAtomValue(currentUserAtom);
@ -740,6 +742,20 @@ export default function NewChatPage() {
break;
}
case "data-documents-updated": {
const docEvent = parsed.data as {
action: string;
document: AgentCreatedDocument;
};
if (docEvent?.document?.id) {
setAgentCreatedDocuments((prev) => {
if (prev.some((d) => d.id === docEvent.document.id)) return prev;
return [...prev, docEvent.document];
});
}
break;
}
case "data-interrupt-request": {
wasInterrupted = true;
const interruptData = parsed.data as Record<string, unknown>;
@ -1534,7 +1550,7 @@ export default function NewChatPage() {
// For new chats (urlChatId === 0), threadId being null is expected (lazy creation)
if (!threadId && urlChatId > 0) {
return (
<div className="flex h-[calc(100dvh-64px)] flex-col items-center justify-center gap-4">
<div className="flex h-full flex-col items-center justify-center gap-4">
<div className="text-destructive">Failed to load chat</div>
<button
type="button"
@ -1553,7 +1569,7 @@ export default function NewChatPage() {
return (
<AssistantRuntimeProvider runtime={runtime}>
<ThinkingStepsDataUI />
<div key={searchSpaceId} className="flex h-[calc(100dvh-64px)] overflow-hidden">
<div key={searchSpaceId} className="flex h-full overflow-hidden">
<div className="flex-1 flex flex-col min-w-0 overflow-hidden">
<Thread />
</div>

View file

@ -2,7 +2,7 @@ import { Skeleton } from "@/components/ui/skeleton";
export default function Loading() {
return (
<div className="flex h-[calc(100dvh-64px)] flex-col bg-main-panel px-4">
<div className="flex h-full flex-col bg-main-panel px-4">
<div className="mx-auto w-full max-w-[44rem] flex flex-1 flex-col gap-6 py-8">
{/* User message */}
<div className="flex justify-end">

View file

@ -7,3 +7,14 @@ export const globalDocumentsQueryParamsAtom = atom<GetDocumentsRequest["queryPar
});
export const documentsSidebarOpenAtom = atom(false);
export interface AgentCreatedDocument {
id: number;
title: string;
documentType: string;
searchSpaceId: number;
folderId: number | null;
createdById: string | null;
}
export const agentCreatedDocumentsAtom = atom<AgentCreatedDocument[]>([]);

View file

@ -4,7 +4,6 @@ import { useAtomValue, useSetAtom } from "jotai";
import { AlertTriangle, Cable, Settings } from "lucide-react";
import { forwardRef, useEffect, useImperativeHandle, useMemo, useState } from "react";
import { createPortal } from "react-dom";
import { useZeroDocumentTypeCounts } from "@/hooks/use-zero-document-type-counts";
import { statusInboxItemsAtom } from "@/atoms/inbox/status-inbox.atom";
import {
globalNewLLMConfigsAtom,
@ -22,6 +21,7 @@ import { Tabs, TabsContent } from "@/components/ui/tabs";
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
import { useConnectorsSync } from "@/hooks/use-connectors-sync";
import { PICKER_CLOSE_EVENT, PICKER_OPEN_EVENT } from "@/hooks/use-google-picker";
import { useZeroDocumentTypeCounts } from "@/hooks/use-zero-document-type-counts";
import { cn } from "@/lib/utils";
import { ConnectorDialogHeader } from "./connector-popup/components/connector-dialog-header";
import { ConnectorConnectView } from "./connector-popup/connector-configs/views/connector-connect-view";

View file

@ -421,7 +421,9 @@ const defaultComponents = memoizeMarkdownComponents({
<code
className={cn("aui-md-inline-code rounded border bg-muted font-semibold", className)}
{...props}
/>
>
{children}
</code>
);
}
const language = /language-(\w+)/.exec(className || "")?.[1] ?? "text";

View file

@ -109,7 +109,7 @@ const ThreadContent: FC = () => {
>
<ThreadPrimitive.Viewport
turnAnchor="top"
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4"
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-scroll px-4 pt-4"
>
<AuiIf condition={({ thread }) => thread.isEmpty}>
<ThreadWelcome />
@ -1062,7 +1062,7 @@ interface ToolGroup {
const TOOL_GROUPS: ToolGroup[] = [
{
label: "Research",
tools: ["search_knowledge_base", "search_surfsense_docs", "scrape_webpage"],
tools: ["search_surfsense_docs", "scrape_webpage"],
},
{
label: "Generate",

View file

@ -69,7 +69,9 @@ export function CreateFolderDialog({
<form onSubmit={handleSubmit} className="flex flex-col gap-3 sm:gap-4">
<div className="flex flex-col gap-2">
<Label htmlFor="folder-name" className="text-sm">Folder name</Label>
<Label htmlFor="folder-name" className="text-sm">
Folder name
</Label>
<Input
ref={inputRef}
id="folder-name"
@ -91,11 +93,7 @@ export function CreateFolderDialog({
>
Cancel
</Button>
<Button
type="submit"
disabled={!name.trim()}
className="h-8 sm:h-9 text-xs sm:text-sm"
>
<Button type="submit" disabled={!name.trim()} className="h-8 sm:h-9 text-xs sm:text-sm">
Create
</Button>
</DialogFooter>

View file

@ -1,6 +1,15 @@
"use client";
import { AlertCircle, Clock, Download, Eye, MoreHorizontal, Move, PenLine, Trash2 } from "lucide-react";
import {
AlertCircle,
Clock,
Download,
Eye,
MoreHorizontal,
Move,
PenLine,
Trash2,
} from "lucide-react";
import React, { useCallback, useRef, useState } from "react";
import { useDrag } from "react-dnd";
import { getDocumentTypeIcon } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentTypeIcon";
@ -112,14 +121,15 @@ export const DocumentNode = React.memo(function DocumentNode({
return (
<ContextMenu onOpenChange={onContextMenuOpenChange}>
<ContextMenuTrigger asChild>
{/* biome-ignore lint/a11y/useSemanticElements: contains nested interactive children (Checkbox) that render as <button>, making a semantic <button> wrapper invalid */}
<div
role="button"
tabIndex={0}
ref={attachRef}
className={cn(
"group flex h-8 w-full items-center gap-2.5 rounded-md px-1 text-sm hover:bg-accent/50 cursor-pointer select-none text-left",
isMentioned && "bg-accent/30",
isDragging && "opacity-40"
"group flex h-8 w-full items-center gap-2.5 rounded-md px-1 text-sm hover:bg-accent/50 cursor-pointer select-none text-left",
isMentioned && "bg-accent/30",
isDragging && "opacity-40"
)}
style={{ paddingLeft: `${depth * 16 + 4}px` }}
onClick={handleCheckChange}
@ -130,54 +140,54 @@ export const DocumentNode = React.memo(function DocumentNode({
}
}}
>
{(() => {
if (statusState === "pending") {
{(() => {
if (statusState === "pending") {
return (
<Tooltip>
<TooltipTrigger asChild>
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
<Clock className="h-3.5 w-3.5 text-muted-foreground/60" />
</span>
</TooltipTrigger>
<TooltipContent side="top">Pending - waiting to be synced</TooltipContent>
</Tooltip>
);
}
if (statusState === "processing") {
return (
<Tooltip>
<TooltipTrigger asChild>
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
<Spinner size="xs" className="text-primary" />
</span>
</TooltipTrigger>
<TooltipContent side="top">Syncing</TooltipContent>
</Tooltip>
);
}
if (statusState === "failed") {
return (
<Tooltip>
<TooltipTrigger asChild>
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
<AlertCircle className="h-3.5 w-3.5 text-destructive" />
</span>
</TooltipTrigger>
<TooltipContent side="top" className="max-w-xs">
{doc.status?.reason || "Processing failed"}
</TooltipContent>
</Tooltip>
);
}
return (
<Tooltip>
<TooltipTrigger asChild>
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
<Clock className="h-3.5 w-3.5 text-muted-foreground/60" />
</span>
</TooltipTrigger>
<TooltipContent side="top">Pending - waiting to be synced</TooltipContent>
</Tooltip>
<Checkbox
checked={isMentioned}
onCheckedChange={handleCheckChange}
onClick={(e) => e.stopPropagation()}
className="h-3.5 w-3.5 shrink-0"
/>
);
}
if (statusState === "processing") {
return (
<Tooltip>
<TooltipTrigger asChild>
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
<Spinner size="xs" className="text-primary" />
</span>
</TooltipTrigger>
<TooltipContent side="top">Syncing</TooltipContent>
</Tooltip>
);
}
if (statusState === "failed") {
return (
<Tooltip>
<TooltipTrigger asChild>
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
<AlertCircle className="h-3.5 w-3.5 text-destructive" />
</span>
</TooltipTrigger>
<TooltipContent side="top" className="max-w-xs">
{doc.status?.reason || "Processing failed"}
</TooltipContent>
</Tooltip>
);
}
return (
<Checkbox
checked={isMentioned}
onCheckedChange={handleCheckChange}
onClick={(e) => e.stopPropagation()}
className="h-3.5 w-3.5 shrink-0"
/>
);
})()}
})()}
<span className="flex-1 min-w-0 truncate">{doc.title}</span>
@ -188,17 +198,19 @@ export const DocumentNode = React.memo(function DocumentNode({
)}
</span>
<DropdownMenu open={dropdownOpen} onOpenChange={setDropdownOpen}>
<DropdownMenuTrigger asChild>
<Button
variant="ghost"
size="icon"
className={cn(
"hidden sm:inline-flex h-6 w-6 shrink-0 hover:bg-transparent",
dropdownOpen ? "opacity-100 bg-accent hover:bg-accent" : "opacity-0 group-hover:opacity-100"
)}
onClick={(e) => e.stopPropagation()}
>
<DropdownMenu open={dropdownOpen} onOpenChange={setDropdownOpen}>
<DropdownMenuTrigger asChild>
<Button
variant="ghost"
size="icon"
className={cn(
"hidden sm:inline-flex h-6 w-6 shrink-0 hover:bg-transparent",
dropdownOpen
? "opacity-100 bg-accent hover:bg-accent"
: "opacity-0 group-hover:opacity-100"
)}
onClick={(e) => e.stopPropagation()}
>
<MoreHorizontal className="h-3.5 w-3.5" />
</Button>
</DropdownMenuTrigger>

View file

@ -15,7 +15,6 @@ import React, { useCallback, useEffect, useRef, useState } from "react";
import { useDrag, useDrop } from "react-dnd";
import { Button } from "@/components/ui/button";
import { Checkbox } from "@/components/ui/checkbox";
import type { FolderSelectionState } from "./FolderTreeView";
import {
ContextMenu,
ContextMenuContent,
@ -29,6 +28,7 @@ import {
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { cn } from "@/lib/utils";
import type { FolderSelectionState } from "./FolderTreeView";
export const DND_TYPES = {
FOLDER: "FOLDER",
@ -263,7 +263,9 @@ export const FolderNode = React.memo(function FolderNode({
</span>
<Checkbox
checked={selectionState === "all" ? true : selectionState === "some" ? "indeterminate" : false}
checked={
selectionState === "all" ? true : selectionState === "some" ? "indeterminate" : false
}
onCheckedChange={handleCheckChange}
onClick={(e) => e.stopPropagation()}
className="h-3.5 w-3.5 shrink-0"

View file

@ -33,6 +33,7 @@ interface FolderTreeViewProps {
onMoveDocument: (doc: DocumentNodeDoc) => void;
onExportDocument?: (doc: DocumentNodeDoc, format: string) => void;
activeTypes: DocumentTypeEnum[];
searchQuery?: string;
onDropIntoFolder?: (
itemType: "folder" | "document",
itemId: number,
@ -69,6 +70,7 @@ export function FolderTreeView({
onMoveDocument,
onExportDocument,
activeTypes,
searchQuery,
onDropIntoFolder,
onReorderFolder,
}: FolderTreeViewProps) {
@ -97,13 +99,13 @@ export function FolderTreeView({
const handleCancelRename = useCallback(() => setRenamingFolderId(null), [setRenamingFolderId]);
const hasDescendantMatch = useMemo(() => {
if (activeTypes.length === 0) return null;
if (activeTypes.length === 0 && !searchQuery) return null;
const match: Record<number, boolean> = {};
function check(folderId: number): boolean {
if (match[folderId] !== undefined) return match[folderId];
const childDocs = (docsByFolder[folderId] ?? []).some((d) =>
activeTypes.includes(d.document_type as DocumentTypeEnum)
const childDocs = (docsByFolder[folderId] ?? []).some(
(d) => activeTypes.length === 0 || activeTypes.includes(d.document_type as DocumentTypeEnum)
);
if (childDocs) {
match[folderId] = true;
@ -124,7 +126,7 @@ export function FolderTreeView({
check(f.id);
}
return match;
}, [folders, docsByFolder, foldersByParent, activeTypes]);
}, [folders, docsByFolder, foldersByParent, activeTypes, searchQuery]);
const folderSelectionStates = useMemo(() => {
const states: Record<number, FolderSelectionState> = {};
@ -177,12 +179,15 @@ export function FolderTreeView({
after: i < visibleFolders.length - 1 ? visibleFolders[i + 1].position : null,
};
const isAutoExpanded = !!searchQuery && !!hasDescendantMatch?.[f.id];
const isExpanded = expandedIds.has(f.id) || isAutoExpanded;
nodes.push(
<FolderNode
key={`folder-${f.id}`}
folder={f}
depth={depth}
isExpanded={expandedIds.has(f.id)}
isExpanded={isExpanded}
isRenaming={renamingFolderId === f.id}
childCount={folderChildCounts[f.id] ?? 0}
selectionState={folderSelectionStates[f.id] ?? "none"}
@ -202,7 +207,7 @@ export function FolderTreeView({
/>
);
if (expandedIds.has(f.id)) {
if (isExpanded) {
nodes.push(...renderLevel(f.id, depth + 1));
}
}
@ -240,7 +245,7 @@ export function FolderTreeView({
);
}
if (treeNodes.length === 0 && activeTypes.length > 0) {
if (treeNodes.length === 0 && (activeTypes.length > 0 || searchQuery)) {
return (
<div className="flex flex-1 flex-col items-center justify-center gap-3 px-4 py-12 text-muted-foreground">
<CirclePlus className="h-10 w-10 rotate-45" />

View file

@ -2,42 +2,50 @@
import { useQuery } from "@rocicorp/zero/react";
import { useAtom, useAtomValue, useSetAtom } from "jotai";
import { ChevronLeft, ChevronRight, Unplug } from "lucide-react";
import { ChevronLeft, ChevronRight, Trash2, Unplug } from "lucide-react";
import { useParams } from "next/navigation";
import { useTranslations } from "next-intl";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { useCallback, useEffect, useMemo, useState } from "react";
import { toast } from "sonner";
import { EXPORT_FILE_EXTENSIONS } from "@/components/shared/ExportMenuItems";
import { DocumentsFilters } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsFilters";
import {
DocumentsTableShell,
type SortKey,
} from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsTableShell";
import { sidebarSelectedDocumentsAtom } from "@/atoms/chat/mentioned-documents.atom";
import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms";
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
import { deleteDocumentMutationAtom } from "@/atoms/documents/document-mutation.atoms";
import { expandedFolderIdsAtom } from "@/atoms/documents/folder.atoms";
import { agentCreatedDocumentsAtom } from "@/atoms/documents/ui.atoms";
import { openDocumentTabAtom } from "@/atoms/tabs/tabs.atom";
import { CreateFolderDialog } from "@/components/documents/CreateFolderDialog";
import type { DocumentNodeDoc } from "@/components/documents/DocumentNode";
import type { FolderDisplay } from "@/components/documents/FolderNode";
import { FolderPickerDialog } from "@/components/documents/FolderPickerDialog";
import { FolderTreeView } from "@/components/documents/FolderTreeView";
import { EXPORT_FILE_EXTENSIONS } from "@/components/shared/ExportMenuItems";
import {
AlertDialog,
AlertDialogAction,
AlertDialogCancel,
AlertDialogContent,
AlertDialogDescription,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogTitle,
} from "@/components/ui/alert-dialog";
import { Avatar, AvatarFallback, AvatarGroup } from "@/components/ui/avatar";
import { Button } from "@/components/ui/button";
import { Spinner } from "@/components/ui/spinner";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
import type { DocumentTypeEnum } from "@/contracts/types/document.types";
import { useDebouncedValue } from "@/hooks/use-debounced-value";
import { useDocumentSearch } from "@/hooks/use-document-search";
import { useDocuments } from "@/hooks/use-documents";
import { useMediaQuery } from "@/hooks/use-media-query";
import { foldersApiService } from "@/lib/apis/folders-api.service";
import { authenticatedFetch } from "@/lib/auth-utils";
import { queries } from "@/zero/queries/index";
import { SidebarSlideOutPanel } from "./SidebarSlideOutPanel";
const NON_DELETABLE_DOCUMENT_TYPES: readonly string[] = ["SURFSENSE_DOCS"];
const SHOWCASE_CONNECTORS = [
{ type: "GOOGLE_DRIVE_CONNECTOR", label: "Google Drive" },
{ type: "GOOGLE_GMAIL_CONNECTOR", label: "Gmail" },
@ -82,8 +90,6 @@ export function DocumentsSidebar({
const [search, setSearch] = useState("");
const debouncedSearch = useDebouncedValue(search, 250);
const [activeTypes, setActiveTypes] = useState<DocumentTypeEnum[]>([]);
const [sortKey, setSortKey] = useState<SortKey>("created_at");
const [sortDesc, setSortDesc] = useState(true);
const { mutateAsync: deleteDocumentMutation } = useAtomValue(deleteDocumentMutationAtom);
const [sidebarDocs, setSidebarDocs] = useAtom(sidebarSelectedDocumentsAtom);
@ -110,6 +116,7 @@ export function DocumentsSidebar({
// Zero queries for tree data
const [zeroFolders] = useQuery(queries.folders.bySpace({ searchSpaceId }));
const [zeroAllDocs] = useQuery(queries.documents.bySpace({ searchSpaceId }));
const [agentCreatedDocs, setAgentCreatedDocs] = useAtom(agentCreatedDocumentsAtom);
const treeFolders: FolderDisplay[] = useMemo(
() =>
@ -123,19 +130,41 @@ export function DocumentsSidebar({
[zeroFolders]
);
const treeDocuments: DocumentNodeDoc[] = useMemo(
() =>
(zeroAllDocs ?? [])
.filter((d) => d.title && d.title.trim() !== "")
.map((d) => ({
id: d.id,
title: d.title,
document_type: d.documentType,
folderId: (d as { folderId?: number | null }).folderId ?? null,
status: d.status as { state: string; reason?: string | null } | undefined,
})),
[zeroAllDocs]
);
const treeDocuments: DocumentNodeDoc[] = useMemo(() => {
const zeroDocs = (zeroAllDocs ?? [])
.filter((d) => d.title && d.title.trim() !== "")
.map((d) => ({
id: d.id,
title: d.title,
document_type: d.documentType,
folderId: (d as { folderId?: number | null }).folderId ?? null,
status: d.status as { state: string; reason?: string | null } | undefined,
}));
const zeroIds = new Set(zeroDocs.map((d) => d.id));
const pendingAgentDocs = agentCreatedDocs
.filter((d) => d.searchSpaceId === searchSpaceId && !zeroIds.has(d.id))
.map((d) => ({
id: d.id,
title: d.title,
document_type: d.documentType,
folderId: d.folderId ?? null,
status: { state: "ready" } as { state: string; reason?: string | null },
}));
return [...pendingAgentDocs, ...zeroDocs];
}, [zeroAllDocs, agentCreatedDocs, searchSpaceId]);
// Prune agent-created docs once Zero has caught up
useEffect(() => {
if (!zeroAllDocs?.length || !agentCreatedDocs.length) return;
const zeroIds = new Set(zeroAllDocs.map((d) => d.id));
const remaining = agentCreatedDocs.filter((d) => !zeroIds.has(d.id));
if (remaining.length < agentCreatedDocs.length) {
setAgentCreatedDocs(remaining);
}
}, [zeroAllDocs, agentCreatedDocs, setAgentCreatedDocs]);
const foldersByParent = useMemo(() => {
const map: Record<string, FolderDisplay[]> = {};
@ -355,7 +384,7 @@ export function DocumentsSidebar({
(d) =>
d.folderId === parentId &&
d.status?.state !== "pending" &&
d.status?.state !== "processing",
d.status?.state !== "processing"
);
const childFolders = foldersByParent[String(parentId)] ?? [];
const descendantDocs = childFolders.flatMap((cf) => collectSubtreeDocs(cf.id));
@ -382,38 +411,72 @@ export function DocumentsSidebar({
setSidebarDocs((prev) => prev.filter((d) => !idsToRemove.has(d.id)));
}
},
[treeDocuments, foldersByParent, setSidebarDocs],
[treeDocuments, foldersByParent, setSidebarDocs]
);
const isSearchMode = !!debouncedSearch.trim();
const searchFilteredDocuments = useMemo(() => {
const query = debouncedSearch.trim().toLowerCase();
if (!query) return treeDocuments;
return treeDocuments.filter((d) => d.title.toLowerCase().includes(query));
}, [treeDocuments, debouncedSearch]);
const {
documents: realtimeDocuments,
typeCounts: realtimeTypeCounts,
loading: realtimeLoading,
loadingMore: realtimeLoadingMore,
hasMore: realtimeHasMore,
loadMore: realtimeLoadMore,
removeItems: realtimeRemoveItems,
error: realtimeError,
} = useDocuments(searchSpaceId, activeTypes, sortKey, sortDesc ? "desc" : "asc");
const typeCounts = useMemo(() => {
const counts: Partial<Record<string, number>> = {};
for (const d of treeDocuments) {
counts[d.document_type] = (counts[d.document_type] || 0) + 1;
}
return counts;
}, [treeDocuments]);
const {
documents: searchDocuments,
loading: searchLoading,
loadingMore: searchLoadingMore,
hasMore: searchHasMore,
loadMore: searchLoadMore,
error: searchError,
removeItems: searchRemoveItems,
} = useDocumentSearch(searchSpaceId, debouncedSearch, activeTypes, isSearchMode && open);
const deletableSelectedIds = useMemo(() => {
const treeDocMap = new Map(treeDocuments.map((d) => [d.id, d]));
return sidebarDocs
.filter((doc) => {
const fullDoc = treeDocMap.get(doc.id);
if (!fullDoc) return false;
const state = fullDoc.status?.state ?? "ready";
return (
state !== "pending" &&
state !== "processing" &&
!NON_DELETABLE_DOCUMENT_TYPES.includes(doc.document_type)
);
})
.map((doc) => doc.id);
}, [sidebarDocs, treeDocuments]);
const displayDocs = isSearchMode ? searchDocuments : realtimeDocuments;
const loading = isSearchMode ? searchLoading : realtimeLoading;
const error = isSearchMode ? searchError : !!realtimeError;
const hasMore = isSearchMode ? searchHasMore : realtimeHasMore;
const loadingMore = isSearchMode ? searchLoadingMore : realtimeLoadingMore;
const onLoadMore = isSearchMode ? searchLoadMore : realtimeLoadMore;
const [bulkDeleteConfirmOpen, setBulkDeleteConfirmOpen] = useState(false);
const [isBulkDeleting, setIsBulkDeleting] = useState(false);
const handleBulkDeleteSelected = useCallback(async () => {
if (deletableSelectedIds.length === 0) return;
setIsBulkDeleting(true);
try {
const results = await Promise.allSettled(
deletableSelectedIds.map(async (id) => {
await deleteDocumentMutation({ id });
return id;
})
);
const successIds = results
.filter((r): r is PromiseFulfilledResult<number> => r.status === "fulfilled")
.map((r) => r.value);
const failed = results.length - successIds.length;
if (successIds.length > 0) {
setSidebarDocs((prev) => {
const idSet = new Set(successIds);
return prev.filter((d) => !idSet.has(d.id));
});
toast.success(`Deleted ${successIds.length} document${successIds.length !== 1 ? "s" : ""}`);
}
if (failed > 0) {
toast.error(`Failed to delete ${failed} document${failed !== 1 ? "s" : ""}`);
}
} catch {
toast.error("Failed to delete documents");
}
setIsBulkDeleting(false);
setBulkDeleteConfirmOpen(false);
}, [deletableSelectedIds, deleteDocumentMutation, setSidebarDocs]);
const onToggleType = useCallback((type: DocumentTypeEnum, checked: boolean) => {
setActiveTypes((prev) => {
@ -430,69 +493,15 @@ export function DocumentsSidebar({
await deleteDocumentMutation({ id });
toast.success(t("delete_success") || "Document deleted");
setSidebarDocs((prev) => prev.filter((d) => d.id !== id));
realtimeRemoveItems([id]);
if (isSearchMode) {
searchRemoveItems([id]);
}
return true;
} catch (e) {
console.error("Error deleting document:", e);
return false;
}
},
[
deleteDocumentMutation,
isSearchMode,
t,
searchRemoveItems,
realtimeRemoveItems,
setSidebarDocs,
]
[deleteDocumentMutation, t, setSidebarDocs]
);
const handleBulkDeleteDocuments = useCallback(
async (ids: number[]): Promise<{ success: number; failed: number }> => {
const successIds: number[] = [];
const results = await Promise.allSettled(
ids.map(async (id) => {
await deleteDocumentMutation({ id });
successIds.push(id);
})
);
if (successIds.length > 0) {
setSidebarDocs((prev) => prev.filter((d) => !successIds.includes(d.id)));
realtimeRemoveItems(successIds);
if (isSearchMode) {
searchRemoveItems(successIds);
}
}
const success = results.filter((r) => r.status === "fulfilled").length;
const failed = results.filter((r) => r.status === "rejected").length;
return { success, failed };
},
[deleteDocumentMutation, isSearchMode, searchRemoveItems, realtimeRemoveItems, setSidebarDocs]
);
const sortKeyRef = useRef(sortKey);
const sortDescRef = useRef(sortDesc);
sortKeyRef.current = sortKey;
sortDescRef.current = sortDesc;
const handleSortChange = useCallback((key: SortKey) => {
const currentKey = sortKeyRef.current;
const currentDesc = sortDescRef.current;
if (currentKey === key && currentDesc) {
setSortKey("created_at");
setSortDesc(true);
} else if (currentKey === key) {
setSortDesc(true);
} else {
setSortKey(key);
setSortDesc(false);
}
}, []);
useEffect(() => {
const handleEscape = (e: KeyboardEvent) => {
if (e.key === "Escape" && open) {
@ -627,7 +636,7 @@ export function DocumentsSidebar({
<div className="flex-1 min-h-0 overflow-x-hidden pt-0 flex flex-col">
<div className="px-4 pb-2">
<DocumentsFilters
typeCounts={realtimeTypeCounts}
typeCounts={typeCounts}
onSearch={setSearch}
searchValue={search}
onToggleType={onToggleType}
@ -636,59 +645,54 @@ export function DocumentsSidebar({
/>
</div>
{isSearchMode ? (
<DocumentsTableShell
documents={displayDocs}
loading={!!loading}
error={!!error}
sortKey={sortKey}
sortDesc={sortDesc}
onSortChange={handleSortChange}
deleteDocument={handleDeleteDocument}
bulkDeleteDocuments={handleBulkDeleteDocuments}
searchSpaceId={String(searchSpaceId)}
hasMore={hasMore}
loadingMore={loadingMore}
onLoadMore={onLoadMore}
mentionedDocIds={mentionedDocIds}
onToggleChatMention={handleToggleChatMention}
isSearchMode={isSearchMode || activeTypes.length > 0}
/>
) : (
<FolderTreeView
folders={treeFolders}
documents={treeDocuments}
expandedIds={expandedIds}
onToggleExpand={toggleFolderExpand}
mentionedDocIds={mentionedDocIds}
onToggleChatMention={handleToggleChatMention}
onToggleFolderSelect={handleToggleFolderSelect}
onRenameFolder={handleRenameFolder}
onDeleteFolder={handleDeleteFolder}
onMoveFolder={handleMoveFolder}
onCreateFolder={handleCreateFolder}
onPreviewDocument={(doc) => {
openDocumentTab({
documentId: doc.id,
searchSpaceId,
title: doc.title,
});
}}
onEditDocument={(doc) => {
openDocumentTab({
documentId: doc.id,
searchSpaceId,
title: doc.title,
});
}}
onDeleteDocument={(doc) => handleDeleteDocument(doc.id)}
onMoveDocument={handleMoveDocument}
onExportDocument={handleExportDocument}
activeTypes={activeTypes}
onDropIntoFolder={handleDropIntoFolder}
onReorderFolder={handleReorderFolder}
/>
{deletableSelectedIds.length > 0 && (
<div className="shrink-0 flex items-center justify-center px-4 py-1.5 animate-in fade-in duration-150">
<button
type="button"
onClick={() => setBulkDeleteConfirmOpen(true)}
className="flex items-center gap-1.5 px-3 py-1 rounded-md bg-destructive text-destructive-foreground shadow-sm text-xs font-medium hover:bg-destructive/90 transition-colors"
>
<Trash2 size={12} />
Delete {deletableSelectedIds.length}{" "}
{deletableSelectedIds.length === 1 ? "item" : "items"}
</button>
</div>
)}
<FolderTreeView
folders={treeFolders}
documents={searchFilteredDocuments}
expandedIds={expandedIds}
onToggleExpand={toggleFolderExpand}
mentionedDocIds={mentionedDocIds}
onToggleChatMention={handleToggleChatMention}
onToggleFolderSelect={handleToggleFolderSelect}
onRenameFolder={handleRenameFolder}
onDeleteFolder={handleDeleteFolder}
onMoveFolder={handleMoveFolder}
onCreateFolder={handleCreateFolder}
searchQuery={debouncedSearch.trim() || undefined}
onPreviewDocument={(doc) => {
openDocumentTab({
documentId: doc.id,
searchSpaceId,
title: doc.title,
});
}}
onEditDocument={(doc) => {
openDocumentTab({
documentId: doc.id,
searchSpaceId,
title: doc.title,
});
}}
onDeleteDocument={(doc) => handleDeleteDocument(doc.id)}
onMoveDocument={handleMoveDocument}
onExportDocument={handleExportDocument}
activeTypes={activeTypes}
onDropIntoFolder={handleDropIntoFolder}
onReorderFolder={handleReorderFolder}
/>
</div>
<FolderPickerDialog
@ -707,6 +711,40 @@ export function DocumentsSidebar({
parentFolderName={createFolderParentName}
onConfirm={handleCreateFolderConfirm}
/>
<AlertDialog
open={bulkDeleteConfirmOpen}
onOpenChange={(open) => !open && !isBulkDeleting && setBulkDeleteConfirmOpen(false)}
>
<AlertDialogContent>
<AlertDialogHeader>
<AlertDialogTitle>
Delete {deletableSelectedIds.length} document
{deletableSelectedIds.length !== 1 ? "s" : ""}?
</AlertDialogTitle>
<AlertDialogDescription>
This action cannot be undone.{" "}
{deletableSelectedIds.length === 1
? "This document"
: `These ${deletableSelectedIds.length} documents`}{" "}
will be permanently deleted from your search space.
</AlertDialogDescription>
</AlertDialogHeader>
<AlertDialogFooter>
<AlertDialogCancel disabled={isBulkDeleting}>Cancel</AlertDialogCancel>
<AlertDialogAction
onClick={(e) => {
e.preventDefault();
handleBulkDeleteSelected();
}}
disabled={isBulkDeleting}
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
>
{isBulkDeleting ? <Spinner size="sm" /> : "Delete"}
</AlertDialogAction>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialog>
</>
);

View file

@ -7,10 +7,10 @@ import { useTheme } from "next-themes";
import { useCallback, useEffect, useRef, useState } from "react";
import { createPortal } from "react-dom";
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
import { useZeroDocumentTypeCounts } from "@/hooks/use-zero-document-type-counts";
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
import { useIsMobile } from "@/hooks/use-mobile";
import { useZeroDocumentTypeCounts } from "@/hooks/use-zero-document-type-counts";
import { fetchThreads } from "@/lib/chat/thread-persistence";
interface TourStep {

View file

@ -9,6 +9,7 @@ import { currentThreadAtom } from "@/atoms/chat/current-thread.atom";
import { closeReportPanelAtom, reportPanelAtom } from "@/atoms/chat/report-panel.atom";
import { PlateEditor } from "@/components/editor/plate-editor";
import { MarkdownViewer } from "@/components/markdown-viewer";
import { EXPORT_FILE_EXTENSIONS, ExportDropdownItems } from "@/components/shared/ExportMenuItems";
import { Button } from "@/components/ui/button";
import { Drawer, DrawerContent, DrawerHandle, DrawerTitle } from "@/components/ui/drawer";
import {
@ -17,7 +18,6 @@ import {
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { ExportDropdownItems, EXPORT_FILE_EXTENSIONS } from "@/components/shared/ExportMenuItems";
import { useMediaQuery } from "@/hooks/use-media-query";
import { baseApiService } from "@/lib/apis/base-api.service";
import { authenticatedFetch } from "@/lib/auth-utils";

View file

@ -1,8 +1,12 @@
"use client";
import { Loader2 } from "lucide-react";
import { DropdownMenuItem, DropdownMenuLabel, DropdownMenuSeparator } from "@/components/ui/dropdown-menu";
import { ContextMenuItem } from "@/components/ui/context-menu";
import {
DropdownMenuItem,
DropdownMenuLabel,
DropdownMenuSeparator,
} from "@/components/ui/dropdown-menu";
export const EXPORT_FILE_EXTENSIONS: Record<string, string> = {
pdf: "pdf",
@ -36,9 +40,7 @@ export function ExportDropdownItems({
<>
{showAllFormats && (
<>
<DropdownMenuLabel className="text-xs text-muted-foreground">
Documents
</DropdownMenuLabel>
<DropdownMenuLabel className="text-xs text-muted-foreground">Documents</DropdownMenuLabel>
<DropdownMenuItem onClick={handle("pdf")} disabled={exporting !== null}>
{exporting === "pdf" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
PDF (.pdf)

View file

@ -287,13 +287,9 @@ function ApprovalCard({
? pendingEdits.end_datetime
: null,
new_location:
pendingEdits.location !== (event?.location ?? "")
? pendingEdits.location || null
: null,
pendingEdits.location !== (event?.location ?? "") ? pendingEdits.location || null : null,
new_attendees:
attendeesArr && attendeesArr.join(",") !== origAttendees.join(",")
? attendeesArr
: null,
attendeesArr && attendeesArr.join(",") !== origAttendees.join(",") ? attendeesArr : null,
};
}
return {

View file

@ -1,7 +1,6 @@
import {
BookOpen,
Brain,
Database,
FileText,
Film,
Globe,
@ -13,7 +12,6 @@ import {
} from "lucide-react";
const TOOL_ICONS: Record<string, LucideIcon> = {
search_knowledge_base: Database,
generate_podcast: Podcast,
generate_video_presentation: Film,
generate_report: FileText,

View file

@ -13,9 +13,7 @@ export function useZeroDocumentTypeCounts(
): Record<string, number> | undefined {
const numericId = searchSpaceId != null ? Number(searchSpaceId) : null;
const [zeroDocuments] = useQuery(
queries.documents.bySpace({ searchSpaceId: numericId ?? -1 })
);
const [zeroDocuments] = useQuery(queries.documents.bySpace({ searchSpaceId: numericId ?? -1 }));
return useMemo(() => {
if (!zeroDocuments || numericId == null) return undefined;