mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
feat: made agent file sytem optimized
This commit is contained in:
parent
ee0b59c0fa
commit
2cc2d339e6
67 changed files with 8011 additions and 5591 deletions
|
|
@ -1 +0,0 @@
|
||||||
Subproject commit a32ce7ff6b2112cf48170d2279a1953eded61987
|
|
||||||
|
|
@ -169,13 +169,3 @@ LANGSMITH_TRACING=true
|
||||||
LANGSMITH_ENDPOINT=https://api.smith.langchain.com
|
LANGSMITH_ENDPOINT=https://api.smith.langchain.com
|
||||||
LANGSMITH_API_KEY=lsv2_pt_.....
|
LANGSMITH_API_KEY=lsv2_pt_.....
|
||||||
LANGSMITH_PROJECT=surfsense
|
LANGSMITH_PROJECT=surfsense
|
||||||
|
|
||||||
# Agent Specific Configuration
|
|
||||||
# Daytona Sandbox (secure cloud code execution for deep agent)
|
|
||||||
# Set DAYTONA_SANDBOX_ENABLED=TRUE to give the agent an isolated execute tool
|
|
||||||
DAYTONA_SANDBOX_ENABLED=TRUE
|
|
||||||
DAYTONA_API_KEY=dtn_asdasfasfafas
|
|
||||||
DAYTONA_API_URL=https://app.daytona.io/api
|
|
||||||
DAYTONA_TARGET=us
|
|
||||||
# Directory for locally-persisted sandbox files (after sandbox deletion)
|
|
||||||
SANDBOX_FILES_DIR=sandbox_files
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,12 @@
|
||||||
"""
|
"""
|
||||||
SurfSense New Chat Agent Module.
|
SurfSense New Chat Agent Module.
|
||||||
|
|
||||||
This module provides the SurfSense deep agent with configurable tools
|
This module provides the SurfSense deep agent with configurable tools,
|
||||||
for knowledge base search, podcast generation, and more.
|
middleware, and preloaded knowledge-base filesystem behavior.
|
||||||
|
|
||||||
Directory Structure:
|
Directory Structure:
|
||||||
- tools/: All agent tools (knowledge_base, podcast, generate_image, etc.)
|
- tools/: All agent tools (podcast, generate_image, web, memory, etc.)
|
||||||
|
- middleware/: Custom middleware (knowledge search, filesystem, dedup, etc.)
|
||||||
- chat_deepagent.py: Main agent factory
|
- chat_deepagent.py: Main agent factory
|
||||||
- system_prompt.py: System prompts and instructions
|
- system_prompt.py: System prompts and instructions
|
||||||
- context.py: Context schema for the agent
|
- context.py: Context schema for the agent
|
||||||
|
|
@ -23,6 +24,13 @@ from .context import SurfSenseContextSchema
|
||||||
# LLM config
|
# LLM config
|
||||||
from .llm_config import create_chat_litellm_from_config, load_llm_config_from_yaml
|
from .llm_config import create_chat_litellm_from_config, load_llm_config_from_yaml
|
||||||
|
|
||||||
|
# Middleware
|
||||||
|
from .middleware import (
|
||||||
|
DedupHITLToolCallsMiddleware,
|
||||||
|
KnowledgeBaseSearchMiddleware,
|
||||||
|
SurfSenseFilesystemMiddleware,
|
||||||
|
)
|
||||||
|
|
||||||
# System prompt
|
# System prompt
|
||||||
from .system_prompt import (
|
from .system_prompt import (
|
||||||
SURFSENSE_CITATION_INSTRUCTIONS,
|
SURFSENSE_CITATION_INSTRUCTIONS,
|
||||||
|
|
@ -39,7 +47,6 @@ from .tools import (
|
||||||
build_tools,
|
build_tools,
|
||||||
create_generate_podcast_tool,
|
create_generate_podcast_tool,
|
||||||
create_scrape_webpage_tool,
|
create_scrape_webpage_tool,
|
||||||
create_search_knowledge_base_tool,
|
|
||||||
format_documents_for_context,
|
format_documents_for_context,
|
||||||
get_all_tool_names,
|
get_all_tool_names,
|
||||||
get_default_enabled_tools,
|
get_default_enabled_tools,
|
||||||
|
|
@ -53,8 +60,12 @@ __all__ = [
|
||||||
# System prompt
|
# System prompt
|
||||||
"SURFSENSE_CITATION_INSTRUCTIONS",
|
"SURFSENSE_CITATION_INSTRUCTIONS",
|
||||||
"SURFSENSE_SYSTEM_PROMPT",
|
"SURFSENSE_SYSTEM_PROMPT",
|
||||||
|
# Middleware
|
||||||
|
"DedupHITLToolCallsMiddleware",
|
||||||
|
"KnowledgeBaseSearchMiddleware",
|
||||||
# Context
|
# Context
|
||||||
"SurfSenseContextSchema",
|
"SurfSenseContextSchema",
|
||||||
|
"SurfSenseFilesystemMiddleware",
|
||||||
"ToolDefinition",
|
"ToolDefinition",
|
||||||
"build_surfsense_system_prompt",
|
"build_surfsense_system_prompt",
|
||||||
"build_tools",
|
"build_tools",
|
||||||
|
|
@ -63,7 +74,6 @@ __all__ = [
|
||||||
# Tool factories
|
# Tool factories
|
||||||
"create_generate_podcast_tool",
|
"create_generate_podcast_tool",
|
||||||
"create_scrape_webpage_tool",
|
"create_scrape_webpage_tool",
|
||||||
"create_search_knowledge_base_tool",
|
|
||||||
# Agent factory
|
# Agent factory
|
||||||
"create_surfsense_deep_agent",
|
"create_surfsense_deep_agent",
|
||||||
# Knowledge base utilities
|
# Knowledge base utilities
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,13 @@ SurfSense deep agent implementation.
|
||||||
This module provides the factory function for creating SurfSense deep agents
|
This module provides the factory function for creating SurfSense deep agents
|
||||||
with configurable tools via the tools registry and configurable prompts
|
with configurable tools via the tools registry and configurable prompts
|
||||||
via NewLLMConfig.
|
via NewLLMConfig.
|
||||||
|
|
||||||
|
We use ``create_agent`` (from langchain) rather than ``create_deep_agent``
|
||||||
|
(from deepagents) so that the middleware stack is fully under our control.
|
||||||
|
This lets us swap in ``SurfSenseFilesystemMiddleware`` — a customisable
|
||||||
|
subclass of the default ``FilesystemMiddleware`` — while preserving every
|
||||||
|
other behaviour that ``create_deep_agent`` provides (todo-list, subagents,
|
||||||
|
summarisation, prompt-caching, etc.).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -12,8 +19,15 @@ import time
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from deepagents import create_deep_agent
|
from deepagents import SubAgent, SubAgentMiddleware, __version__ as deepagents_version
|
||||||
from deepagents.backends.protocol import SandboxBackendProtocol
|
from deepagents.backends import StateBackend
|
||||||
|
from deepagents.graph import BASE_AGENT_PROMPT
|
||||||
|
from deepagents.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||||
|
from deepagents.middleware.subagents import GENERAL_PURPOSE_SUBAGENT
|
||||||
|
from deepagents.middleware.summarization import create_summarization_middleware
|
||||||
|
from langchain.agents import create_agent
|
||||||
|
from langchain.agents.middleware import TodoListMiddleware
|
||||||
|
from langchain_anthropic.middleware import AnthropicPromptCachingMiddleware
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langgraph.types import Checkpointer
|
from langgraph.types import Checkpointer
|
||||||
|
|
@ -21,8 +35,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.new_chat.context import SurfSenseContextSchema
|
from app.agents.new_chat.context import SurfSenseContextSchema
|
||||||
from app.agents.new_chat.llm_config import AgentConfig
|
from app.agents.new_chat.llm_config import AgentConfig
|
||||||
from app.agents.new_chat.middleware.dedup_tool_calls import (
|
from app.agents.new_chat.middleware import (
|
||||||
DedupHITLToolCallsMiddleware,
|
DedupHITLToolCallsMiddleware,
|
||||||
|
KnowledgeBaseSearchMiddleware,
|
||||||
|
SurfSenseFilesystemMiddleware,
|
||||||
)
|
)
|
||||||
from app.agents.new_chat.system_prompt import (
|
from app.agents.new_chat.system_prompt import (
|
||||||
build_configurable_system_prompt,
|
build_configurable_system_prompt,
|
||||||
|
|
@ -40,15 +56,15 @@ _perf_log = get_perf_logger()
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
# Maps SearchSourceConnectorType enum values to the searchable document/connector types
|
# Maps SearchSourceConnectorType enum values to the searchable document/connector types
|
||||||
# used by the knowledge_base and web_search tools.
|
# used by pre-search middleware and web_search.
|
||||||
# Live search connectors (TAVILY_API, LINKUP_API, BAIDU_SEARCH_API) are routed to
|
# Live search connectors (TAVILY_API, LINKUP_API, BAIDU_SEARCH_API) are routed to
|
||||||
# the web_search tool; all others go to search_knowledge_base.
|
# the web_search tool; all others are considered local/indexed data.
|
||||||
_CONNECTOR_TYPE_TO_SEARCHABLE: dict[str, str] = {
|
_CONNECTOR_TYPE_TO_SEARCHABLE: dict[str, str] = {
|
||||||
# Live search connectors (handled by web_search tool)
|
# Live search connectors (handled by web_search tool)
|
||||||
"TAVILY_API": "TAVILY_API",
|
"TAVILY_API": "TAVILY_API",
|
||||||
"LINKUP_API": "LINKUP_API",
|
"LINKUP_API": "LINKUP_API",
|
||||||
"BAIDU_SEARCH_API": "BAIDU_SEARCH_API",
|
"BAIDU_SEARCH_API": "BAIDU_SEARCH_API",
|
||||||
# Local/indexed connectors (handled by search_knowledge_base tool)
|
# Local/indexed connectors (handled by KB pre-search middleware)
|
||||||
"SLACK_CONNECTOR": "SLACK_CONNECTOR",
|
"SLACK_CONNECTOR": "SLACK_CONNECTOR",
|
||||||
"TEAMS_CONNECTOR": "TEAMS_CONNECTOR",
|
"TEAMS_CONNECTOR": "TEAMS_CONNECTOR",
|
||||||
"NOTION_CONNECTOR": "NOTION_CONNECTOR",
|
"NOTION_CONNECTOR": "NOTION_CONNECTOR",
|
||||||
|
|
@ -141,13 +157,11 @@ async def create_surfsense_deep_agent(
|
||||||
additional_tools: Sequence[BaseTool] | None = None,
|
additional_tools: Sequence[BaseTool] | None = None,
|
||||||
firecrawl_api_key: str | None = None,
|
firecrawl_api_key: str | None = None,
|
||||||
thread_visibility: ChatVisibility | None = None,
|
thread_visibility: ChatVisibility | None = None,
|
||||||
sandbox_backend: SandboxBackendProtocol | None = None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a SurfSense deep agent with configurable tools and prompts.
|
Create a SurfSense deep agent with configurable tools and prompts.
|
||||||
|
|
||||||
The agent comes with built-in tools that can be configured:
|
The agent comes with built-in tools that can be configured:
|
||||||
- search_knowledge_base: Search the user's personal knowledge base
|
|
||||||
- generate_podcast: Generate audio podcasts from content
|
- generate_podcast: Generate audio podcasts from content
|
||||||
- generate_image: Generate images from text descriptions using AI models
|
- generate_image: Generate images from text descriptions using AI models
|
||||||
- scrape_webpage: Extract content from webpages
|
- scrape_webpage: Extract content from webpages
|
||||||
|
|
@ -179,9 +193,6 @@ async def create_surfsense_deep_agent(
|
||||||
These are always added regardless of enabled/disabled settings.
|
These are always added regardless of enabled/disabled settings.
|
||||||
firecrawl_api_key: Optional Firecrawl API key for premium web scraping.
|
firecrawl_api_key: Optional Firecrawl API key for premium web scraping.
|
||||||
Falls back to Chromium/Trafilatura if not provided.
|
Falls back to Chromium/Trafilatura if not provided.
|
||||||
sandbox_backend: Optional sandbox backend (e.g. DaytonaSandbox) for
|
|
||||||
secure code execution. When provided, the agent gets an
|
|
||||||
isolated ``execute`` tool for running shell commands.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
CompiledStateGraph: The configured deep agent
|
CompiledStateGraph: The configured deep agent
|
||||||
|
|
@ -205,7 +216,7 @@ async def create_surfsense_deep_agent(
|
||||||
# Create agent with only specific tools
|
# Create agent with only specific tools
|
||||||
agent = create_surfsense_deep_agent(
|
agent = create_surfsense_deep_agent(
|
||||||
llm, search_space_id, db_session, ...,
|
llm, search_space_id, db_session, ...,
|
||||||
enabled_tools=["search_knowledge_base", "scrape_webpage"]
|
enabled_tools=["scrape_webpage"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create agent without podcast generation
|
# Create agent without podcast generation
|
||||||
|
|
@ -357,6 +368,10 @@ async def create_surfsense_deep_agent(
|
||||||
]
|
]
|
||||||
modified_disabled_tools.extend(confluence_tools)
|
modified_disabled_tools.extend(confluence_tools)
|
||||||
|
|
||||||
|
# Remove direct KB search tool; we now pre-seed a scoped filesystem via middleware.
|
||||||
|
if "search_knowledge_base" not in modified_disabled_tools:
|
||||||
|
modified_disabled_tools.append("search_knowledge_base")
|
||||||
|
|
||||||
# Build tools using the async registry (includes MCP tools)
|
# Build tools using the async registry (includes MCP tools)
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
tools = await build_tools_async(
|
tools = await build_tools_async(
|
||||||
|
|
@ -373,7 +388,6 @@ async def create_surfsense_deep_agent(
|
||||||
|
|
||||||
# Build system prompt based on agent_config, scoped to the tools actually enabled
|
# Build system prompt based on agent_config, scoped to the tools actually enabled
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
_sandbox_enabled = sandbox_backend is not None
|
|
||||||
_enabled_tool_names = {t.name for t in tools}
|
_enabled_tool_names = {t.name for t in tools}
|
||||||
_user_disabled_tool_names = set(disabled_tools) if disabled_tools else set()
|
_user_disabled_tool_names = set(disabled_tools) if disabled_tools else set()
|
||||||
if agent_config is not None:
|
if agent_config is not None:
|
||||||
|
|
@ -382,14 +396,12 @@ async def create_surfsense_deep_agent(
|
||||||
use_default_system_instructions=agent_config.use_default_system_instructions,
|
use_default_system_instructions=agent_config.use_default_system_instructions,
|
||||||
citations_enabled=agent_config.citations_enabled,
|
citations_enabled=agent_config.citations_enabled,
|
||||||
thread_visibility=thread_visibility,
|
thread_visibility=thread_visibility,
|
||||||
sandbox_enabled=_sandbox_enabled,
|
|
||||||
enabled_tool_names=_enabled_tool_names,
|
enabled_tool_names=_enabled_tool_names,
|
||||||
disabled_tool_names=_user_disabled_tool_names,
|
disabled_tool_names=_user_disabled_tool_names,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
system_prompt = build_surfsense_system_prompt(
|
system_prompt = build_surfsense_system_prompt(
|
||||||
thread_visibility=thread_visibility,
|
thread_visibility=thread_visibility,
|
||||||
sandbox_enabled=_sandbox_enabled,
|
|
||||||
enabled_tool_names=_enabled_tool_names,
|
enabled_tool_names=_enabled_tool_names,
|
||||||
disabled_tool_names=_user_disabled_tool_names,
|
disabled_tool_names=_user_disabled_tool_names,
|
||||||
)
|
)
|
||||||
|
|
@ -397,24 +409,69 @@ async def create_surfsense_deep_agent(
|
||||||
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
|
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build optional kwargs for the deep agent
|
# -- Build the middleware stack (mirrors create_deep_agent internals) ------
|
||||||
deep_agent_kwargs: dict[str, Any] = {}
|
# General-purpose subagent middleware
|
||||||
if sandbox_backend is not None:
|
gp_middleware = [
|
||||||
deep_agent_kwargs["backend"] = sandbox_backend
|
TodoListMiddleware(),
|
||||||
|
SurfSenseFilesystemMiddleware(
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
created_by_id=user_id,
|
||||||
|
),
|
||||||
|
create_summarization_middleware(llm, StateBackend),
|
||||||
|
PatchToolCallsMiddleware(),
|
||||||
|
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||||
|
]
|
||||||
|
|
||||||
|
general_purpose_spec: SubAgent = { # type: ignore[typeddict-unknown-key]
|
||||||
|
**GENERAL_PURPOSE_SUBAGENT,
|
||||||
|
"model": llm,
|
||||||
|
"tools": tools,
|
||||||
|
"middleware": gp_middleware,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Main agent middleware
|
||||||
|
deepagent_middleware = [
|
||||||
|
TodoListMiddleware(),
|
||||||
|
KnowledgeBaseSearchMiddleware(
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
available_connectors=available_connectors,
|
||||||
|
available_document_types=available_document_types,
|
||||||
|
),
|
||||||
|
SurfSenseFilesystemMiddleware(
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
created_by_id=user_id,
|
||||||
|
),
|
||||||
|
SubAgentMiddleware(backend=StateBackend, subagents=[general_purpose_spec]),
|
||||||
|
create_summarization_middleware(llm, StateBackend),
|
||||||
|
PatchToolCallsMiddleware(),
|
||||||
|
DedupHITLToolCallsMiddleware(),
|
||||||
|
AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Combine system_prompt with BASE_AGENT_PROMPT (same as create_deep_agent)
|
||||||
|
final_system_prompt = system_prompt + "\n\n" + BASE_AGENT_PROMPT
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
agent = await asyncio.to_thread(
|
agent = await asyncio.to_thread(
|
||||||
create_deep_agent,
|
create_agent,
|
||||||
model=llm,
|
llm,
|
||||||
|
system_prompt=final_system_prompt,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
system_prompt=system_prompt,
|
middleware=deepagent_middleware,
|
||||||
context_schema=SurfSenseContextSchema,
|
context_schema=SurfSenseContextSchema,
|
||||||
checkpointer=checkpointer,
|
checkpointer=checkpointer,
|
||||||
middleware=[DedupHITLToolCallsMiddleware()],
|
)
|
||||||
**deep_agent_kwargs,
|
agent = agent.with_config(
|
||||||
|
{
|
||||||
|
"recursion_limit": 10_000,
|
||||||
|
"metadata": {
|
||||||
|
"ls_integration": "deepagents",
|
||||||
|
"versions": {"deepagents": deepagents_version},
|
||||||
|
},
|
||||||
|
}
|
||||||
)
|
)
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[create_agent] Graph compiled (create_deep_agent) in %.3fs",
|
"[create_agent] Graph compiled (create_agent) in %.3fs",
|
||||||
time.perf_counter() - _t0,
|
time.perf_counter() - _t0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
"""Middleware components for the SurfSense new chat agent."""
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.dedup_tool_calls import (
|
||||||
|
DedupHITLToolCallsMiddleware,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.middleware.filesystem import (
|
||||||
|
SurfSenseFilesystemMiddleware,
|
||||||
|
)
|
||||||
|
from app.agents.new_chat.middleware.knowledge_search import (
|
||||||
|
KnowledgeBaseSearchMiddleware,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DedupHITLToolCallsMiddleware",
|
||||||
|
"KnowledgeBaseSearchMiddleware",
|
||||||
|
"SurfSenseFilesystemMiddleware",
|
||||||
|
]
|
||||||
694
surfsense_backend/app/agents/new_chat/middleware/filesystem.py
Normal file
694
surfsense_backend/app/agents/new_chat/middleware/filesystem.py
Normal file
|
|
@ -0,0 +1,694 @@
|
||||||
|
"""Custom filesystem middleware for the SurfSense agent.
|
||||||
|
|
||||||
|
This middleware customizes prompts and persists write/edit operations for
|
||||||
|
`/documents/*` files into SurfSense's `Document`/`Chunk` tables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import re
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
from deepagents import FilesystemMiddleware
|
||||||
|
from deepagents.backends.protocol import EditResult, WriteResult
|
||||||
|
from deepagents.backends.utils import validate_path
|
||||||
|
from deepagents.middleware.filesystem import FilesystemState
|
||||||
|
from fractional_indexing import generate_key_between
|
||||||
|
from langchain.tools import ToolRuntime
|
||||||
|
from langchain_core.callbacks import dispatch_custom_event
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
from langchain_core.tools import BaseTool, StructuredTool
|
||||||
|
from langgraph.types import Command
|
||||||
|
from sqlalchemy import delete, select
|
||||||
|
|
||||||
|
from app.db import Chunk, Document, DocumentType, Folder, shielded_async_session
|
||||||
|
from app.indexing_pipeline.document_chunker import chunk_text
|
||||||
|
from app.utils.document_converters import (
|
||||||
|
embed_texts,
|
||||||
|
generate_content_hash,
|
||||||
|
generate_unique_identifier_hash,
|
||||||
|
)
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# System Prompt (injected into every model call by wrap_model_call)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
SURFSENSE_FILESYSTEM_SYSTEM_PROMPT = """## Following Conventions
|
||||||
|
|
||||||
|
- Read files before editing — understand existing content before making changes.
|
||||||
|
- Mimic existing style, naming conventions, and patterns.
|
||||||
|
|
||||||
|
## Filesystem Tools `ls`, `read_file`, `write_file`, `edit_file`, `glob`, `grep`, `save_document`
|
||||||
|
|
||||||
|
All file paths must start with a `/`.
|
||||||
|
- ls: list files and directories at a given path.
|
||||||
|
- read_file: read a file from the filesystem.
|
||||||
|
- write_file: create a temporary file in the session (not persisted).
|
||||||
|
- edit_file: edit a file in the session (not persisted for /documents/ files).
|
||||||
|
- glob: find files matching a pattern (e.g., "**/*.xml").
|
||||||
|
- grep: search for text within files.
|
||||||
|
- save_document: **permanently** save a new document to the user's knowledge
|
||||||
|
base. Use only when the user explicitly asks to save/create a document.
|
||||||
|
|
||||||
|
## Reading Documents Efficiently
|
||||||
|
|
||||||
|
Documents are formatted as XML. Each document contains:
|
||||||
|
- `<document_metadata>` — title, type, URL, etc.
|
||||||
|
- `<chunk_index>` — a table of every chunk with its **line range** and a
|
||||||
|
`matched="true"` flag for chunks that matched the search query.
|
||||||
|
- `<document_content>` — the actual chunks in original document order.
|
||||||
|
|
||||||
|
**Workflow**: when reading a large document, read the first ~20 lines to see
|
||||||
|
the `<chunk_index>`, identify chunks marked `matched="true"`, then use
|
||||||
|
`read_file(path, offset=<start_line>, limit=<lines>)` to jump directly to
|
||||||
|
those sections instead of reading the entire file sequentially.
|
||||||
|
|
||||||
|
Use `<chunk id='...'>` values as citation IDs in your answers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Per-Tool Descriptions (shown to the LLM as the tool's docstring)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
SURFSENSE_LIST_FILES_TOOL_DESCRIPTION = """Lists files and directories at the given path.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SURFSENSE_READ_FILE_TOOL_DESCRIPTION = """Reads a file from the filesystem.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
- By default, reads up to 100 lines from the beginning.
|
||||||
|
- Use `offset` and `limit` for pagination when files are large.
|
||||||
|
- Results include line numbers.
|
||||||
|
- Documents contain a `<chunk_index>` near the top listing every chunk with
|
||||||
|
its line range and a `matched="true"` flag for search-relevant chunks.
|
||||||
|
Read the index first, then jump to matched chunks with
|
||||||
|
`read_file(path, offset=<start_line>, limit=<num_lines>)`.
|
||||||
|
- Use chunk IDs (`<chunk id='...'>`) as citations in answers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION = """Writes a new file to the in-memory filesystem (session-only).
|
||||||
|
|
||||||
|
Use this to create scratch/working files during the conversation. Files created
|
||||||
|
here are ephemeral and will not be saved to the user's knowledge base.
|
||||||
|
|
||||||
|
To permanently save a document to the user's knowledge base, use the
|
||||||
|
`save_document` tool instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION = """Performs exact string replacements in files.
|
||||||
|
|
||||||
|
IMPORTANT:
|
||||||
|
- Read the file before editing.
|
||||||
|
- Preserve exact indentation and formatting.
|
||||||
|
- Edits to documents under `/documents/` are session-only (not persisted to the
|
||||||
|
database) because those files use an XML citation wrapper around the original
|
||||||
|
content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SURFSENSE_GLOB_TOOL_DESCRIPTION = """Find files matching a glob pattern.
|
||||||
|
|
||||||
|
Supports standard glob patterns: `*`, `**`, `?`.
|
||||||
|
Returns absolute file paths.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SURFSENSE_GREP_TOOL_DESCRIPTION = """Search for a literal text pattern across files.
|
||||||
|
|
||||||
|
Use this to locate relevant document files/chunks before reading full files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SURFSENSE_SAVE_DOCUMENT_TOOL_DESCRIPTION = """Permanently saves a document to the user's knowledge base.
|
||||||
|
|
||||||
|
This is an expensive operation — it creates a new Document record in the
|
||||||
|
database, chunks the content, and generates embeddings for search.
|
||||||
|
|
||||||
|
Use ONLY when the user explicitly asks to save/create/store a document.
|
||||||
|
Do NOT use this for scratch work; use `write_file` for temporary files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
title: The document title (e.g., "Meeting Notes 2025-06-01").
|
||||||
|
content: The plain-text or markdown content to save. Do NOT include XML
|
||||||
|
citation wrappers — pass only the actual document text.
|
||||||
|
folder_path: Optional folder path under /documents/ (e.g., "Work/Notes").
|
||||||
|
Folders are created automatically if they don't exist.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class SurfSenseFilesystemMiddleware(FilesystemMiddleware):
|
||||||
|
"""SurfSense-specific filesystem middleware with DB persistence for docs."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
search_space_id: int | None = None,
|
||||||
|
created_by_id: str | None = None,
|
||||||
|
tool_token_limit_before_evict: int | None = 20000,
|
||||||
|
) -> None:
|
||||||
|
self._search_space_id = search_space_id
|
||||||
|
self._created_by_id = created_by_id
|
||||||
|
super().__init__(
|
||||||
|
system_prompt=SURFSENSE_FILESYSTEM_SYSTEM_PROMPT,
|
||||||
|
custom_tool_descriptions={
|
||||||
|
"ls": SURFSENSE_LIST_FILES_TOOL_DESCRIPTION,
|
||||||
|
"read_file": SURFSENSE_READ_FILE_TOOL_DESCRIPTION,
|
||||||
|
"write_file": SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION,
|
||||||
|
"edit_file": SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION,
|
||||||
|
"glob": SURFSENSE_GLOB_TOOL_DESCRIPTION,
|
||||||
|
"grep": SURFSENSE_GREP_TOOL_DESCRIPTION,
|
||||||
|
},
|
||||||
|
tool_token_limit_before_evict=tool_token_limit_before_evict,
|
||||||
|
)
|
||||||
|
# Remove the execute tool (no sandbox backend)
|
||||||
|
self.tools = [t for t in self.tools if t.name != "execute"]
|
||||||
|
self.tools.append(self._create_save_document_tool())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _run_async_blocking(coro: Any) -> Any:
|
||||||
|
"""Run async coroutine from sync code path when no event loop is running."""
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
return "Error: sync filesystem persistence not supported inside an active event loop."
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
return asyncio.run(coro)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_virtual_path(file_path: str) -> tuple[list[str], str]:
|
||||||
|
"""Parse /documents/... path into folder parts and a document title."""
|
||||||
|
if not file_path.startswith("/documents/"):
|
||||||
|
return [], ""
|
||||||
|
rel = file_path[len("/documents/") :].strip("/")
|
||||||
|
if not rel:
|
||||||
|
return [], ""
|
||||||
|
parts = [part for part in rel.split("/") if part]
|
||||||
|
file_name = parts[-1]
|
||||||
|
title = file_name[:-4] if file_name.lower().endswith(".xml") else file_name
|
||||||
|
return parts[:-1], title
|
||||||
|
|
||||||
|
async def _ensure_folder_hierarchy(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
folder_parts: list[str],
|
||||||
|
search_space_id: int,
|
||||||
|
) -> int | None:
|
||||||
|
"""Ensure folder hierarchy exists and return leaf folder ID."""
|
||||||
|
if not folder_parts:
|
||||||
|
return None
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
parent_id: int | None = None
|
||||||
|
for name in folder_parts:
|
||||||
|
result = await session.execute(
|
||||||
|
select(Folder).where(
|
||||||
|
Folder.search_space_id == search_space_id,
|
||||||
|
Folder.parent_id == parent_id
|
||||||
|
if parent_id is not None
|
||||||
|
else Folder.parent_id.is_(None),
|
||||||
|
Folder.name == name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
folder = result.scalar_one_or_none()
|
||||||
|
if folder is None:
|
||||||
|
sibling_result = await session.execute(
|
||||||
|
select(Folder.position)
|
||||||
|
.where(
|
||||||
|
Folder.search_space_id == search_space_id,
|
||||||
|
Folder.parent_id == parent_id
|
||||||
|
if parent_id is not None
|
||||||
|
else Folder.parent_id.is_(None),
|
||||||
|
)
|
||||||
|
.order_by(Folder.position.desc())
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
last_position = sibling_result.scalar_one_or_none()
|
||||||
|
folder = Folder(
|
||||||
|
name=name,
|
||||||
|
position=generate_key_between(last_position, None),
|
||||||
|
parent_id=parent_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
created_by_id=self._created_by_id,
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
session.add(folder)
|
||||||
|
await session.flush()
|
||||||
|
parent_id = folder.id
|
||||||
|
await session.commit()
|
||||||
|
return parent_id
|
||||||
|
|
||||||
|
async def _persist_new_document(
|
||||||
|
self, *, file_path: str, content: str
|
||||||
|
) -> dict[str, Any] | str:
|
||||||
|
"""Persist a new NOTE document from a newly written file.
|
||||||
|
|
||||||
|
Returns a dict with document metadata on success, or an error string.
|
||||||
|
"""
|
||||||
|
if self._search_space_id is None:
|
||||||
|
return {}
|
||||||
|
folder_parts, title = self._parse_virtual_path(file_path)
|
||||||
|
if not title:
|
||||||
|
return "Error: write_file for document persistence requires path under /documents/<name>.xml"
|
||||||
|
folder_id = await self._ensure_folder_hierarchy(
|
||||||
|
folder_parts=folder_parts,
|
||||||
|
search_space_id=self._search_space_id,
|
||||||
|
)
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
content_hash = generate_content_hash(content, self._search_space_id)
|
||||||
|
existing = await session.execute(
|
||||||
|
select(Document.id).where(Document.content_hash == content_hash)
|
||||||
|
)
|
||||||
|
if existing.scalar_one_or_none() is not None:
|
||||||
|
return "Error: A document with identical content already exists."
|
||||||
|
unique_identifier_hash = generate_unique_identifier_hash(
|
||||||
|
DocumentType.NOTE,
|
||||||
|
file_path,
|
||||||
|
self._search_space_id,
|
||||||
|
)
|
||||||
|
doc = Document(
|
||||||
|
title=title,
|
||||||
|
document_type=DocumentType.NOTE,
|
||||||
|
document_metadata={"virtual_path": file_path},
|
||||||
|
content=content,
|
||||||
|
content_hash=content_hash,
|
||||||
|
unique_identifier_hash=unique_identifier_hash,
|
||||||
|
source_markdown=content,
|
||||||
|
search_space_id=self._search_space_id,
|
||||||
|
folder_id=folder_id,
|
||||||
|
created_by_id=self._created_by_id,
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
session.add(doc)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
summary_embedding = embed_texts([content])[0]
|
||||||
|
doc.embedding = summary_embedding
|
||||||
|
chunk_texts = chunk_text(content)
|
||||||
|
if chunk_texts:
|
||||||
|
chunk_embeddings = embed_texts(chunk_texts)
|
||||||
|
chunks = [
|
||||||
|
Chunk(document_id=doc.id, content=text, embedding=embedding)
|
||||||
|
for text, embedding in zip(
|
||||||
|
chunk_texts, chunk_embeddings, strict=True
|
||||||
|
)
|
||||||
|
]
|
||||||
|
session.add_all(chunks)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": doc.id,
|
||||||
|
"title": title,
|
||||||
|
"documentType": DocumentType.NOTE.value,
|
||||||
|
"searchSpaceId": self._search_space_id,
|
||||||
|
"folderId": folder_id,
|
||||||
|
"createdById": str(self._created_by_id)
|
||||||
|
if self._created_by_id
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _persist_edited_document(
|
||||||
|
self, *, file_path: str, updated_content: str
|
||||||
|
) -> str | None:
|
||||||
|
"""Persist edits for an existing NOTE document and recreate chunks."""
|
||||||
|
if self._search_space_id is None:
|
||||||
|
return None
|
||||||
|
unique_identifier_hash = generate_unique_identifier_hash(
|
||||||
|
DocumentType.NOTE,
|
||||||
|
file_path,
|
||||||
|
self._search_space_id,
|
||||||
|
)
|
||||||
|
doc_id_from_xml: int | None = None
|
||||||
|
match = re.search(r"<document_id>\s*(\d+)\s*</document_id>", updated_content)
|
||||||
|
if match:
|
||||||
|
doc_id_from_xml = int(match.group(1))
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
doc_result = await session.execute(
|
||||||
|
select(Document).where(
|
||||||
|
Document.search_space_id == self._search_space_id,
|
||||||
|
Document.unique_identifier_hash == unique_identifier_hash,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
document = doc_result.scalar_one_or_none()
|
||||||
|
if document is None and doc_id_from_xml is not None:
|
||||||
|
by_id_result = await session.execute(
|
||||||
|
select(Document).where(
|
||||||
|
Document.search_space_id == self._search_space_id,
|
||||||
|
Document.id == doc_id_from_xml,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
document = by_id_result.scalar_one_or_none()
|
||||||
|
if document is None:
|
||||||
|
return "Error: Could not map edited file to an existing document."
|
||||||
|
|
||||||
|
document.content = updated_content
|
||||||
|
document.source_markdown = updated_content
|
||||||
|
document.content_hash = generate_content_hash(
|
||||||
|
updated_content, self._search_space_id
|
||||||
|
)
|
||||||
|
document.updated_at = datetime.now(UTC)
|
||||||
|
if not document.document_metadata:
|
||||||
|
document.document_metadata = {}
|
||||||
|
document.document_metadata["virtual_path"] = file_path
|
||||||
|
|
||||||
|
summary_embedding = embed_texts([updated_content])[0]
|
||||||
|
document.embedding = summary_embedding
|
||||||
|
|
||||||
|
await session.execute(delete(Chunk).where(Chunk.document_id == document.id))
|
||||||
|
chunk_texts = chunk_text(updated_content)
|
||||||
|
if chunk_texts:
|
||||||
|
chunk_embeddings = embed_texts(chunk_texts)
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
Chunk(
|
||||||
|
document_id=document.id, content=text, embedding=embedding
|
||||||
|
)
|
||||||
|
for text, embedding in zip(
|
||||||
|
chunk_texts, chunk_embeddings, strict=True
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _create_save_document_tool(self) -> BaseTool:
|
||||||
|
"""Create save_document tool that persists a new document to the KB."""
|
||||||
|
|
||||||
|
def sync_save_document(
|
||||||
|
title: Annotated[str, "Title for the new document."],
|
||||||
|
content: Annotated[
|
||||||
|
str,
|
||||||
|
"Plain-text or markdown content to save. Do NOT include XML wrappers.",
|
||||||
|
],
|
||||||
|
runtime: ToolRuntime[None, FilesystemState],
|
||||||
|
folder_path: Annotated[
|
||||||
|
str,
|
||||||
|
"Optional folder path under /documents/ (e.g. 'Work/Notes'). Created automatically.",
|
||||||
|
] = "",
|
||||||
|
) -> Command | str:
|
||||||
|
if not content.strip():
|
||||||
|
return "Error: content cannot be empty."
|
||||||
|
file_name = re.sub(r'[\\/:*?"<>|]+', "_", title).strip() or "untitled"
|
||||||
|
if not file_name.lower().endswith(".xml"):
|
||||||
|
file_name = f"{file_name}.xml"
|
||||||
|
folder = folder_path.strip().strip("/") if folder_path else ""
|
||||||
|
virtual_path = (
|
||||||
|
f"/documents/{folder}/{file_name}"
|
||||||
|
if folder
|
||||||
|
else f"/documents/{file_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
persist_result = self._run_async_blocking(
|
||||||
|
self._persist_new_document(file_path=virtual_path, content=content)
|
||||||
|
)
|
||||||
|
if isinstance(persist_result, str):
|
||||||
|
return persist_result
|
||||||
|
if isinstance(persist_result, dict) and persist_result.get("id"):
|
||||||
|
dispatch_custom_event("document_created", persist_result)
|
||||||
|
return f"Document '{title}' saved to knowledge base (path: {virtual_path})."
|
||||||
|
|
||||||
|
async def async_save_document(
|
||||||
|
title: Annotated[str, "Title for the new document."],
|
||||||
|
content: Annotated[
|
||||||
|
str,
|
||||||
|
"Plain-text or markdown content to save. Do NOT include XML wrappers.",
|
||||||
|
],
|
||||||
|
runtime: ToolRuntime[None, FilesystemState],
|
||||||
|
folder_path: Annotated[
|
||||||
|
str,
|
||||||
|
"Optional folder path under /documents/ (e.g. 'Work/Notes'). Created automatically.",
|
||||||
|
] = "",
|
||||||
|
) -> Command | str:
|
||||||
|
if not content.strip():
|
||||||
|
return "Error: content cannot be empty."
|
||||||
|
file_name = re.sub(r'[\\/:*?"<>|]+', "_", title).strip() or "untitled"
|
||||||
|
if not file_name.lower().endswith(".xml"):
|
||||||
|
file_name = f"{file_name}.xml"
|
||||||
|
folder = folder_path.strip().strip("/") if folder_path else ""
|
||||||
|
virtual_path = (
|
||||||
|
f"/documents/{folder}/{file_name}"
|
||||||
|
if folder
|
||||||
|
else f"/documents/{file_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
persist_result = await self._persist_new_document(
|
||||||
|
file_path=virtual_path, content=content
|
||||||
|
)
|
||||||
|
if isinstance(persist_result, str):
|
||||||
|
return persist_result
|
||||||
|
if isinstance(persist_result, dict) and persist_result.get("id"):
|
||||||
|
dispatch_custom_event("document_created", persist_result)
|
||||||
|
return f"Document '{title}' saved to knowledge base (path: {virtual_path})."
|
||||||
|
|
||||||
|
return StructuredTool.from_function(
|
||||||
|
name="save_document",
|
||||||
|
description=SURFSENSE_SAVE_DOCUMENT_TOOL_DESCRIPTION,
|
||||||
|
func=sync_save_document,
|
||||||
|
coroutine=async_save_document,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_write_file_tool(self) -> BaseTool:
|
||||||
|
"""Create write_file — ephemeral for /documents/*, persisted otherwise."""
|
||||||
|
tool_description = (
|
||||||
|
self._custom_tool_descriptions.get("write_file")
|
||||||
|
or SURFSENSE_WRITE_FILE_TOOL_DESCRIPTION
|
||||||
|
)
|
||||||
|
|
||||||
|
def sync_write_file(
|
||||||
|
file_path: Annotated[
|
||||||
|
str,
|
||||||
|
"Absolute path where the file should be created. Must be absolute, not relative.",
|
||||||
|
],
|
||||||
|
content: Annotated[
|
||||||
|
str,
|
||||||
|
"The text content to write to the file. This parameter is required.",
|
||||||
|
],
|
||||||
|
runtime: ToolRuntime[None, FilesystemState],
|
||||||
|
) -> Command | str:
|
||||||
|
resolved_backend = self._get_backend(runtime)
|
||||||
|
try:
|
||||||
|
validated_path = validate_path(file_path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return f"Error: {exc}"
|
||||||
|
res: WriteResult = resolved_backend.write(validated_path, content)
|
||||||
|
if res.error:
|
||||||
|
return res.error
|
||||||
|
|
||||||
|
if not self._is_kb_document(validated_path):
|
||||||
|
persist_result = self._run_async_blocking(
|
||||||
|
self._persist_new_document(
|
||||||
|
file_path=validated_path, content=content
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if isinstance(persist_result, str):
|
||||||
|
return persist_result
|
||||||
|
if isinstance(persist_result, dict) and persist_result.get("id"):
|
||||||
|
dispatch_custom_event("document_created", persist_result)
|
||||||
|
|
||||||
|
if res.files_update is not None:
|
||||||
|
return Command(
|
||||||
|
update={
|
||||||
|
"files": res.files_update,
|
||||||
|
"messages": [
|
||||||
|
ToolMessage(
|
||||||
|
content=f"Updated file {res.path}",
|
||||||
|
tool_call_id=runtime.tool_call_id,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return f"Updated file {res.path}"
|
||||||
|
|
||||||
|
async def async_write_file(
|
||||||
|
file_path: Annotated[
|
||||||
|
str,
|
||||||
|
"Absolute path where the file should be created. Must be absolute, not relative.",
|
||||||
|
],
|
||||||
|
content: Annotated[
|
||||||
|
str,
|
||||||
|
"The text content to write to the file. This parameter is required.",
|
||||||
|
],
|
||||||
|
runtime: ToolRuntime[None, FilesystemState],
|
||||||
|
) -> Command | str:
|
||||||
|
resolved_backend = self._get_backend(runtime)
|
||||||
|
try:
|
||||||
|
validated_path = validate_path(file_path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return f"Error: {exc}"
|
||||||
|
res: WriteResult = await resolved_backend.awrite(validated_path, content)
|
||||||
|
if res.error:
|
||||||
|
return res.error
|
||||||
|
|
||||||
|
if not self._is_kb_document(validated_path):
|
||||||
|
persist_result = await self._persist_new_document(
|
||||||
|
file_path=validated_path,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
if isinstance(persist_result, str):
|
||||||
|
return persist_result
|
||||||
|
if isinstance(persist_result, dict) and persist_result.get("id"):
|
||||||
|
dispatch_custom_event("document_created", persist_result)
|
||||||
|
|
||||||
|
if res.files_update is not None:
|
||||||
|
return Command(
|
||||||
|
update={
|
||||||
|
"files": res.files_update,
|
||||||
|
"messages": [
|
||||||
|
ToolMessage(
|
||||||
|
content=f"Updated file {res.path}",
|
||||||
|
tool_call_id=runtime.tool_call_id,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return f"Updated file {res.path}"
|
||||||
|
|
||||||
|
return StructuredTool.from_function(
|
||||||
|
name="write_file",
|
||||||
|
description=tool_description,
|
||||||
|
func=sync_write_file,
|
||||||
|
coroutine=async_write_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_kb_document(path: str) -> bool:
|
||||||
|
"""Return True for paths under /documents/ (KB-sourced, XML-wrapped)."""
|
||||||
|
return path.startswith("/documents/")
|
||||||
|
|
||||||
|
def _create_edit_file_tool(self) -> BaseTool:
|
||||||
|
"""Create edit_file with DB persistence (skipped for KB documents)."""
|
||||||
|
tool_description = (
|
||||||
|
self._custom_tool_descriptions.get("edit_file")
|
||||||
|
or SURFSENSE_EDIT_FILE_TOOL_DESCRIPTION
|
||||||
|
)
|
||||||
|
|
||||||
|
def sync_edit_file(
|
||||||
|
file_path: Annotated[
|
||||||
|
str,
|
||||||
|
"Absolute path to the file to edit. Must be absolute, not relative.",
|
||||||
|
],
|
||||||
|
old_string: Annotated[
|
||||||
|
str,
|
||||||
|
"The exact text to find and replace. Must be unique in the file unless replace_all is True.",
|
||||||
|
],
|
||||||
|
new_string: Annotated[
|
||||||
|
str,
|
||||||
|
"The text to replace old_string with. Must be different from old_string.",
|
||||||
|
],
|
||||||
|
runtime: ToolRuntime[None, FilesystemState],
|
||||||
|
*,
|
||||||
|
replace_all: Annotated[
|
||||||
|
bool,
|
||||||
|
"If True, replace all occurrences of old_string. If False (default), old_string must be unique.",
|
||||||
|
] = False,
|
||||||
|
) -> Command | str:
|
||||||
|
resolved_backend = self._get_backend(runtime)
|
||||||
|
try:
|
||||||
|
validated_path = validate_path(file_path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return f"Error: {exc}"
|
||||||
|
res: EditResult = resolved_backend.edit(
|
||||||
|
validated_path,
|
||||||
|
old_string,
|
||||||
|
new_string,
|
||||||
|
replace_all=replace_all,
|
||||||
|
)
|
||||||
|
if res.error:
|
||||||
|
return res.error
|
||||||
|
|
||||||
|
if not self._is_kb_document(validated_path):
|
||||||
|
read_result = resolved_backend.read(
|
||||||
|
validated_path, offset=0, limit=200000
|
||||||
|
)
|
||||||
|
if read_result.error or read_result.file_data is None:
|
||||||
|
return f"Error: could not reload edited file '{validated_path}' for persistence."
|
||||||
|
updated_content = read_result.file_data["content"]
|
||||||
|
persist_result = self._run_async_blocking(
|
||||||
|
self._persist_edited_document(
|
||||||
|
file_path=validated_path,
|
||||||
|
updated_content=updated_content,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if isinstance(persist_result, str):
|
||||||
|
return persist_result
|
||||||
|
|
||||||
|
if res.files_update is not None:
|
||||||
|
return Command(
|
||||||
|
update={
|
||||||
|
"files": res.files_update,
|
||||||
|
"messages": [
|
||||||
|
ToolMessage(
|
||||||
|
content=f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'",
|
||||||
|
tool_call_id=runtime.tool_call_id,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'"
|
||||||
|
|
||||||
|
async def async_edit_file(
|
||||||
|
file_path: Annotated[
|
||||||
|
str,
|
||||||
|
"Absolute path to the file to edit. Must be absolute, not relative.",
|
||||||
|
],
|
||||||
|
old_string: Annotated[
|
||||||
|
str,
|
||||||
|
"The exact text to find and replace. Must be unique in the file unless replace_all is True.",
|
||||||
|
],
|
||||||
|
new_string: Annotated[
|
||||||
|
str,
|
||||||
|
"The text to replace old_string with. Must be different from old_string.",
|
||||||
|
],
|
||||||
|
runtime: ToolRuntime[None, FilesystemState],
|
||||||
|
*,
|
||||||
|
replace_all: Annotated[
|
||||||
|
bool,
|
||||||
|
"If True, replace all occurrences of old_string. If False (default), old_string must be unique.",
|
||||||
|
] = False,
|
||||||
|
) -> Command | str:
|
||||||
|
resolved_backend = self._get_backend(runtime)
|
||||||
|
try:
|
||||||
|
validated_path = validate_path(file_path)
|
||||||
|
except ValueError as exc:
|
||||||
|
return f"Error: {exc}"
|
||||||
|
res: EditResult = await resolved_backend.aedit(
|
||||||
|
validated_path,
|
||||||
|
old_string,
|
||||||
|
new_string,
|
||||||
|
replace_all=replace_all,
|
||||||
|
)
|
||||||
|
if res.error:
|
||||||
|
return res.error
|
||||||
|
|
||||||
|
if not self._is_kb_document(validated_path):
|
||||||
|
read_result = await resolved_backend.aread(
|
||||||
|
validated_path, offset=0, limit=200000
|
||||||
|
)
|
||||||
|
if read_result.error or read_result.file_data is None:
|
||||||
|
return f"Error: could not reload edited file '{validated_path}' for persistence."
|
||||||
|
updated_content = read_result.file_data["content"]
|
||||||
|
persist_error = await self._persist_edited_document(
|
||||||
|
file_path=validated_path,
|
||||||
|
updated_content=updated_content,
|
||||||
|
)
|
||||||
|
if persist_error:
|
||||||
|
return persist_error
|
||||||
|
|
||||||
|
if res.files_update is not None:
|
||||||
|
return Command(
|
||||||
|
update={
|
||||||
|
"files": res.files_update,
|
||||||
|
"messages": [
|
||||||
|
ToolMessage(
|
||||||
|
content=f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'",
|
||||||
|
tool_call_id=runtime.tool_call_id,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return f"Successfully replaced {res.occurrences} instance(s) of the string in '{res.path}'"
|
||||||
|
|
||||||
|
return StructuredTool.from_function(
|
||||||
|
name="edit_file",
|
||||||
|
description=tool_description,
|
||||||
|
func=sync_edit_file,
|
||||||
|
coroutine=async_edit_file,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,414 @@
|
||||||
|
"""Knowledge-base pre-search middleware for the SurfSense new chat agent.
|
||||||
|
|
||||||
|
This middleware runs before the main agent loop and seeds a virtual filesystem
|
||||||
|
(`files` state) with relevant documents retrieved via hybrid search. On each
|
||||||
|
turn the filesystem is *expanded* — new results merge with documents loaded
|
||||||
|
during prior turns — and a synthetic ``ls`` result is injected into the message
|
||||||
|
history so the LLM is immediately aware of the current filesystem structure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db import NATIVE_TO_LEGACY_DOCTYPE, Document, Folder, shielded_async_session
|
||||||
|
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||||
|
from app.utils.document_converters import embed_texts
|
||||||
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
_perf_log = get_perf_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_text_from_message(message: BaseMessage) -> str:
|
||||||
|
"""Extract plain text from a message content."""
|
||||||
|
content = getattr(message, "content", "")
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict) and item.get("type") == "text":
|
||||||
|
parts.append(str(item.get("text", "")))
|
||||||
|
return "\n".join(p for p in parts if p)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_filename(value: str, *, fallback: str = "untitled.xml") -> str:
|
||||||
|
"""Convert arbitrary text into a filesystem-safe filename."""
|
||||||
|
name = re.sub(r"[\\/:*?\"<>|]+", "_", value).strip()
|
||||||
|
name = re.sub(r"\s+", " ", name)
|
||||||
|
if not name:
|
||||||
|
name = fallback
|
||||||
|
if len(name) > 180:
|
||||||
|
name = name[:180].rstrip()
|
||||||
|
if not name.lower().endswith(".xml"):
|
||||||
|
name = f"{name}.xml"
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def _build_document_xml(
|
||||||
|
document: dict[str, Any],
|
||||||
|
matched_chunk_ids: set[int] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Build citation-friendly XML with a ``<chunk_index>`` for smart seeking.
|
||||||
|
|
||||||
|
The ``<chunk_index>`` at the top of each document lists every chunk with its
|
||||||
|
line range inside ``<document_content>`` and flags chunks that directly
|
||||||
|
matched the search query (``matched="true"``). This lets the LLM jump
|
||||||
|
straight to the most relevant section via ``read_file(offset=…, limit=…)``
|
||||||
|
instead of reading sequentially from the start.
|
||||||
|
"""
|
||||||
|
matched = matched_chunk_ids or set()
|
||||||
|
|
||||||
|
doc_meta = document.get("document") or {}
|
||||||
|
metadata = (doc_meta.get("metadata") or {}) if isinstance(doc_meta, dict) else {}
|
||||||
|
document_id = doc_meta.get("id", document.get("document_id", "unknown"))
|
||||||
|
document_type = doc_meta.get("document_type", document.get("source", "UNKNOWN"))
|
||||||
|
title = doc_meta.get("title") or metadata.get("title") or "Untitled Document"
|
||||||
|
url = (
|
||||||
|
metadata.get("url") or metadata.get("source") or metadata.get("page_url") or ""
|
||||||
|
)
|
||||||
|
metadata_json = json.dumps(metadata, ensure_ascii=False)
|
||||||
|
|
||||||
|
# --- 1. Metadata header (fixed structure) ---
|
||||||
|
metadata_lines: list[str] = [
|
||||||
|
"<document>",
|
||||||
|
"<document_metadata>",
|
||||||
|
f" <document_id>{document_id}</document_id>",
|
||||||
|
f" <document_type>{document_type}</document_type>",
|
||||||
|
f" <title><![CDATA[{title}]]></title>",
|
||||||
|
f" <url><![CDATA[{url}]]></url>",
|
||||||
|
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
|
||||||
|
"</document_metadata>",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
|
||||||
|
# --- 2. Pre-build chunk XML strings to compute line counts ---
|
||||||
|
chunks = document.get("chunks") or []
|
||||||
|
chunk_entries: list[tuple[int | None, str]] = [] # (chunk_id, xml_string)
|
||||||
|
if isinstance(chunks, list):
|
||||||
|
for chunk in chunks:
|
||||||
|
if not isinstance(chunk, dict):
|
||||||
|
continue
|
||||||
|
chunk_id = chunk.get("chunk_id") or chunk.get("id")
|
||||||
|
chunk_content = str(chunk.get("content", "")).strip()
|
||||||
|
if not chunk_content:
|
||||||
|
continue
|
||||||
|
if chunk_id is None:
|
||||||
|
xml = f" <chunk><![CDATA[{chunk_content}]]></chunk>"
|
||||||
|
else:
|
||||||
|
xml = f" <chunk id='{chunk_id}'><![CDATA[{chunk_content}]]></chunk>"
|
||||||
|
chunk_entries.append((chunk_id, xml))
|
||||||
|
|
||||||
|
# --- 3. Compute line numbers for every chunk ---
|
||||||
|
# Layout (1-indexed lines for read_file):
|
||||||
|
# metadata_lines -> len(metadata_lines) lines
|
||||||
|
# <chunk_index> -> 1 line
|
||||||
|
# index entries -> len(chunk_entries) lines
|
||||||
|
# </chunk_index> -> 1 line
|
||||||
|
# (empty line) -> 1 line
|
||||||
|
# <document_content> -> 1 line
|
||||||
|
# chunk xml lines…
|
||||||
|
# </document_content> -> 1 line
|
||||||
|
# </document> -> 1 line
|
||||||
|
index_overhead = (
|
||||||
|
1 + len(chunk_entries) + 1 + 1 + 1
|
||||||
|
) # tags + empty + <document_content>
|
||||||
|
first_chunk_line = len(metadata_lines) + index_overhead + 1 # 1-indexed
|
||||||
|
|
||||||
|
current_line = first_chunk_line
|
||||||
|
index_entry_lines: list[str] = []
|
||||||
|
for cid, xml_str in chunk_entries:
|
||||||
|
num_lines = xml_str.count("\n") + 1
|
||||||
|
end_line = current_line + num_lines - 1
|
||||||
|
matched_attr = ' matched="true"' if cid is not None and cid in matched else ""
|
||||||
|
if cid is not None:
|
||||||
|
index_entry_lines.append(
|
||||||
|
f' <entry chunk_id="{cid}" lines="{current_line}-{end_line}"{matched_attr}/>'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
index_entry_lines.append(
|
||||||
|
f' <entry lines="{current_line}-{end_line}"{matched_attr}/>'
|
||||||
|
)
|
||||||
|
current_line = end_line + 1
|
||||||
|
|
||||||
|
# --- 4. Assemble final XML ---
|
||||||
|
lines = metadata_lines.copy()
|
||||||
|
lines.append("<chunk_index>")
|
||||||
|
lines.extend(index_entry_lines)
|
||||||
|
lines.append("</chunk_index>")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("<document_content>")
|
||||||
|
for _, xml_str in chunk_entries:
|
||||||
|
lines.append(xml_str)
|
||||||
|
lines.extend(["</document_content>", "</document>"])
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_folder_paths(
|
||||||
|
session: AsyncSession, search_space_id: int
|
||||||
|
) -> dict[int, str]:
|
||||||
|
"""Return a map of folder_id -> virtual folder path under /documents."""
|
||||||
|
result = await session.execute(
|
||||||
|
select(Folder.id, Folder.name, Folder.parent_id).where(
|
||||||
|
Folder.search_space_id == search_space_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rows = result.all()
|
||||||
|
by_id = {row.id: {"name": row.name, "parent_id": row.parent_id} for row in rows}
|
||||||
|
|
||||||
|
cache: dict[int, str] = {}
|
||||||
|
|
||||||
|
def resolve_path(folder_id: int) -> str:
|
||||||
|
if folder_id in cache:
|
||||||
|
return cache[folder_id]
|
||||||
|
parts: list[str] = []
|
||||||
|
cursor: int | None = folder_id
|
||||||
|
visited: set[int] = set()
|
||||||
|
while cursor is not None and cursor in by_id and cursor not in visited:
|
||||||
|
visited.add(cursor)
|
||||||
|
entry = by_id[cursor]
|
||||||
|
parts.append(
|
||||||
|
_safe_filename(str(entry["name"]), fallback="folder").removesuffix(
|
||||||
|
".xml"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cursor = entry["parent_id"]
|
||||||
|
parts.reverse()
|
||||||
|
path = "/documents/" + "/".join(parts) if parts else "/documents"
|
||||||
|
cache[folder_id] = path
|
||||||
|
return path
|
||||||
|
|
||||||
|
for folder_id in by_id:
|
||||||
|
resolve_path(folder_id)
|
||||||
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
def _build_synthetic_ls(
|
||||||
|
existing_files: dict[str, Any] | None,
|
||||||
|
new_files: dict[str, Any],
|
||||||
|
) -> tuple[AIMessage, ToolMessage]:
|
||||||
|
"""Build a synthetic ls("/documents") tool-call + result for the LLM context.
|
||||||
|
|
||||||
|
Paths are listed with *new* (rank-ordered) files first, then existing files
|
||||||
|
that were already in state from prior turns.
|
||||||
|
"""
|
||||||
|
merged: dict[str, Any] = {**(existing_files or {}), **new_files}
|
||||||
|
doc_paths = [
|
||||||
|
p for p, v in merged.items() if p.startswith("/documents/") and v is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
new_set = set(new_files)
|
||||||
|
new_paths = [p for p in doc_paths if p in new_set]
|
||||||
|
old_paths = [p for p in doc_paths if p not in new_set]
|
||||||
|
ordered = new_paths + old_paths
|
||||||
|
|
||||||
|
tool_call_id = f"auto_ls_{uuid.uuid4().hex[:12]}"
|
||||||
|
ai_msg = AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[{"name": "ls", "args": {"path": "/documents"}, "id": tool_call_id}],
|
||||||
|
)
|
||||||
|
tool_msg = ToolMessage(
|
||||||
|
content=str(ordered) if ordered else "No documents found.",
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
)
|
||||||
|
return ai_msg, tool_msg
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_search_types(
|
||||||
|
available_connectors: list[str] | None,
|
||||||
|
available_document_types: list[str] | None,
|
||||||
|
) -> list[str] | None:
|
||||||
|
"""Build a flat list of document-type strings for the chunk retriever.
|
||||||
|
|
||||||
|
Includes legacy equivalents from ``NATIVE_TO_LEGACY_DOCTYPE`` so that
|
||||||
|
old documents indexed under Composio names are still found.
|
||||||
|
|
||||||
|
Returns ``None`` when no filtering is desired (search all types).
|
||||||
|
"""
|
||||||
|
types: set[str] = set()
|
||||||
|
if available_document_types:
|
||||||
|
types.update(available_document_types)
|
||||||
|
if available_connectors:
|
||||||
|
types.update(available_connectors)
|
||||||
|
if not types:
|
||||||
|
return None
|
||||||
|
|
||||||
|
expanded: set[str] = set(types)
|
||||||
|
for t in types:
|
||||||
|
legacy = NATIVE_TO_LEGACY_DOCTYPE.get(t)
|
||||||
|
if legacy:
|
||||||
|
expanded.add(legacy)
|
||||||
|
return list(expanded) if expanded else None
|
||||||
|
|
||||||
|
|
||||||
|
async def search_knowledge_base(
|
||||||
|
*,
|
||||||
|
query: str,
|
||||||
|
search_space_id: int,
|
||||||
|
available_connectors: list[str] | None = None,
|
||||||
|
available_document_types: list[str] | None = None,
|
||||||
|
top_k: int = 10,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Run a single unified hybrid search against the knowledge base.
|
||||||
|
|
||||||
|
Uses one ``ChucksHybridSearchRetriever`` call across all document types
|
||||||
|
instead of fanning out per-connector. This reduces the number of DB
|
||||||
|
queries from ~10 to 2 (one RRF query + one chunk fetch).
|
||||||
|
"""
|
||||||
|
if not query:
|
||||||
|
return []
|
||||||
|
|
||||||
|
[embedding] = embed_texts([query])
|
||||||
|
|
||||||
|
doc_types = _resolve_search_types(available_connectors, available_document_types)
|
||||||
|
retriever_top_k = min(top_k * 3, 30)
|
||||||
|
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
retriever = ChucksHybridSearchRetriever(session)
|
||||||
|
results = await retriever.hybrid_search(
|
||||||
|
query_text=query,
|
||||||
|
top_k=retriever_top_k,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
document_type=doc_types,
|
||||||
|
query_embedding=embedding.tolist(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return results[:top_k]
|
||||||
|
|
||||||
|
|
||||||
|
async def build_scoped_filesystem(
|
||||||
|
*,
|
||||||
|
documents: Sequence[dict[str, Any]],
|
||||||
|
search_space_id: int,
|
||||||
|
) -> dict[str, dict[str, str]]:
|
||||||
|
"""Build a StateBackend-compatible files dict from search results."""
|
||||||
|
async with shielded_async_session() as session:
|
||||||
|
folder_paths = await _get_folder_paths(session, search_space_id)
|
||||||
|
doc_ids = [
|
||||||
|
(doc.get("document") or {}).get("id")
|
||||||
|
for doc in documents
|
||||||
|
if isinstance(doc, dict)
|
||||||
|
]
|
||||||
|
doc_ids = [doc_id for doc_id in doc_ids if isinstance(doc_id, int)]
|
||||||
|
folder_by_doc_id: dict[int, int | None] = {}
|
||||||
|
if doc_ids:
|
||||||
|
doc_rows = await session.execute(
|
||||||
|
select(Document.id, Document.folder_id).where(
|
||||||
|
Document.search_space_id == search_space_id,
|
||||||
|
Document.id.in_(doc_ids),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
folder_by_doc_id = {
|
||||||
|
row.id: row.folder_id for row in doc_rows.all() if row.id is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
files: dict[str, dict[str, str]] = {}
|
||||||
|
for document in documents:
|
||||||
|
doc_meta = document.get("document") or {}
|
||||||
|
title = str(doc_meta.get("title") or "untitled")
|
||||||
|
doc_id = doc_meta.get("id")
|
||||||
|
folder_id = folder_by_doc_id.get(doc_id) if isinstance(doc_id, int) else None
|
||||||
|
base_folder = folder_paths.get(folder_id, "/documents")
|
||||||
|
file_name = _safe_filename(title)
|
||||||
|
path = f"{base_folder}/{file_name}"
|
||||||
|
matched_ids = set(document.get("matched_chunk_ids") or [])
|
||||||
|
xml_content = _build_document_xml(document, matched_chunk_ids=matched_ids)
|
||||||
|
files[path] = {
|
||||||
|
"content": xml_content.split("\n"),
|
||||||
|
"encoding": "utf-8",
|
||||||
|
"created_at": "",
|
||||||
|
"modified_at": "",
|
||||||
|
}
|
||||||
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||||
|
"""Pre-agent middleware that always searches the KB and seeds a scoped filesystem."""
|
||||||
|
|
||||||
|
tools = ()
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
search_space_id: int,
|
||||||
|
available_connectors: list[str] | None = None,
|
||||||
|
available_document_types: list[str] | None = None,
|
||||||
|
top_k: int = 10,
|
||||||
|
) -> None:
|
||||||
|
self.search_space_id = search_space_id
|
||||||
|
self.available_connectors = available_connectors
|
||||||
|
self.available_document_types = available_document_types
|
||||||
|
self.top_k = top_k
|
||||||
|
|
||||||
|
def before_agent( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
state: AgentState,
|
||||||
|
runtime: Runtime[Any],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
return None
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
return asyncio.run(self.abefore_agent(state, runtime))
|
||||||
|
|
||||||
|
async def abefore_agent( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
state: AgentState,
|
||||||
|
runtime: Runtime[Any],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
del runtime
|
||||||
|
messages = state.get("messages") or []
|
||||||
|
if not messages:
|
||||||
|
return None
|
||||||
|
last_message = messages[-1]
|
||||||
|
if not isinstance(last_message, HumanMessage):
|
||||||
|
return None
|
||||||
|
|
||||||
|
user_text = _extract_text_from_message(last_message).strip()
|
||||||
|
if not user_text:
|
||||||
|
return None
|
||||||
|
|
||||||
|
t0 = _perf_log and asyncio.get_event_loop().time()
|
||||||
|
existing_files = state.get("files")
|
||||||
|
|
||||||
|
search_results = await search_knowledge_base(
|
||||||
|
query=user_text,
|
||||||
|
search_space_id=self.search_space_id,
|
||||||
|
available_connectors=self.available_connectors,
|
||||||
|
available_document_types=self.available_document_types,
|
||||||
|
top_k=self.top_k,
|
||||||
|
)
|
||||||
|
new_files = await build_scoped_filesystem(
|
||||||
|
documents=search_results,
|
||||||
|
search_space_id=self.search_space_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
ai_msg, tool_msg = _build_synthetic_ls(existing_files, new_files)
|
||||||
|
|
||||||
|
if t0 is not None:
|
||||||
|
_perf_log.info(
|
||||||
|
"[kb_fs_middleware] completed in %.3fs query=%r new_files=%d total=%d",
|
||||||
|
asyncio.get_event_loop().time() - t0,
|
||||||
|
user_text[:80],
|
||||||
|
len(new_files),
|
||||||
|
len(new_files) + len(existing_files or {}),
|
||||||
|
)
|
||||||
|
return {"files": new_files, "messages": [ai_msg, tool_msg]}
|
||||||
|
|
@ -25,6 +25,21 @@ When writing mathematical formulas or equations, ALWAYS use LaTeX notation. NEVE
|
||||||
|
|
||||||
NEVER expose internal tool parameter names, backend IDs, or implementation details to the user. Always use natural, user-friendly language instead.
|
NEVER expose internal tool parameter names, backend IDs, or implementation details to the user. Always use natural, user-friendly language instead.
|
||||||
|
|
||||||
|
<knowledge_base_only_policy>
|
||||||
|
CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
|
||||||
|
- You MUST answer questions ONLY using information retrieved from the user's knowledge base, web search results, scraped webpages, or other tool outputs.
|
||||||
|
- You MUST NOT answer factual or informational questions from your own training data or general knowledge unless the user explicitly grants permission.
|
||||||
|
- If the knowledge base search returns no relevant results AND no other tool provides the answer, you MUST:
|
||||||
|
1. Inform the user that you could not find relevant information in their knowledge base.
|
||||||
|
2. Ask the user: "Would you like me to answer from my general knowledge instead?"
|
||||||
|
3. ONLY provide a general-knowledge answer AFTER the user explicitly says yes.
|
||||||
|
- This policy does NOT apply to:
|
||||||
|
* Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?")
|
||||||
|
* Formatting, summarization, or analysis of content already present in the conversation
|
||||||
|
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
|
||||||
|
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
||||||
|
</knowledge_base_only_policy>
|
||||||
|
|
||||||
</system_instruction>
|
</system_instruction>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -41,6 +56,21 @@ When writing mathematical formulas or equations, ALWAYS use LaTeX notation. NEVE
|
||||||
|
|
||||||
NEVER expose internal tool parameter names, backend IDs, or implementation details to the user. Always use natural, user-friendly language instead.
|
NEVER expose internal tool parameter names, backend IDs, or implementation details to the user. Always use natural, user-friendly language instead.
|
||||||
|
|
||||||
|
<knowledge_base_only_policy>
|
||||||
|
CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
|
||||||
|
- You MUST answer questions ONLY using information retrieved from the team's shared knowledge base, web search results, scraped webpages, or other tool outputs.
|
||||||
|
- You MUST NOT answer factual or informational questions from your own training data or general knowledge unless a team member explicitly grants permission.
|
||||||
|
- If the knowledge base search returns no relevant results AND no other tool provides the answer, you MUST:
|
||||||
|
1. Inform the team that you could not find relevant information in the shared knowledge base.
|
||||||
|
2. Ask: "Would you like me to answer from my general knowledge instead?"
|
||||||
|
3. ONLY provide a general-knowledge answer AFTER a team member explicitly says yes.
|
||||||
|
- This policy does NOT apply to:
|
||||||
|
* Casual conversation, greetings, or meta-questions about SurfSense itself (e.g., "what can you do?")
|
||||||
|
* Formatting, summarization, or analysis of content already present in the conversation
|
||||||
|
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
|
||||||
|
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
||||||
|
</knowledge_base_only_policy>
|
||||||
|
|
||||||
</system_instruction>
|
</system_instruction>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -67,15 +97,6 @@ _TOOLS_PREAMBLE = """
|
||||||
<tools>
|
<tools>
|
||||||
You have access to the following tools:
|
You have access to the following tools:
|
||||||
|
|
||||||
CRITICAL BEHAVIORAL RULE — SEARCH FIRST, ANSWER LATER:
|
|
||||||
For ANY user query that is ambiguous, open-ended, or could potentially have relevant context in the
|
|
||||||
knowledge base, you MUST call `search_knowledge_base` BEFORE attempting to answer from your own
|
|
||||||
general knowledge. This includes (but is not limited to) questions about concepts, topics, projects,
|
|
||||||
people, events, recommendations, or anything the user might have stored notes/documents about.
|
|
||||||
Only fall back to your own general knowledge if the search returns NO relevant results.
|
|
||||||
Do NOT skip the search and answer directly — the user's knowledge base may contain personalized,
|
|
||||||
up-to-date, or domain-specific information that is more relevant than your general training data.
|
|
||||||
|
|
||||||
IMPORTANT: You can ONLY use the tools listed below. If a capability is not listed here, you do NOT have it.
|
IMPORTANT: You can ONLY use the tools listed below. If a capability is not listed here, you do NOT have it.
|
||||||
Do NOT claim you can do something if the corresponding tool is not listed.
|
Do NOT claim you can do something if the corresponding tool is not listed.
|
||||||
|
|
||||||
|
|
@ -92,29 +113,6 @@ _TOOL_INSTRUCTIONS["search_surfsense_docs"] = """
|
||||||
- Returns: Documentation content with chunk IDs for citations (prefixed with 'doc-', e.g., [citation:doc-123])
|
- Returns: Documentation content with chunk IDs for citations (prefixed with 'doc-', e.g., [citation:doc-123])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_TOOL_INSTRUCTIONS["search_knowledge_base"] = """
|
|
||||||
- search_knowledge_base: Search the user's personal knowledge base for relevant information.
|
|
||||||
- DEFAULT ACTION: For any user question or ambiguous query, ALWAYS call this tool first to check
|
|
||||||
for relevant context before answering from general knowledge. When in doubt, search.
|
|
||||||
- IMPORTANT: When searching for information (meetings, schedules, notes, tasks, etc.), ALWAYS search broadly
|
|
||||||
across ALL sources first by omitting connectors_to_search. The user may store information in various places
|
|
||||||
including calendar apps, note-taking apps (Obsidian, Notion), chat apps (Slack, Discord), and more.
|
|
||||||
- This tool searches ONLY local/indexed data (uploaded files, Notion, Slack, browser extension captures, etc.).
|
|
||||||
For real-time web search (current events, news, live data), use the `web_search` tool instead.
|
|
||||||
- FALLBACK BEHAVIOR: If the search returns no relevant results, you MAY then answer using your own
|
|
||||||
general knowledge, but clearly indicate that no matching information was found in the knowledge base.
|
|
||||||
- Only narrow to specific connectors if the user explicitly asks (e.g., "check my Slack" or "in my calendar").
|
|
||||||
- Personal notes in Obsidian, Notion, or NOTE often contain schedules, meeting times, reminders, and other
|
|
||||||
important information that may not be in calendars.
|
|
||||||
- Args:
|
|
||||||
- query: The search query - be specific and include key terms
|
|
||||||
- top_k: Number of results to retrieve (default: 10)
|
|
||||||
- start_date: Optional ISO date/datetime (e.g. "2025-12-12" or "2025-12-12T00:00:00+00:00")
|
|
||||||
- end_date: Optional ISO date/datetime (e.g. "2025-12-19" or "2025-12-19T23:59:59+00:00")
|
|
||||||
- connectors_to_search: Optional list of connector enums to search. If omitted, searches all.
|
|
||||||
- Returns: Formatted string with relevant documents and their content
|
|
||||||
"""
|
|
||||||
|
|
||||||
_TOOL_INSTRUCTIONS["generate_podcast"] = """
|
_TOOL_INSTRUCTIONS["generate_podcast"] = """
|
||||||
- generate_podcast: Generate an audio podcast from provided content.
|
- generate_podcast: Generate an audio podcast from provided content.
|
||||||
- Use this when the user asks to create, generate, or make a podcast.
|
- Use this when the user asks to create, generate, or make a podcast.
|
||||||
|
|
@ -163,8 +161,8 @@ _TOOL_INSTRUCTIONS["generate_report"] = """
|
||||||
* For source_strategy="kb_search": Can be empty or minimal — the tool handles searching internally.
|
* For source_strategy="kb_search": Can be empty or minimal — the tool handles searching internally.
|
||||||
* For source_strategy="auto": Include what you have; the tool searches KB if it's not enough.
|
* For source_strategy="auto": Include what you have; the tool searches KB if it's not enough.
|
||||||
- source_strategy: Controls how the tool collects source material. One of:
|
- source_strategy: Controls how the tool collects source material. One of:
|
||||||
* "conversation" — The conversation already contains enough context (prior Q&A, discussion, pasted text, scraped pages). Pass a thorough summary as source_content. Do NOT call search_knowledge_base separately.
|
* "conversation" — The conversation already contains enough context (prior Q&A, discussion, pasted text, scraped pages). Pass a thorough summary as source_content.
|
||||||
* "kb_search" — The tool will search the knowledge base internally. Provide search_queries with 1-5 targeted queries. Do NOT call search_knowledge_base separately.
|
* "kb_search" — The tool will search the knowledge base internally. Provide search_queries with 1-5 targeted queries.
|
||||||
* "auto" — Use source_content if sufficient, otherwise fall back to internal KB search using search_queries.
|
* "auto" — Use source_content if sufficient, otherwise fall back to internal KB search using search_queries.
|
||||||
* "provided" — Use only what is in source_content (default, backward-compatible).
|
* "provided" — Use only what is in source_content (default, backward-compatible).
|
||||||
- search_queries: When source_strategy is "kb_search" or "auto", provide 1-5 specific search queries for the knowledge base. These should be precise, not just the topic name repeated.
|
- search_queries: When source_strategy is "kb_search" or "auto", provide 1-5 specific search queries for the knowledge base. These should be precise, not just the topic name repeated.
|
||||||
|
|
@ -176,11 +174,11 @@ _TOOL_INSTRUCTIONS["generate_report"] = """
|
||||||
- The report is generated immediately in Markdown and displayed inline in the chat.
|
- The report is generated immediately in Markdown and displayed inline in the chat.
|
||||||
- Export/download formats (PDF, DOCX, HTML, LaTeX, EPUB, ODT, plain text) are produced from the generated Markdown report.
|
- Export/download formats (PDF, DOCX, HTML, LaTeX, EPUB, ODT, plain text) are produced from the generated Markdown report.
|
||||||
- SOURCE STRATEGY DECISION (HIGH PRIORITY — follow this exactly):
|
- SOURCE STRATEGY DECISION (HIGH PRIORITY — follow this exactly):
|
||||||
* If the conversation already has substantive Q&A / discussion on the topic → use source_strategy="conversation" with a comprehensive summary as source_content. Do NOT call search_knowledge_base first.
|
* If the conversation already has substantive Q&A / discussion on the topic → use source_strategy="conversation" with a comprehensive summary as source_content.
|
||||||
* If the user wants a report on a topic not yet discussed → use source_strategy="kb_search" with targeted search_queries. Do NOT call search_knowledge_base first.
|
* If the user wants a report on a topic not yet discussed → use source_strategy="kb_search" with targeted search_queries.
|
||||||
* If you have some content but might need more → use source_strategy="auto" with both source_content and search_queries.
|
* If you have some content but might need more → use source_strategy="auto" with both source_content and search_queries.
|
||||||
* When revising an existing report (parent_report_id set) and the conversation has relevant context → use source_strategy="conversation". The revision will use the previous report content plus your source_content.
|
* When revising an existing report (parent_report_id set) and the conversation has relevant context → use source_strategy="conversation". The revision will use the previous report content plus your source_content.
|
||||||
* NEVER call search_knowledge_base and then pass its results to generate_report. The tool handles KB search internally.
|
* NEVER run a separate KB lookup step and then pass those results to generate_report. The tool handles KB search internally.
|
||||||
- AFTER CALLING THIS TOOL: Do NOT repeat, summarize, or reproduce the report content in the chat. The report is already displayed as an interactive card that the user can open, read, copy, and export. Simply confirm that the report was generated (e.g., "I've generated your report on [topic]. You can view the Markdown report now, and export it in various formats from the card."). NEVER write out the report text in the chat.
|
- AFTER CALLING THIS TOOL: Do NOT repeat, summarize, or reproduce the report content in the chat. The report is already displayed as an interactive card that the user can open, read, copy, and export. Simply confirm that the report was generated (e.g., "I've generated your report on [topic]. You can view the Markdown report now, and export it in various formats from the card."). NEVER write out the report text in the chat.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -204,7 +202,7 @@ _TOOL_INSTRUCTIONS["scrape_webpage"] = """
|
||||||
* When a user asks to "get", "fetch", "pull", "grab", "scrape", or "read" content from a URL
|
* When a user asks to "get", "fetch", "pull", "grab", "scrape", or "read" content from a URL
|
||||||
* When the user wants live/dynamic data from a specific webpage (e.g., tables, scores, stats, prices)
|
* When the user wants live/dynamic data from a specific webpage (e.g., tables, scores, stats, prices)
|
||||||
* When a URL was mentioned earlier in the conversation and the user asks for its actual content
|
* When a URL was mentioned earlier in the conversation and the user asks for its actual content
|
||||||
* When search_knowledge_base returned insufficient data and the user wants more
|
* When preloaded `/documents/` data is insufficient and the user wants more
|
||||||
- Trigger scenarios:
|
- Trigger scenarios:
|
||||||
* "Read this article and summarize it"
|
* "Read this article and summarize it"
|
||||||
* "What does this page say about X?"
|
* "What does this page say about X?"
|
||||||
|
|
@ -366,23 +364,6 @@ _MEMORY_TOOL_EXAMPLES: dict[str, dict[str, str]] = {
|
||||||
# Per-tool examples keyed by tool name. Only examples for enabled tools are included.
|
# Per-tool examples keyed by tool name. Only examples for enabled tools are included.
|
||||||
_TOOL_EXAMPLES: dict[str, str] = {}
|
_TOOL_EXAMPLES: dict[str, str] = {}
|
||||||
|
|
||||||
_TOOL_EXAMPLES["search_knowledge_base"] = """
|
|
||||||
- User: "What time is the team meeting today?"
|
|
||||||
- Call: `search_knowledge_base(query="team meeting time today")` (searches ALL sources - calendar, notes, Obsidian, etc.)
|
|
||||||
- DO NOT limit to just calendar - the info might be in notes!
|
|
||||||
- User: "When is my gym session?"
|
|
||||||
- Call: `search_knowledge_base(query="gym session time schedule")` (searches ALL sources)
|
|
||||||
- User: "Fetch all my notes and what's in them?"
|
|
||||||
- Call: `search_knowledge_base(query="*", top_k=50, connectors_to_search=["NOTE"])`
|
|
||||||
- User: "What did I discuss on Slack last week about the React migration?"
|
|
||||||
- Call: `search_knowledge_base(query="React migration", connectors_to_search=["SLACK_CONNECTOR"], start_date="YYYY-MM-DD", end_date="YYYY-MM-DD")`
|
|
||||||
- User: "Check my Obsidian notes for meeting notes"
|
|
||||||
- Call: `search_knowledge_base(query="meeting notes", connectors_to_search=["OBSIDIAN_CONNECTOR"])`
|
|
||||||
- User: "search me current usd to inr rate"
|
|
||||||
- Call: `web_search(query="current USD to INR exchange rate")`
|
|
||||||
- Then answer using the returned live web results with citations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_TOOL_EXAMPLES["search_surfsense_docs"] = """
|
_TOOL_EXAMPLES["search_surfsense_docs"] = """
|
||||||
- User: "How do I install SurfSense?"
|
- User: "How do I install SurfSense?"
|
||||||
- Call: `search_surfsense_docs(query="installation setup")`
|
- Call: `search_surfsense_docs(query="installation setup")`
|
||||||
|
|
@ -400,8 +381,7 @@ _TOOL_EXAMPLES["generate_podcast"] = """
|
||||||
- User: "Create a podcast summary of this conversation"
|
- User: "Create a podcast summary of this conversation"
|
||||||
- Call: `generate_podcast(source_content="Complete conversation summary:\\n\\nUser asked about [topic 1]:\\n[Your detailed response]\\n\\nUser then asked about [topic 2]:\\n[Your detailed response]\\n\\n[Continue for all exchanges in the conversation]", podcast_title="Conversation Summary")`
|
- Call: `generate_podcast(source_content="Complete conversation summary:\\n\\nUser asked about [topic 1]:\\n[Your detailed response]\\n\\nUser then asked about [topic 2]:\\n[Your detailed response]\\n\\n[Continue for all exchanges in the conversation]", podcast_title="Conversation Summary")`
|
||||||
- User: "Make a podcast about quantum computing"
|
- User: "Make a podcast about quantum computing"
|
||||||
- First search: `search_knowledge_base(query="quantum computing")`
|
- First explore `/documents/` (ls/glob/grep/read_file), then: `generate_podcast(source_content="Key insights about quantum computing from retrieved files:\\n\\n[Comprehensive summary of findings]", podcast_title="Quantum Computing Explained")`
|
||||||
- Then: `generate_podcast(source_content="Key insights about quantum computing from the knowledge base:\\n\\n[Comprehensive summary of all relevant search results with key facts, concepts, and findings]", podcast_title="Quantum Computing Explained")`
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_TOOL_EXAMPLES["generate_video_presentation"] = """
|
_TOOL_EXAMPLES["generate_video_presentation"] = """
|
||||||
|
|
@ -410,8 +390,7 @@ _TOOL_EXAMPLES["generate_video_presentation"] = """
|
||||||
- User: "Create slides summarizing this conversation"
|
- User: "Create slides summarizing this conversation"
|
||||||
- Call: `generate_video_presentation(source_content="Complete conversation summary:\\n\\nUser asked about [topic 1]:\\n[Your detailed response]\\n\\nUser then asked about [topic 2]:\\n[Your detailed response]\\n\\n[Continue for all exchanges in the conversation]", video_title="Conversation Summary")`
|
- Call: `generate_video_presentation(source_content="Complete conversation summary:\\n\\nUser asked about [topic 1]:\\n[Your detailed response]\\n\\nUser then asked about [topic 2]:\\n[Your detailed response]\\n\\n[Continue for all exchanges in the conversation]", video_title="Conversation Summary")`
|
||||||
- User: "Make a video presentation about quantum computing"
|
- User: "Make a video presentation about quantum computing"
|
||||||
- First search: `search_knowledge_base(query="quantum computing")`
|
- First explore `/documents/` (ls/glob/grep/read_file), then: `generate_video_presentation(source_content="Key insights about quantum computing from retrieved files:\\n\\n[Comprehensive summary of findings]", video_title="Quantum Computing Explained")`
|
||||||
- Then: `generate_video_presentation(source_content="Key insights about quantum computing from the knowledge base:\\n\\n[Comprehensive summary of all relevant search results with key facts, concepts, and findings]", video_title="Quantum Computing Explained")`
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_TOOL_EXAMPLES["generate_report"] = """
|
_TOOL_EXAMPLES["generate_report"] = """
|
||||||
|
|
@ -471,7 +450,6 @@ _TOOL_EXAMPLES["web_search"] = """
|
||||||
# All tool names that have prompt instructions (order matters for prompt readability)
|
# All tool names that have prompt instructions (order matters for prompt readability)
|
||||||
_ALL_TOOL_NAMES_ORDERED = [
|
_ALL_TOOL_NAMES_ORDERED = [
|
||||||
"search_surfsense_docs",
|
"search_surfsense_docs",
|
||||||
"search_knowledge_base",
|
|
||||||
"web_search",
|
"web_search",
|
||||||
"generate_podcast",
|
"generate_podcast",
|
||||||
"generate_video_presentation",
|
"generate_video_presentation",
|
||||||
|
|
@ -650,87 +628,6 @@ However, from your video learning, it's important to note that asyncio is not su
|
||||||
</citation_instructions>
|
</citation_instructions>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Sandbox / code execution instructions — appended when sandbox backend is enabled.
|
|
||||||
# Inspired by Claude's computer-use prompt, scoped to code execution & data analytics.
|
|
||||||
SANDBOX_EXECUTION_INSTRUCTIONS = """
|
|
||||||
<code_execution>
|
|
||||||
You have access to a secure, isolated Linux sandbox environment for running code and shell commands.
|
|
||||||
This gives you the `execute` tool alongside the standard filesystem tools (`ls`, `read_file`, `write_file`, `edit_file`, `glob`, `grep`).
|
|
||||||
|
|
||||||
## CRITICAL — CODE-FIRST RULE
|
|
||||||
|
|
||||||
ALWAYS prefer executing code over giving a text-only response when the user's request involves ANY of the following:
|
|
||||||
- **Creating a chart, plot, graph, or visualization** → Write Python code and generate the actual file. NEVER describe percentages or data in text and offer to "paste into Excel". Just produce the chart.
|
|
||||||
- **Data analysis, statistics, or computation** → Write code to compute the answer. Do not do math by hand in text.
|
|
||||||
- **Generating or transforming files** (CSV, PDF, images, etc.) → Write code to create the file.
|
|
||||||
- **Running, testing, or debugging code** → Execute it in the sandbox.
|
|
||||||
|
|
||||||
This applies even when you first retrieve data from the knowledge base. After `search_knowledge_base` returns relevant data, **immediately proceed to write and execute code** if the user's request matches any of the categories above. Do NOT stop at a text summary and wait for the user to ask you to "use Python" — that extra round-trip is a poor experience.
|
|
||||||
|
|
||||||
Example (CORRECT):
|
|
||||||
User: "Create a pie chart of my benefits"
|
|
||||||
→ 1. search_knowledge_base → retrieve benefits data
|
|
||||||
→ 2. Immediately execute Python code (matplotlib) to generate the pie chart
|
|
||||||
→ 3. Return the downloadable file + brief description
|
|
||||||
|
|
||||||
Example (WRONG):
|
|
||||||
User: "Create a pie chart of my benefits"
|
|
||||||
→ 1. search_knowledge_base → retrieve benefits data
|
|
||||||
→ 2. Print a text table with percentages and ask the user if they want a chart ← NEVER do this
|
|
||||||
|
|
||||||
## When to Use Code Execution
|
|
||||||
|
|
||||||
Use the sandbox when the task benefits from actually running code rather than just describing it:
|
|
||||||
- **Data analysis**: Load CSVs/JSON, compute statistics, filter/aggregate data, pivot tables
|
|
||||||
- **Visualization**: Generate charts and plots (matplotlib, plotly, seaborn)
|
|
||||||
- **Calculations**: Math, financial modeling, unit conversions, simulations
|
|
||||||
- **Code validation**: Run and test code snippets the user provides or asks about
|
|
||||||
- **File processing**: Parse, transform, or convert data files
|
|
||||||
- **Quick prototyping**: Demonstrate working code for the user's problem
|
|
||||||
- **Package exploration**: Install and test libraries the user is evaluating
|
|
||||||
|
|
||||||
## When NOT to Use Code Execution
|
|
||||||
|
|
||||||
Do not use the sandbox for:
|
|
||||||
- Answering factual questions from your own knowledge
|
|
||||||
- Summarizing or explaining concepts
|
|
||||||
- Simple formatting or text generation tasks
|
|
||||||
- Tasks that don't require running code to answer
|
|
||||||
|
|
||||||
## Package Management
|
|
||||||
|
|
||||||
- Use `pip install <package>` to install Python packages as needed
|
|
||||||
- Common data/analytics packages (pandas, numpy, matplotlib, scipy, scikit-learn) may need to be installed on first use
|
|
||||||
- Always verify a package installed successfully before using it
|
|
||||||
|
|
||||||
## Working Guidelines
|
|
||||||
|
|
||||||
- **Working directory**: The shell starts in the sandbox user's home directory (e.g. `/home/daytona`). Use **relative paths** or `/tmp/` for all files you create. NEVER write directly to `/home/` — that is the parent directory and is not writable. Use `pwd` if you need to discover the current working directory.
|
|
||||||
- **Iterative approach**: For complex tasks, break work into steps — write code, run it, check output, refine
|
|
||||||
- **Error handling**: If code fails, read the error, fix the issue, and retry. Don't just report the error without attempting a fix.
|
|
||||||
- **Show results**: When generating plots or outputs, present the key findings directly in your response. For plots, save to a file and describe the results.
|
|
||||||
- **Be efficient**: Install packages once per session. Combine related commands when possible.
|
|
||||||
- **Large outputs**: If command output is very large, use `head`, `tail`, or save to a file and read selectively.
|
|
||||||
|
|
||||||
## Sharing Generated Files
|
|
||||||
|
|
||||||
When your code creates output files (images, CSVs, PDFs, etc.) in the sandbox:
|
|
||||||
- **Print the absolute path** at the end of your script so the user can download the file. Example: `print("SANDBOX_FILE: /tmp/chart.png")`
|
|
||||||
- **DO NOT use markdown image syntax** for files created inside the sandbox. Sandbox files are not accessible via public URLs and will show "Image not available". The frontend automatically renders a download button from the `SANDBOX_FILE:` marker.
|
|
||||||
- You can output multiple files, one per line: `print("SANDBOX_FILE: /tmp/report.csv")`, `print("SANDBOX_FILE: /tmp/chart.png")`
|
|
||||||
- Always describe what the file contains in your response text so the user knows what they are downloading.
|
|
||||||
- IMPORTANT: Every `execute` call that saves a file MUST print the `SANDBOX_FILE: <path>` marker. Without it the user cannot download the file.
|
|
||||||
|
|
||||||
## Data Analytics Best Practices
|
|
||||||
|
|
||||||
When the user asks you to analyze data:
|
|
||||||
1. First, inspect the data structure (`head`, `shape`, `dtypes`, `describe()`)
|
|
||||||
2. Clean and validate before computing (handle nulls, check types)
|
|
||||||
3. Perform the analysis and present results clearly
|
|
||||||
4. Offer follow-up insights or visualizations when appropriate
|
|
||||||
</code_execution>
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Anti-citation prompt - used when citations are disabled
|
# Anti-citation prompt - used when citations are disabled
|
||||||
# This explicitly tells the model NOT to include citations
|
# This explicitly tells the model NOT to include citations
|
||||||
SURFSENSE_NO_CITATION_INSTRUCTIONS = """
|
SURFSENSE_NO_CITATION_INSTRUCTIONS = """
|
||||||
|
|
@ -756,7 +653,6 @@ Your goal is to provide helpful, informative answers in a clean, readable format
|
||||||
def build_surfsense_system_prompt(
|
def build_surfsense_system_prompt(
|
||||||
today: datetime | None = None,
|
today: datetime | None = None,
|
||||||
thread_visibility: ChatVisibility | None = None,
|
thread_visibility: ChatVisibility | None = None,
|
||||||
sandbox_enabled: bool = False,
|
|
||||||
enabled_tool_names: set[str] | None = None,
|
enabled_tool_names: set[str] | None = None,
|
||||||
disabled_tool_names: set[str] | None = None,
|
disabled_tool_names: set[str] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
@ -767,12 +663,10 @@ def build_surfsense_system_prompt(
|
||||||
- Default system instructions
|
- Default system instructions
|
||||||
- Tools instructions (only for enabled tools)
|
- Tools instructions (only for enabled tools)
|
||||||
- Citation instructions enabled
|
- Citation instructions enabled
|
||||||
- Sandbox execution instructions (when sandbox_enabled=True)
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
today: Optional datetime for today's date (defaults to current UTC date)
|
today: Optional datetime for today's date (defaults to current UTC date)
|
||||||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||||
sandbox_enabled: Whether the sandbox backend is active (adds code execution instructions).
|
|
||||||
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
|
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
|
||||||
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
|
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
|
||||||
|
|
||||||
|
|
@ -786,13 +680,7 @@ def build_surfsense_system_prompt(
|
||||||
visibility, enabled_tool_names, disabled_tool_names
|
visibility, enabled_tool_names, disabled_tool_names
|
||||||
)
|
)
|
||||||
citation_instructions = SURFSENSE_CITATION_INSTRUCTIONS
|
citation_instructions = SURFSENSE_CITATION_INSTRUCTIONS
|
||||||
sandbox_instructions = SANDBOX_EXECUTION_INSTRUCTIONS if sandbox_enabled else ""
|
return system_instructions + tools_instructions + citation_instructions
|
||||||
return (
|
|
||||||
system_instructions
|
|
||||||
+ tools_instructions
|
|
||||||
+ citation_instructions
|
|
||||||
+ sandbox_instructions
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_configurable_system_prompt(
|
def build_configurable_system_prompt(
|
||||||
|
|
@ -801,18 +689,16 @@ def build_configurable_system_prompt(
|
||||||
citations_enabled: bool = True,
|
citations_enabled: bool = True,
|
||||||
today: datetime | None = None,
|
today: datetime | None = None,
|
||||||
thread_visibility: ChatVisibility | None = None,
|
thread_visibility: ChatVisibility | None = None,
|
||||||
sandbox_enabled: bool = False,
|
|
||||||
enabled_tool_names: set[str] | None = None,
|
enabled_tool_names: set[str] | None = None,
|
||||||
disabled_tool_names: set[str] | None = None,
|
disabled_tool_names: set[str] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
|
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
|
||||||
|
|
||||||
The prompt is composed of up to four parts:
|
The prompt is composed of three parts:
|
||||||
1. System Instructions - either custom or default SURFSENSE_SYSTEM_INSTRUCTIONS
|
1. System Instructions - either custom or default SURFSENSE_SYSTEM_INSTRUCTIONS
|
||||||
2. Tools Instructions - only for enabled tools, with a note about disabled ones
|
2. Tools Instructions - only for enabled tools, with a note about disabled ones
|
||||||
3. Citation Instructions - either SURFSENSE_CITATION_INSTRUCTIONS or SURFSENSE_NO_CITATION_INSTRUCTIONS
|
3. Citation Instructions - either SURFSENSE_CITATION_INSTRUCTIONS or SURFSENSE_NO_CITATION_INSTRUCTIONS
|
||||||
4. Sandbox Execution Instructions - when sandbox_enabled=True
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
custom_system_instructions: Custom system instructions to use. If empty/None and
|
custom_system_instructions: Custom system instructions to use. If empty/None and
|
||||||
|
|
@ -824,7 +710,6 @@ def build_configurable_system_prompt(
|
||||||
anti-citation instructions (False).
|
anti-citation instructions (False).
|
||||||
today: Optional datetime for today's date (defaults to current UTC date)
|
today: Optional datetime for today's date (defaults to current UTC date)
|
||||||
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
thread_visibility: Optional; when provided, used for conditional prompt (e.g. private vs shared memory wording). Defaults to private behavior when None.
|
||||||
sandbox_enabled: Whether the sandbox backend is active (adds code execution instructions).
|
|
||||||
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
|
enabled_tool_names: Set of tool names actually bound to the agent. When None all tools are included.
|
||||||
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
|
disabled_tool_names: Set of tool names the user explicitly disabled. Included as a note so the model can inform the user.
|
||||||
|
|
||||||
|
|
@ -856,14 +741,7 @@ def build_configurable_system_prompt(
|
||||||
else SURFSENSE_NO_CITATION_INSTRUCTIONS
|
else SURFSENSE_NO_CITATION_INSTRUCTIONS
|
||||||
)
|
)
|
||||||
|
|
||||||
sandbox_instructions = SANDBOX_EXECUTION_INSTRUCTIONS if sandbox_enabled else ""
|
return system_instructions + tools_instructions + citation_instructions
|
||||||
|
|
||||||
return (
|
|
||||||
system_instructions
|
|
||||||
+ tools_instructions
|
|
||||||
+ citation_instructions
|
|
||||||
+ sandbox_instructions
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_system_instructions() -> str:
|
def get_default_system_instructions() -> str:
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ This module contains all the tools available to the SurfSense agent.
|
||||||
To add a new tool, see the documentation in registry.py.
|
To add a new tool, see the documentation in registry.py.
|
||||||
|
|
||||||
Available tools:
|
Available tools:
|
||||||
- search_knowledge_base: Search the user's personal knowledge base
|
|
||||||
- search_surfsense_docs: Search Surfsense documentation for usage help
|
- search_surfsense_docs: Search Surfsense documentation for usage help
|
||||||
- generate_podcast: Generate audio podcasts from content
|
- generate_podcast: Generate audio podcasts from content
|
||||||
- generate_video_presentation: Generate video presentations with slides and narration
|
- generate_video_presentation: Generate video presentations with slides and narration
|
||||||
|
|
@ -20,7 +19,6 @@ Available tools:
|
||||||
from .generate_image import create_generate_image_tool
|
from .generate_image import create_generate_image_tool
|
||||||
from .knowledge_base import (
|
from .knowledge_base import (
|
||||||
CONNECTOR_DESCRIPTIONS,
|
CONNECTOR_DESCRIPTIONS,
|
||||||
create_search_knowledge_base_tool,
|
|
||||||
format_documents_for_context,
|
format_documents_for_context,
|
||||||
search_knowledge_base_async,
|
search_knowledge_base_async,
|
||||||
)
|
)
|
||||||
|
|
@ -52,7 +50,6 @@ __all__ = [
|
||||||
"create_recall_memory_tool",
|
"create_recall_memory_tool",
|
||||||
"create_save_memory_tool",
|
"create_save_memory_tool",
|
||||||
"create_scrape_webpage_tool",
|
"create_scrape_webpage_tool",
|
||||||
"create_search_knowledge_base_tool",
|
|
||||||
"create_search_surfsense_docs_tool",
|
"create_search_surfsense_docs_tool",
|
||||||
"format_documents_for_context",
|
"format_documents_for_context",
|
||||||
"get_all_tool_names",
|
"get_all_tool_names",
|
||||||
|
|
|
||||||
|
|
@ -273,9 +273,7 @@ def create_update_calendar_event_tool(
|
||||||
final_new_start_datetime, context
|
final_new_start_datetime, context
|
||||||
)
|
)
|
||||||
if final_new_end_datetime is not None:
|
if final_new_end_datetime is not None:
|
||||||
update_body["end"] = _build_time_body(
|
update_body["end"] = _build_time_body(final_new_end_datetime, context)
|
||||||
final_new_end_datetime, context
|
|
||||||
)
|
|
||||||
if final_new_description is not None:
|
if final_new_description is not None:
|
||||||
update_body["description"] = final_new_description
|
update_body["description"] = final_new_description
|
||||||
if final_new_location is not None:
|
if final_new_location is not None:
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ This module provides:
|
||||||
- Connector constants and normalization
|
- Connector constants and normalization
|
||||||
- Async knowledge base search across multiple connectors
|
- Async knowledge base search across multiple connectors
|
||||||
- Document formatting for LLM context
|
- Document formatting for LLM context
|
||||||
- Tool factory for creating search_knowledge_base tools
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -16,8 +15,6 @@ import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import StructuredTool
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import NATIVE_TO_LEGACY_DOCTYPE, shielded_async_session
|
from app.db import NATIVE_TO_LEGACY_DOCTYPE, shielded_async_session
|
||||||
|
|
@ -619,9 +616,76 @@ async def search_knowledge_base_async(
|
||||||
perf = get_perf_logger()
|
perf = get_perf_logger()
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
|
|
||||||
|
deduplicated = await search_knowledge_base_raw_async(
|
||||||
|
query=query,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
db_session=db_session,
|
||||||
|
connector_service=connector_service,
|
||||||
|
connectors_to_search=connectors_to_search,
|
||||||
|
top_k=top_k,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
available_connectors=available_connectors,
|
||||||
|
available_document_types=available_document_types,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not deduplicated:
|
||||||
|
return "No documents found in the knowledge base. The search space has no indexed content yet."
|
||||||
|
|
||||||
|
# Use browse chunk cap for degenerate queries, otherwise adaptive chunking.
|
||||||
|
max_chunks_per_doc = (
|
||||||
|
_BROWSE_MAX_CHUNKS_PER_DOC if _is_degenerate_query(query) else 0
|
||||||
|
)
|
||||||
|
output_budget = _compute_tool_output_budget(max_input_tokens)
|
||||||
|
result = format_documents_for_context(
|
||||||
|
deduplicated,
|
||||||
|
max_chars=output_budget,
|
||||||
|
max_chunks_per_doc=max_chunks_per_doc,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(result) > output_budget:
|
||||||
|
perf.warning(
|
||||||
|
"[kb_search] output STILL exceeds budget after format (%d > %d), "
|
||||||
|
"hard truncation should have fired",
|
||||||
|
len(result),
|
||||||
|
output_budget,
|
||||||
|
)
|
||||||
|
|
||||||
|
perf.info(
|
||||||
|
"[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d "
|
||||||
|
"budget=%d max_input_tokens=%s space=%d",
|
||||||
|
time.perf_counter() - t0,
|
||||||
|
len(deduplicated),
|
||||||
|
len(deduplicated),
|
||||||
|
len(result),
|
||||||
|
output_budget,
|
||||||
|
max_input_tokens,
|
||||||
|
search_space_id,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def search_knowledge_base_raw_async(
|
||||||
|
query: str,
|
||||||
|
search_space_id: int,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
connector_service: ConnectorService,
|
||||||
|
connectors_to_search: list[str] | None = None,
|
||||||
|
top_k: int = 10,
|
||||||
|
start_date: datetime | None = None,
|
||||||
|
end_date: datetime | None = None,
|
||||||
|
available_connectors: list[str] | None = None,
|
||||||
|
available_document_types: list[str] | None = None,
|
||||||
|
query_embedding: list[float] | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Search knowledge base and return raw document dicts (no XML formatting)."""
|
||||||
|
perf = get_perf_logger()
|
||||||
|
t0 = time.perf_counter()
|
||||||
all_documents: list[dict[str, Any]] = []
|
all_documents: list[dict[str, Any]] = []
|
||||||
|
|
||||||
# Resolve date range (default last 2 years)
|
# Preserve the public signature for compatibility even if values are unused.
|
||||||
|
_ = (db_session, connector_service)
|
||||||
|
|
||||||
from app.agents.new_chat.utils import resolve_date_range
|
from app.agents.new_chat.utils import resolve_date_range
|
||||||
|
|
||||||
resolved_start_date, resolved_end_date = resolve_date_range(
|
resolved_start_date, resolved_end_date = resolve_date_range(
|
||||||
|
|
@ -631,144 +695,76 @@ async def search_knowledge_base_async(
|
||||||
|
|
||||||
connectors = _normalize_connectors(connectors_to_search, available_connectors)
|
connectors = _normalize_connectors(connectors_to_search, available_connectors)
|
||||||
|
|
||||||
# --- Optimization 1: skip connectors that have zero indexed documents ---
|
|
||||||
if available_document_types:
|
if available_document_types:
|
||||||
doc_types_set = set(available_document_types)
|
doc_types_set = set(available_document_types)
|
||||||
before_count = len(connectors)
|
|
||||||
connectors = [
|
connectors = [
|
||||||
c
|
c
|
||||||
for c in connectors
|
for c in connectors
|
||||||
if c in doc_types_set
|
if c in doc_types_set
|
||||||
or NATIVE_TO_LEGACY_DOCTYPE.get(c, "") in doc_types_set
|
or NATIVE_TO_LEGACY_DOCTYPE.get(c, "") in doc_types_set
|
||||||
]
|
]
|
||||||
skipped = before_count - len(connectors)
|
|
||||||
if skipped:
|
|
||||||
perf.info(
|
|
||||||
"[kb_search] skipped %d empty connectors (had %d, now %d)",
|
|
||||||
skipped,
|
|
||||||
before_count,
|
|
||||||
len(connectors),
|
|
||||||
)
|
|
||||||
|
|
||||||
perf.info(
|
|
||||||
"[kb_search] searching %d connectors: %s (space=%d, top_k=%d)",
|
|
||||||
len(connectors),
|
|
||||||
connectors[:5],
|
|
||||||
search_space_id,
|
|
||||||
top_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Fast-path: no connectors left after filtering ---
|
|
||||||
if not connectors:
|
if not connectors:
|
||||||
perf.info(
|
return []
|
||||||
"[kb_search] TOTAL in %.3fs — no connectors to search, returning empty",
|
|
||||||
time.perf_counter() - t0,
|
|
||||||
)
|
|
||||||
return "No documents found in the knowledge base. The search space has no indexed content yet."
|
|
||||||
|
|
||||||
# --- Fast-path: degenerate queries (*, **, empty, etc.) ---
|
|
||||||
# Semantic embedding of '*' is noise and plainto_tsquery('english', '*')
|
|
||||||
# yields an empty tsquery, so both retrieval signals are useless.
|
|
||||||
# Fall back to a recency-ordered browse that returns diverse results.
|
|
||||||
if _is_degenerate_query(query):
|
if _is_degenerate_query(query):
|
||||||
perf.info(
|
perf.info(
|
||||||
"[kb_search] degenerate query %r detected - falling back to recency browse",
|
"[kb_search_raw] degenerate query %r detected - recency browse",
|
||||||
query,
|
query,
|
||||||
)
|
)
|
||||||
browse_connectors = connectors if connectors else [None] # type: ignore[list-item]
|
browse_connectors = connectors if connectors else [None] # type: ignore[list-item]
|
||||||
|
|
||||||
expanded_browse = []
|
expanded_browse = []
|
||||||
for c in browse_connectors:
|
for connector in browse_connectors:
|
||||||
if c is not None and c in NATIVE_TO_LEGACY_DOCTYPE:
|
if connector is not None and connector in NATIVE_TO_LEGACY_DOCTYPE:
|
||||||
expanded_browse.append([c, NATIVE_TO_LEGACY_DOCTYPE[c]])
|
expanded_browse.append([connector, NATIVE_TO_LEGACY_DOCTYPE[connector]])
|
||||||
else:
|
else:
|
||||||
expanded_browse.append(c)
|
expanded_browse.append(connector)
|
||||||
|
|
||||||
browse_results = await asyncio.gather(
|
browse_results = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
_browse_recent_documents(
|
_browse_recent_documents(
|
||||||
search_space_id=search_space_id,
|
|
||||||
document_type=c,
|
|
||||||
top_k=top_k,
|
|
||||||
start_date=resolved_start_date,
|
|
||||||
end_date=resolved_end_date,
|
|
||||||
)
|
|
||||||
for c in expanded_browse
|
|
||||||
]
|
|
||||||
)
|
|
||||||
for docs in browse_results:
|
|
||||||
all_documents.extend(docs)
|
|
||||||
|
|
||||||
# Skip dedup + formatting below (browse already returns unique docs)
|
|
||||||
# but still cap output budget.
|
|
||||||
output_budget = _compute_tool_output_budget(max_input_tokens)
|
|
||||||
result = format_documents_for_context(
|
|
||||||
all_documents,
|
|
||||||
max_chars=output_budget,
|
|
||||||
max_chunks_per_doc=_BROWSE_MAX_CHUNKS_PER_DOC,
|
|
||||||
)
|
|
||||||
perf.info(
|
|
||||||
"[kb_search] TOTAL (browse) in %.3fs total_docs=%d output_chars=%d "
|
|
||||||
"budget=%d space=%d",
|
|
||||||
time.perf_counter() - t0,
|
|
||||||
len(all_documents),
|
|
||||||
len(result),
|
|
||||||
output_budget,
|
|
||||||
search_space_id,
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
# --- Optimization 2: compute the query embedding once, share across all local searches ---
|
|
||||||
from app.config import config as app_config
|
|
||||||
|
|
||||||
t_embed = time.perf_counter()
|
|
||||||
precomputed_embedding = app_config.embedding_model_instance.embed(query)
|
|
||||||
perf.info(
|
|
||||||
"[kb_search] shared embedding computed in %.3fs",
|
|
||||||
time.perf_counter() - t_embed,
|
|
||||||
)
|
|
||||||
|
|
||||||
max_parallel_searches = 4
|
|
||||||
semaphore = asyncio.Semaphore(max_parallel_searches)
|
|
||||||
|
|
||||||
async def _search_one_connector(connector: str) -> list[dict[str, Any]]:
|
|
||||||
try:
|
|
||||||
t_conn = time.perf_counter()
|
|
||||||
async with semaphore, shielded_async_session() as isolated_session:
|
|
||||||
svc = ConnectorService(isolated_session, search_space_id)
|
|
||||||
chunks = await svc._combined_rrf_search(
|
|
||||||
query_text=query,
|
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
document_type=connector,
|
document_type=connector,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
start_date=resolved_start_date,
|
start_date=resolved_start_date,
|
||||||
end_date=resolved_end_date,
|
end_date=resolved_end_date,
|
||||||
query_embedding=precomputed_embedding,
|
|
||||||
)
|
)
|
||||||
perf.info(
|
for connector in expanded_browse
|
||||||
"[kb_search] connector=%s results=%d in %.3fs",
|
]
|
||||||
connector,
|
)
|
||||||
len(chunks),
|
for docs in browse_results:
|
||||||
time.perf_counter() - t_conn,
|
all_documents.extend(docs)
|
||||||
)
|
else:
|
||||||
return chunks
|
if query_embedding is None:
|
||||||
except Exception as e:
|
from app.config import config as app_config
|
||||||
perf.warning("[kb_search] connector=%s FAILED: %s", connector, e)
|
|
||||||
return []
|
|
||||||
|
|
||||||
t_gather = time.perf_counter()
|
query_embedding = app_config.embedding_model_instance.embed(query)
|
||||||
connector_results = await asyncio.gather(
|
|
||||||
*[_search_one_connector(connector) for connector in connectors]
|
max_parallel_searches = 4
|
||||||
)
|
semaphore = asyncio.Semaphore(max_parallel_searches)
|
||||||
perf.info(
|
|
||||||
"[kb_search] all connectors gathered in %.3fs",
|
async def _search_one_connector(connector: str) -> list[dict[str, Any]]:
|
||||||
time.perf_counter() - t_gather,
|
try:
|
||||||
)
|
async with semaphore, shielded_async_session() as isolated_session:
|
||||||
for chunks in connector_results:
|
svc = ConnectorService(isolated_session, search_space_id)
|
||||||
all_documents.extend(chunks)
|
return await svc._combined_rrf_search(
|
||||||
|
query_text=query,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
document_type=connector,
|
||||||
|
top_k=top_k,
|
||||||
|
start_date=resolved_start_date,
|
||||||
|
end_date=resolved_end_date,
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
perf.warning("[kb_search_raw] connector=%s FAILED: %s", connector, exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
connector_results = await asyncio.gather(
|
||||||
|
*[_search_one_connector(connector) for connector in connectors]
|
||||||
|
)
|
||||||
|
for docs in connector_results:
|
||||||
|
all_documents.extend(docs)
|
||||||
|
|
||||||
# Deduplicate primarily by document ID. Only fall back to content hashing
|
|
||||||
# when a document has no ID.
|
|
||||||
seen_doc_ids: set[Any] = set()
|
seen_doc_ids: set[Any] = set()
|
||||||
seen_content_hashes: set[int] = set()
|
seen_content_hashes: set[int] = set()
|
||||||
deduplicated: list[dict[str, Any]] = []
|
deduplicated: list[dict[str, Any]] = []
|
||||||
|
|
@ -785,7 +781,6 @@ async def search_knowledge_base_async(
|
||||||
chunk_texts.append(chunk_content)
|
chunk_texts.append(chunk_content)
|
||||||
if chunk_texts:
|
if chunk_texts:
|
||||||
return hash("||".join(chunk_texts))
|
return hash("||".join(chunk_texts))
|
||||||
|
|
||||||
flat_content = (document.get("content") or "").strip()
|
flat_content = (document.get("content") or "").strip()
|
||||||
if flat_content:
|
if flat_content:
|
||||||
return hash(flat_content)
|
return hash(flat_content)
|
||||||
|
|
@ -793,216 +788,24 @@ async def search_knowledge_base_async(
|
||||||
|
|
||||||
for doc in all_documents:
|
for doc in all_documents:
|
||||||
doc_id = (doc.get("document", {}) or {}).get("id")
|
doc_id = (doc.get("document", {}) or {}).get("id")
|
||||||
|
|
||||||
if doc_id is not None:
|
if doc_id is not None:
|
||||||
if doc_id in seen_doc_ids:
|
if doc_id in seen_doc_ids:
|
||||||
continue
|
continue
|
||||||
seen_doc_ids.add(doc_id)
|
seen_doc_ids.add(doc_id)
|
||||||
deduplicated.append(doc)
|
deduplicated.append(doc)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
content_hash = _content_fingerprint(doc)
|
content_hash = _content_fingerprint(doc)
|
||||||
|
if content_hash is not None and content_hash in seen_content_hashes:
|
||||||
|
continue
|
||||||
if content_hash is not None:
|
if content_hash is not None:
|
||||||
if content_hash in seen_content_hashes:
|
|
||||||
continue
|
|
||||||
seen_content_hashes.add(content_hash)
|
seen_content_hashes.add(content_hash)
|
||||||
|
|
||||||
deduplicated.append(doc)
|
deduplicated.append(doc)
|
||||||
|
|
||||||
# Sort by RRF score so the most relevant documents from ANY connector
|
deduplicated.sort(key=lambda doc: doc.get("score", 0), reverse=True)
|
||||||
# appear first, preventing budget truncation from hiding top results.
|
|
||||||
deduplicated.sort(key=lambda d: d.get("score", 0), reverse=True)
|
|
||||||
|
|
||||||
output_budget = _compute_tool_output_budget(max_input_tokens)
|
|
||||||
result = format_documents_for_context(deduplicated, max_chars=output_budget)
|
|
||||||
|
|
||||||
if len(result) > output_budget:
|
|
||||||
perf.warning(
|
|
||||||
"[kb_search] output STILL exceeds budget after format (%d > %d), "
|
|
||||||
"hard truncation should have fired",
|
|
||||||
len(result),
|
|
||||||
output_budget,
|
|
||||||
)
|
|
||||||
|
|
||||||
perf.info(
|
perf.info(
|
||||||
"[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d "
|
"[kb_search_raw] done in %.3fs total=%d deduped=%d",
|
||||||
"budget=%d max_input_tokens=%s space=%d",
|
|
||||||
time.perf_counter() - t0,
|
time.perf_counter() - t0,
|
||||||
len(all_documents),
|
len(all_documents),
|
||||||
len(deduplicated),
|
len(deduplicated),
|
||||||
len(result),
|
|
||||||
output_budget,
|
|
||||||
max_input_tokens,
|
|
||||||
search_space_id,
|
|
||||||
)
|
)
|
||||||
return result
|
return deduplicated
|
||||||
|
|
||||||
|
|
||||||
def _build_connector_docstring(available_connectors: list[str] | None) -> str:
|
|
||||||
"""
|
|
||||||
Build the connector documentation section for the tool docstring.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
available_connectors: List of available connector types, or None for all
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted docstring section listing available connectors
|
|
||||||
"""
|
|
||||||
connectors = available_connectors if available_connectors else list(_ALL_CONNECTORS)
|
|
||||||
|
|
||||||
lines = []
|
|
||||||
for connector in connectors:
|
|
||||||
# Skip internal names, prefer user-facing aliases
|
|
||||||
if connector == "CRAWLED_URL":
|
|
||||||
# Show as WEBCRAWLER_CONNECTOR for user-facing docs
|
|
||||||
description = CONNECTOR_DESCRIPTIONS.get(connector, connector)
|
|
||||||
lines.append(f"- WEBCRAWLER_CONNECTOR: {description}")
|
|
||||||
else:
|
|
||||||
description = CONNECTOR_DESCRIPTIONS.get(connector, connector)
|
|
||||||
lines.append(f"- {connector}: {description}")
|
|
||||||
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Tool Input Schema
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class SearchKnowledgeBaseInput(BaseModel):
|
|
||||||
"""Input schema for the search_knowledge_base tool."""
|
|
||||||
|
|
||||||
query: str = Field(
|
|
||||||
description=(
|
|
||||||
"The search query - use specific natural language terms. "
|
|
||||||
"NEVER use wildcards like '*' or '**'; instead describe what you want "
|
|
||||||
"(e.g. 'recent meeting notes' or 'project architecture overview')."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
top_k: int = Field(
|
|
||||||
default=10,
|
|
||||||
description="Number of results to retrieve (default: 10). Keep ≤20 for focused searches.",
|
|
||||||
)
|
|
||||||
start_date: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Optional ISO date/datetime (e.g. '2025-12-12' or '2025-12-12T00:00:00+00:00')",
|
|
||||||
)
|
|
||||||
end_date: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Optional ISO date/datetime (e.g. '2025-12-19' or '2025-12-19T23:59:59+00:00')",
|
|
||||||
)
|
|
||||||
connectors_to_search: list[str] | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Optional list of connector enums to search. If omitted, searches all available.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_search_knowledge_base_tool(
|
|
||||||
search_space_id: int,
|
|
||||||
db_session: AsyncSession,
|
|
||||||
connector_service: ConnectorService,
|
|
||||||
available_connectors: list[str] | None = None,
|
|
||||||
available_document_types: list[str] | None = None,
|
|
||||||
max_input_tokens: int | None = None,
|
|
||||||
) -> StructuredTool:
|
|
||||||
"""
|
|
||||||
Factory function to create the search_knowledge_base tool with injected dependencies.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
search_space_id: The user's search space ID
|
|
||||||
db_session: Database session
|
|
||||||
connector_service: Initialized connector service
|
|
||||||
available_connectors: Optional list of connector types available in the search space.
|
|
||||||
Used to dynamically generate the tool docstring.
|
|
||||||
available_document_types: Optional list of document types that have data in the search space.
|
|
||||||
Used to inform the LLM about what data exists.
|
|
||||||
max_input_tokens: Model context window (tokens) from litellm model info.
|
|
||||||
Used to dynamically size tool output.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A configured StructuredTool instance
|
|
||||||
"""
|
|
||||||
# Build connector documentation dynamically
|
|
||||||
connector_docs = _build_connector_docstring(available_connectors)
|
|
||||||
|
|
||||||
# Build context about available document types
|
|
||||||
doc_types_info = ""
|
|
||||||
if available_document_types:
|
|
||||||
doc_types_info = f"""
|
|
||||||
|
|
||||||
## Document types with indexed content in this search space
|
|
||||||
|
|
||||||
The following document types have content available for search:
|
|
||||||
{", ".join(available_document_types)}
|
|
||||||
|
|
||||||
Focus searches on these types for best results."""
|
|
||||||
|
|
||||||
# Build the dynamic description for the tool
|
|
||||||
# This is what the LLM sees when deciding whether/how to use the tool
|
|
||||||
dynamic_description = f"""Search the user's personal knowledge base for relevant information.
|
|
||||||
|
|
||||||
Use this tool to find documents, notes, files, web pages, and other content the user has indexed.
|
|
||||||
This searches ONLY local/indexed data (uploaded files, Notion, Slack, browser extension captures, etc.).
|
|
||||||
For real-time web search (current events, news, live data), use the `web_search` tool instead.
|
|
||||||
|
|
||||||
IMPORTANT:
|
|
||||||
- Always craft specific, descriptive search queries using natural language keywords.
|
|
||||||
Good: "quarterly sales report Q3", "Python API authentication design".
|
|
||||||
Bad: "*", "**", "everything", single characters. Wildcard/empty queries yield poor results.
|
|
||||||
- Prefer multiple focused searches over a single broad one with high top_k.
|
|
||||||
- If the user requests a specific source type (e.g. "my notes", "Slack messages"), pass `connectors_to_search=[...]` using the enums below.
|
|
||||||
- If `connectors_to_search` is omitted/empty, the system will search broadly.
|
|
||||||
- Only connectors that are enabled/configured for this search space are available.{doc_types_info}
|
|
||||||
|
|
||||||
## Available connector enums for `connectors_to_search`
|
|
||||||
|
|
||||||
{connector_docs}
|
|
||||||
|
|
||||||
NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type `CRAWLED_URL`."""
|
|
||||||
|
|
||||||
# Capture for closure
|
|
||||||
_available_connectors = available_connectors
|
|
||||||
_available_document_types = available_document_types
|
|
||||||
|
|
||||||
async def _search_knowledge_base_impl(
|
|
||||||
query: str,
|
|
||||||
top_k: int = 10,
|
|
||||||
start_date: str | None = None,
|
|
||||||
end_date: str | None = None,
|
|
||||||
connectors_to_search: list[str] | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""Implementation function for knowledge base search."""
|
|
||||||
from app.agents.new_chat.utils import parse_date_or_datetime
|
|
||||||
|
|
||||||
parsed_start: datetime | None = None
|
|
||||||
parsed_end: datetime | None = None
|
|
||||||
|
|
||||||
if start_date:
|
|
||||||
parsed_start = parse_date_or_datetime(start_date)
|
|
||||||
if end_date:
|
|
||||||
parsed_end = parse_date_or_datetime(end_date)
|
|
||||||
|
|
||||||
return await search_knowledge_base_async(
|
|
||||||
query=query,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
db_session=db_session,
|
|
||||||
connector_service=connector_service,
|
|
||||||
connectors_to_search=connectors_to_search,
|
|
||||||
top_k=top_k,
|
|
||||||
start_date=parsed_start,
|
|
||||||
end_date=parsed_end,
|
|
||||||
available_connectors=_available_connectors,
|
|
||||||
available_document_types=_available_document_types,
|
|
||||||
max_input_tokens=max_input_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create StructuredTool with dynamic description
|
|
||||||
# This properly sets the description that the LLM sees
|
|
||||||
tool = StructuredTool(
|
|
||||||
name="search_knowledge_base",
|
|
||||||
description=dynamic_description,
|
|
||||||
coroutine=_search_knowledge_base_impl,
|
|
||||||
args_schema=SearchKnowledgeBaseInput,
|
|
||||||
)
|
|
||||||
|
|
||||||
return tool
|
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,6 @@ from .jira import (
|
||||||
create_delete_jira_issue_tool,
|
create_delete_jira_issue_tool,
|
||||||
create_update_jira_issue_tool,
|
create_update_jira_issue_tool,
|
||||||
)
|
)
|
||||||
from .knowledge_base import create_search_knowledge_base_tool
|
|
||||||
from .linear import (
|
from .linear import (
|
||||||
create_create_linear_issue_tool,
|
create_create_linear_issue_tool,
|
||||||
create_delete_linear_issue_tool,
|
create_delete_linear_issue_tool,
|
||||||
|
|
@ -128,23 +127,6 @@ class ToolDefinition:
|
||||||
# Registry of all built-in tools
|
# Registry of all built-in tools
|
||||||
# Contributors: Add your new tools here!
|
# Contributors: Add your new tools here!
|
||||||
BUILTIN_TOOLS: list[ToolDefinition] = [
|
BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
# Core tool - searches the user's knowledge base
|
|
||||||
# Now supports dynamic connector/document type discovery
|
|
||||||
ToolDefinition(
|
|
||||||
name="search_knowledge_base",
|
|
||||||
description="Search the user's personal knowledge base for relevant information",
|
|
||||||
factory=lambda deps: create_search_knowledge_base_tool(
|
|
||||||
search_space_id=deps["search_space_id"],
|
|
||||||
db_session=deps["db_session"],
|
|
||||||
connector_service=deps["connector_service"],
|
|
||||||
# Optional: dynamically discovered connectors/document types
|
|
||||||
available_connectors=deps.get("available_connectors"),
|
|
||||||
available_document_types=deps.get("available_document_types"),
|
|
||||||
max_input_tokens=deps.get("max_input_tokens"),
|
|
||||||
),
|
|
||||||
requires=["search_space_id", "db_session", "connector_service"],
|
|
||||||
# Note: available_connectors and available_document_types are optional
|
|
||||||
),
|
|
||||||
# Podcast generation tool
|
# Podcast generation tool
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="generate_podcast",
|
name="generate_podcast",
|
||||||
|
|
@ -168,8 +150,8 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
requires=["search_space_id", "db_session", "thread_id"],
|
requires=["search_space_id", "db_session", "thread_id"],
|
||||||
),
|
),
|
||||||
# Report generation tool (inline, short-lived sessions for DB ops)
|
# Report generation tool (inline, short-lived sessions for DB ops)
|
||||||
# Supports internal KB search via source_strategy so the agent doesn't
|
# Supports internal KB search via source_strategy so the agent does not
|
||||||
# need to call search_knowledge_base separately before generating.
|
# need a separate search step before generating.
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="generate_report",
|
name="generate_report",
|
||||||
description="Generate a structured report from provided content and export it",
|
description="Generate a structured report from provided content and export it",
|
||||||
|
|
@ -551,7 +533,7 @@ def build_tools(
|
||||||
tools = build_tools(deps)
|
tools = build_tools(deps)
|
||||||
|
|
||||||
# Use only specific tools
|
# Use only specific tools
|
||||||
tools = build_tools(deps, enabled_tools=["search_knowledge_base"])
|
tools = build_tools(deps, enabled_tools=["generate_report"])
|
||||||
|
|
||||||
# Use defaults but disable podcast
|
# Use defaults but disable podcast
|
||||||
tools = build_tools(deps, disabled_tools=["generate_podcast"])
|
tools = build_tools(deps, disabled_tools=["generate_podcast"])
|
||||||
|
|
|
||||||
|
|
@ -584,8 +584,8 @@ def create_generate_report_tool(
|
||||||
search_space_id: The user's search space ID
|
search_space_id: The user's search space ID
|
||||||
thread_id: The chat thread ID for associating the report
|
thread_id: The chat thread ID for associating the report
|
||||||
connector_service: Optional connector service for internal KB search.
|
connector_service: Optional connector service for internal KB search.
|
||||||
When provided, the tool can search the knowledge base without the
|
When provided, the tool can search the knowledge base internally
|
||||||
agent having to call search_knowledge_base separately.
|
(used by the "kb_search" and "auto" source strategies).
|
||||||
available_connectors: Optional list of connector types available in the
|
available_connectors: Optional list of connector types available in the
|
||||||
search space (used to scope internal KB searches).
|
search space (used to scope internal KB searches).
|
||||||
|
|
||||||
|
|
@ -639,12 +639,13 @@ def create_generate_report_tool(
|
||||||
|
|
||||||
SOURCE STRATEGY (how to collect source material):
|
SOURCE STRATEGY (how to collect source material):
|
||||||
- source_strategy="conversation" — The conversation already has
|
- source_strategy="conversation" — The conversation already has
|
||||||
enough context (prior Q&A, pasted text, uploaded files, scraped
|
enough context (prior Q&A, filesystem exploration, pasted text,
|
||||||
webpages). Pass a thorough summary as source_content.
|
uploaded files, scraped webpages). Pass a thorough summary as
|
||||||
NEVER call search_knowledge_base separately first.
|
source_content.
|
||||||
- source_strategy="kb_search" — Search the knowledge base
|
- source_strategy="kb_search" — Search the knowledge base
|
||||||
internally. Provide 1-5 targeted search_queries. The tool
|
internally. Provide 1-5 targeted search_queries. The tool
|
||||||
handles searching — do NOT call search_knowledge_base first.
|
handles searching internally — do NOT manually read and dump
|
||||||
|
/documents/ files into source_content.
|
||||||
- source_strategy="provided" — Use only what is in source_content
|
- source_strategy="provided" — Use only what is in source_content
|
||||||
(default, backward-compatible).
|
(default, backward-compatible).
|
||||||
- source_strategy="auto" — Use source_content if it has enough
|
- source_strategy="auto" — Use source_content if it has enough
|
||||||
|
|
@ -1064,6 +1065,7 @@ def create_generate_report_tool(
|
||||||
"title": topic,
|
"title": topic,
|
||||||
"word_count": metadata.get("word_count", 0),
|
"word_count": metadata.get("word_count", 0),
|
||||||
"is_revision": bool(parent_report_content),
|
"is_revision": bool(parent_report_content),
|
||||||
|
"report_markdown": report_content,
|
||||||
"message": f"Report generated successfully: {topic}",
|
"message": f"Report generated successfully: {topic}",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -137,9 +137,7 @@ async def _filter_changes_by_folder(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
parents = file.get("parents", [])
|
parents = file.get("parents", [])
|
||||||
if folder_id in parents:
|
if folder_id in parents or await _is_descendant_of(client, parents, folder_id):
|
||||||
filtered.append(change)
|
|
||||||
elif await _is_descendant_of(client, parents, folder_id):
|
|
||||||
filtered.append(change)
|
filtered.append(change)
|
||||||
|
|
||||||
return filtered
|
return filtered
|
||||||
|
|
|
||||||
|
|
@ -157,7 +157,9 @@ class GoogleDriveClient:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _sync_download_file(
|
def _sync_download_file(
|
||||||
service, file_id: str, credentials: Credentials,
|
service,
|
||||||
|
file_id: str,
|
||||||
|
credentials: Credentials,
|
||||||
) -> tuple[bytes | None, str | None]:
|
) -> tuple[bytes | None, str | None]:
|
||||||
"""Blocking download — runs on a worker thread via ``to_thread``."""
|
"""Blocking download — runs on a worker thread via ``to_thread``."""
|
||||||
thread = threading.current_thread().name
|
thread = threading.current_thread().name
|
||||||
|
|
@ -180,7 +182,9 @@ class GoogleDriveClient:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return None, f"Error downloading file: {e!s}"
|
return None, f"Error downloading file: {e!s}"
|
||||||
finally:
|
finally:
|
||||||
logger.info(f"[download] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s")
|
logger.info(
|
||||||
|
f"[download] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]:
|
async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -194,12 +198,18 @@ class GoogleDriveClient:
|
||||||
"""
|
"""
|
||||||
service = await self.get_service()
|
service = await self.get_service()
|
||||||
return await asyncio.to_thread(
|
return await asyncio.to_thread(
|
||||||
self._sync_download_file, service, file_id, self._resolved_credentials,
|
self._sync_download_file,
|
||||||
|
service,
|
||||||
|
file_id,
|
||||||
|
self._resolved_credentials,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _sync_download_file_to_disk(
|
def _sync_download_file_to_disk(
|
||||||
service, file_id: str, dest_path: str, chunksize: int,
|
service,
|
||||||
|
file_id: str,
|
||||||
|
dest_path: str,
|
||||||
|
chunksize: int,
|
||||||
credentials: Credentials,
|
credentials: Credentials,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""Blocking download-to-disk — runs on a worker thread via ``to_thread``."""
|
"""Blocking download-to-disk — runs on a worker thread via ``to_thread``."""
|
||||||
|
|
@ -223,10 +233,15 @@ class GoogleDriveClient:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error downloading file: {e!s}"
|
return f"Error downloading file: {e!s}"
|
||||||
finally:
|
finally:
|
||||||
logger.info(f"[download-to-disk] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s")
|
logger.info(
|
||||||
|
f"[download-to-disk] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
async def download_file_to_disk(
|
async def download_file_to_disk(
|
||||||
self, file_id: str, dest_path: str, chunksize: int = 5 * 1024 * 1024,
|
self,
|
||||||
|
file_id: str,
|
||||||
|
dest_path: str,
|
||||||
|
chunksize: int = 5 * 1024 * 1024,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""Stream file directly to disk in chunks, avoiding full in-memory buffering.
|
"""Stream file directly to disk in chunks, avoiding full in-memory buffering.
|
||||||
|
|
||||||
|
|
@ -235,12 +250,19 @@ class GoogleDriveClient:
|
||||||
service = await self.get_service()
|
service = await self.get_service()
|
||||||
return await asyncio.to_thread(
|
return await asyncio.to_thread(
|
||||||
self._sync_download_file_to_disk,
|
self._sync_download_file_to_disk,
|
||||||
service, file_id, dest_path, chunksize, self._resolved_credentials,
|
service,
|
||||||
|
file_id,
|
||||||
|
dest_path,
|
||||||
|
chunksize,
|
||||||
|
self._resolved_credentials,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _sync_export_google_file(
|
def _sync_export_google_file(
|
||||||
service, file_id: str, mime_type: str, credentials: Credentials,
|
service,
|
||||||
|
file_id: str,
|
||||||
|
mime_type: str,
|
||||||
|
credentials: Credentials,
|
||||||
) -> tuple[bytes | None, str | None]:
|
) -> tuple[bytes | None, str | None]:
|
||||||
"""Blocking export — runs on a worker thread via ``to_thread``."""
|
"""Blocking export — runs on a worker thread via ``to_thread``."""
|
||||||
thread = threading.current_thread().name
|
thread = threading.current_thread().name
|
||||||
|
|
@ -261,7 +283,9 @@ class GoogleDriveClient:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return None, f"Error exporting file: {e!s}"
|
return None, f"Error exporting file: {e!s}"
|
||||||
finally:
|
finally:
|
||||||
logger.info(f"[export] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s")
|
logger.info(
|
||||||
|
f"[export] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
async def export_google_file(
|
async def export_google_file(
|
||||||
self, file_id: str, mime_type: str
|
self, file_id: str, mime_type: str
|
||||||
|
|
@ -278,7 +302,10 @@ class GoogleDriveClient:
|
||||||
"""
|
"""
|
||||||
service = await self.get_service()
|
service = await self.get_service()
|
||||||
return await asyncio.to_thread(
|
return await asyncio.to_thread(
|
||||||
self._sync_export_google_file, service, file_id, mime_type,
|
self._sync_export_google_file,
|
||||||
|
service,
|
||||||
|
file_id,
|
||||||
|
mime_type,
|
||||||
self._resolved_credentials,
|
self._resolved_credentials,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"""Content extraction for Google Drive files."""
|
"""Content extraction for Google Drive files."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
@ -72,7 +73,11 @@ async def download_and_extract_content(
|
||||||
if is_google_workspace_file(mime_type):
|
if is_google_workspace_file(mime_type):
|
||||||
export_mime = get_export_mime_type(mime_type)
|
export_mime = get_export_mime_type(mime_type)
|
||||||
if not export_mime:
|
if not export_mime:
|
||||||
return None, drive_metadata, f"Cannot export Google Workspace type: {mime_type}"
|
return (
|
||||||
|
None,
|
||||||
|
drive_metadata,
|
||||||
|
f"Cannot export Google Workspace type: {mime_type}",
|
||||||
|
)
|
||||||
content_bytes, error = await client.export_google_file(file_id, export_mime)
|
content_bytes, error = await client.export_google_file(file_id, export_mime)
|
||||||
if error:
|
if error:
|
||||||
return None, drive_metadata, error
|
return None, drive_metadata, error
|
||||||
|
|
@ -83,9 +88,7 @@ async def download_and_extract_content(
|
||||||
temp_file_path = tmp.name
|
temp_file_path = tmp.name
|
||||||
else:
|
else:
|
||||||
extension = (
|
extension = (
|
||||||
Path(file_name).suffix
|
Path(file_name).suffix or get_extension_from_mime(mime_type) or ".bin"
|
||||||
or get_extension_from_mime(mime_type)
|
|
||||||
or ".bin"
|
|
||||||
)
|
)
|
||||||
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp:
|
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp:
|
||||||
temp_file_path = tmp.name
|
temp_file_path = tmp.name
|
||||||
|
|
@ -102,10 +105,8 @@ async def download_and_extract_content(
|
||||||
return None, drive_metadata, str(e)
|
return None, drive_metadata, str(e)
|
||||||
finally:
|
finally:
|
||||||
if temp_file_path and os.path.exists(temp_file_path):
|
if temp_file_path and os.path.exists(temp_file_path):
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
os.unlink(temp_file_path)
|
os.unlink(temp_file_path)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
|
async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
|
||||||
|
|
@ -117,9 +118,10 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|
||||||
if lower.endswith((".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")):
|
if lower.endswith((".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")):
|
||||||
from app.config import config as app_config
|
|
||||||
from litellm import atranscription
|
from litellm import atranscription
|
||||||
|
|
||||||
|
from app.config import config as app_config
|
||||||
|
|
||||||
stt_service_type = (
|
stt_service_type = (
|
||||||
"local"
|
"local"
|
||||||
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
|
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
|
||||||
|
|
@ -127,10 +129,15 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
|
||||||
)
|
)
|
||||||
if stt_service_type == "local":
|
if stt_service_type == "local":
|
||||||
from app.services.stt_service import stt_service
|
from app.services.stt_service import stt_service
|
||||||
|
|
||||||
t0 = time.monotonic()
|
t0 = time.monotonic()
|
||||||
logger.info(f"[local-stt] START file={filename} thread={threading.current_thread().name}")
|
logger.info(
|
||||||
|
f"[local-stt] START file={filename} thread={threading.current_thread().name}"
|
||||||
|
)
|
||||||
result = await asyncio.to_thread(stt_service.transcribe_file, file_path)
|
result = await asyncio.to_thread(stt_service.transcribe_file, file_path)
|
||||||
logger.info(f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s")
|
logger.info(
|
||||||
|
f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s"
|
||||||
|
)
|
||||||
text = result.get("text", "")
|
text = result.get("text", "")
|
||||||
else:
|
else:
|
||||||
with open(file_path, "rb") as audio_file:
|
with open(file_path, "rb") as audio_file:
|
||||||
|
|
@ -153,6 +160,7 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
|
||||||
|
|
||||||
if app_config.ETL_SERVICE == "UNSTRUCTURED":
|
if app_config.ETL_SERVICE == "UNSTRUCTURED":
|
||||||
from langchain_unstructured import UnstructuredLoader
|
from langchain_unstructured import UnstructuredLoader
|
||||||
|
|
||||||
from app.utils.document_converters import convert_document_to_markdown
|
from app.utils.document_converters import convert_document_to_markdown
|
||||||
|
|
||||||
loader = UnstructuredLoader(
|
loader = UnstructuredLoader(
|
||||||
|
|
@ -172,7 +180,9 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
|
||||||
parse_with_llamacloud_retry,
|
parse_with_llamacloud_retry,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await parse_with_llamacloud_retry(file_path=file_path, estimated_pages=50)
|
result = await parse_with_llamacloud_retry(
|
||||||
|
file_path=file_path, estimated_pages=50
|
||||||
|
)
|
||||||
markdown_documents = await result.aget_markdown_documents(split_by_page=False)
|
markdown_documents = await result.aget_markdown_documents(split_by_page=False)
|
||||||
if not markdown_documents:
|
if not markdown_documents:
|
||||||
raise RuntimeError(f"LlamaCloud returned no documents for {filename}")
|
raise RuntimeError(f"LlamaCloud returned no documents for {filename}")
|
||||||
|
|
@ -183,9 +193,13 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
|
||||||
|
|
||||||
converter = DocumentConverter()
|
converter = DocumentConverter()
|
||||||
t0 = time.monotonic()
|
t0 = time.monotonic()
|
||||||
logger.info(f"[docling] START file={filename} thread={threading.current_thread().name}")
|
logger.info(
|
||||||
|
f"[docling] START file={filename} thread={threading.current_thread().name}"
|
||||||
|
)
|
||||||
result = await asyncio.to_thread(converter.convert, file_path)
|
result = await asyncio.to_thread(converter.convert, file_path)
|
||||||
logger.info(f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s")
|
logger.info(
|
||||||
|
f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s"
|
||||||
|
)
|
||||||
return result.document.export_to_markdown()
|
return result.document.export_to_markdown()
|
||||||
|
|
||||||
raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}")
|
raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}")
|
||||||
|
|
@ -249,9 +263,7 @@ async def download_and_process_file(
|
||||||
return None, error
|
return None, error
|
||||||
|
|
||||||
extension = (
|
extension = (
|
||||||
Path(file_name).suffix
|
Path(file_name).suffix or get_extension_from_mime(mime_type) or ".bin"
|
||||||
or get_extension_from_mime(mime_type)
|
|
||||||
or ".bin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp_file:
|
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp_file:
|
||||||
|
|
@ -290,9 +302,7 @@ async def download_and_process_file(
|
||||||
connector_info["metadata"]["md5_checksum"] = file["md5Checksum"]
|
connector_info["metadata"]["md5_checksum"] = file["md5Checksum"]
|
||||||
|
|
||||||
if is_google_workspace_file(mime_type):
|
if is_google_workspace_file(mime_type):
|
||||||
export_ext = get_extension_from_mime(
|
export_ext = get_extension_from_mime(get_export_mime_type(mime_type) or "")
|
||||||
get_export_mime_type(mime_type) or ""
|
|
||||||
)
|
|
||||||
connector_info["metadata"]["exported_as"] = (
|
connector_info["metadata"]["exported_as"] = (
|
||||||
export_ext.lstrip(".") if export_ext else "pdf"
|
export_ext.lstrip(".") if export_ext else "pdf"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,9 @@ def compute_identifier_hash(
|
||||||
|
|
||||||
def compute_unique_identifier_hash(doc: ConnectorDocument) -> str:
|
def compute_unique_identifier_hash(doc: ConnectorDocument) -> str:
|
||||||
"""Return a stable SHA-256 hash identifying a document by its source identity."""
|
"""Return a stable SHA-256 hash identifying a document by its source identity."""
|
||||||
return compute_identifier_hash(doc.document_type.value, doc.unique_id, doc.search_space_id)
|
return compute_identifier_hash(
|
||||||
|
doc.document_type.value, doc.unique_id, doc.search_space_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_content_hash(doc: ConnectorDocument) -> str:
|
def compute_content_hash(doc: ConnectorDocument) -> str:
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,23 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from sqlalchemy import delete, select
|
from sqlalchemy import delete, select
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db import NATIVE_TO_LEGACY_DOCTYPE, Chunk, Document, DocumentStatus
|
from app.db import (
|
||||||
|
NATIVE_TO_LEGACY_DOCTYPE,
|
||||||
|
Chunk,
|
||||||
|
Document,
|
||||||
|
DocumentStatus,
|
||||||
|
DocumentType,
|
||||||
|
)
|
||||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||||
from app.indexing_pipeline.document_chunker import chunk_text
|
from app.indexing_pipeline.document_chunker import chunk_text
|
||||||
from app.indexing_pipeline.document_embedder import embed_texts
|
from app.indexing_pipeline.document_embedder import embed_texts
|
||||||
|
|
@ -52,12 +60,114 @@ from app.indexing_pipeline.pipeline_logger import (
|
||||||
from app.utils.perf import get_perf_logger
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PlaceholderInfo:
|
||||||
|
"""Minimal info to create a placeholder document row for instant UI feedback.
|
||||||
|
|
||||||
|
These are created immediately when items are discovered (before content
|
||||||
|
extraction) so users see them in the UI via Zero sync right away.
|
||||||
|
"""
|
||||||
|
|
||||||
|
title: str
|
||||||
|
document_type: DocumentType
|
||||||
|
unique_id: str
|
||||||
|
search_space_id: int
|
||||||
|
connector_id: int | None
|
||||||
|
created_by_id: str
|
||||||
|
metadata: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class IndexingPipelineService:
|
class IndexingPipelineService:
|
||||||
"""Single pipeline for indexing connector documents. All connectors use this service."""
|
"""Single pipeline for indexing connector documents. All connectors use this service."""
|
||||||
|
|
||||||
def __init__(self, session: AsyncSession) -> None:
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
self.session = session
|
self.session = session
|
||||||
|
|
||||||
|
async def create_placeholder_documents(
|
||||||
|
self, placeholders: list[PlaceholderInfo]
|
||||||
|
) -> int:
|
||||||
|
"""Create placeholder document rows with pending status for instant UI feedback.
|
||||||
|
|
||||||
|
These rows appear immediately in the UI via Zero sync. They are later
|
||||||
|
updated by prepare_for_indexing() when actual content is available.
|
||||||
|
|
||||||
|
Returns the number of placeholders successfully created.
|
||||||
|
Failures are logged but never block the main indexing flow.
|
||||||
|
|
||||||
|
NOTE: This method commits on ``self.session`` so the rows become
|
||||||
|
visible to Zero sync immediately. Any pending ORM mutations on the
|
||||||
|
session are committed together, which is consistent with how other
|
||||||
|
mid-flow commits work in the indexing codebase (e.g. rename-only
|
||||||
|
updates in ``_should_skip_file``, ``migrate_legacy_docs``).
|
||||||
|
"""
|
||||||
|
if not placeholders:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
uid_hashes: dict[str, PlaceholderInfo] = {}
|
||||||
|
for p in placeholders:
|
||||||
|
try:
|
||||||
|
uid_hash = compute_identifier_hash(
|
||||||
|
p.document_type.value, p.unique_id, p.search_space_id
|
||||||
|
)
|
||||||
|
uid_hashes.setdefault(uid_hash, p)
|
||||||
|
except Exception:
|
||||||
|
_logger.debug(
|
||||||
|
"Skipping placeholder hash for %s", p.unique_id, exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if not uid_hashes:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
result = await self.session.execute(
|
||||||
|
select(Document.unique_identifier_hash).where(
|
||||||
|
Document.unique_identifier_hash.in_(list(uid_hashes.keys()))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing_hashes: set[str] = set(result.scalars().all())
|
||||||
|
|
||||||
|
created = 0
|
||||||
|
for uid_hash, p in uid_hashes.items():
|
||||||
|
if uid_hash in existing_hashes:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
content_hash = hashlib.sha256(
|
||||||
|
f"placeholder:{uid_hash}".encode()
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
document = Document(
|
||||||
|
title=p.title,
|
||||||
|
document_type=p.document_type,
|
||||||
|
content="Pending...",
|
||||||
|
content_hash=content_hash,
|
||||||
|
unique_identifier_hash=uid_hash,
|
||||||
|
document_metadata=p.metadata or {},
|
||||||
|
search_space_id=p.search_space_id,
|
||||||
|
connector_id=p.connector_id,
|
||||||
|
created_by_id=p.created_by_id,
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
status=DocumentStatus.pending(),
|
||||||
|
)
|
||||||
|
self.session.add(document)
|
||||||
|
created += 1
|
||||||
|
except Exception:
|
||||||
|
_logger.debug("Skipping placeholder for %s", p.unique_id, exc_info=True)
|
||||||
|
|
||||||
|
if created > 0:
|
||||||
|
try:
|
||||||
|
await self.session.commit()
|
||||||
|
_logger.info(
|
||||||
|
"Created %d placeholder document(s) for instant UI feedback",
|
||||||
|
created,
|
||||||
|
)
|
||||||
|
except IntegrityError:
|
||||||
|
await self.session.rollback()
|
||||||
|
_logger.debug("Placeholder commit failed (race condition), continuing")
|
||||||
|
created = 0
|
||||||
|
|
||||||
|
return created
|
||||||
|
|
||||||
async def migrate_legacy_docs(
|
async def migrate_legacy_docs(
|
||||||
self, connector_docs: list[ConnectorDocument]
|
self, connector_docs: list[ConnectorDocument]
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -77,9 +187,7 @@ class IndexingPipelineService:
|
||||||
legacy_type, doc.unique_id, doc.search_space_id
|
legacy_type, doc.unique_id, doc.search_space_id
|
||||||
)
|
)
|
||||||
result = await self.session.execute(
|
result = await self.session.execute(
|
||||||
select(Document).filter(
|
select(Document).filter(Document.unique_identifier_hash == legacy_hash)
|
||||||
Document.unique_identifier_hash == legacy_hash
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
existing = result.scalars().first()
|
existing = result.scalars().first()
|
||||||
if existing is None:
|
if existing is None:
|
||||||
|
|
@ -101,9 +209,7 @@ class IndexingPipelineService:
|
||||||
Indexers that need heartbeat callbacks or custom per-document logic
|
Indexers that need heartbeat callbacks or custom per-document logic
|
||||||
should call prepare_for_indexing() + index() directly instead.
|
should call prepare_for_indexing() + index() directly instead.
|
||||||
"""
|
"""
|
||||||
doc_map = {
|
doc_map = {compute_unique_identifier_hash(cd): cd for cd in connector_docs}
|
||||||
compute_unique_identifier_hash(cd): cd for cd in connector_docs
|
|
||||||
}
|
|
||||||
documents = await self.prepare_for_indexing(connector_docs)
|
documents = await self.prepare_for_indexing(connector_docs)
|
||||||
results: list[Document] = []
|
results: list[Document] = []
|
||||||
for document in documents:
|
for document in documents:
|
||||||
|
|
@ -166,6 +272,21 @@ class IndexingPipelineService:
|
||||||
log_document_requeued(ctx)
|
log_document_requeued(ctx)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
dup_check = await self.session.execute(
|
||||||
|
select(Document.id).filter(
|
||||||
|
Document.content_hash == content_hash,
|
||||||
|
Document.id != existing.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if dup_check.scalars().first() is not None:
|
||||||
|
if not DocumentStatus.is_state(
|
||||||
|
existing.status, DocumentStatus.READY
|
||||||
|
):
|
||||||
|
existing.status = DocumentStatus.failed(
|
||||||
|
"Duplicate content — already indexed by another document"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
existing.title = connector_doc.title
|
existing.title = connector_doc.title
|
||||||
existing.content_hash = content_hash
|
existing.content_hash = content_hash
|
||||||
existing.source_markdown = connector_doc.source_markdown
|
existing.source_markdown = connector_doc.source_markdown
|
||||||
|
|
@ -349,9 +470,7 @@ class IndexingPipelineService:
|
||||||
perf = get_perf_logger()
|
perf = get_perf_logger()
|
||||||
t_total = time.perf_counter()
|
t_total = time.perf_counter()
|
||||||
|
|
||||||
doc_map = {
|
doc_map = {compute_unique_identifier_hash(cd): cd for cd in connector_docs}
|
||||||
compute_unique_identifier_hash(cd): cd for cd in connector_docs
|
|
||||||
}
|
|
||||||
documents = await self.prepare_for_indexing(connector_docs)
|
documents = await self.prepare_for_indexing(connector_docs)
|
||||||
|
|
||||||
if not documents:
|
if not documents:
|
||||||
|
|
@ -383,9 +502,7 @@ class IndexingPipelineService:
|
||||||
session_maker = get_celery_session_maker()
|
session_maker = get_celery_session_maker()
|
||||||
async with session_maker() as isolated_session:
|
async with session_maker() as isolated_session:
|
||||||
try:
|
try:
|
||||||
refetched = await isolated_session.get(
|
refetched = await isolated_session.get(Document, document.id)
|
||||||
Document, document.id
|
|
||||||
)
|
|
||||||
if refetched is None:
|
if refetched is None:
|
||||||
async with lock:
|
async with lock:
|
||||||
failed_count += 1
|
failed_count += 1
|
||||||
|
|
@ -393,9 +510,7 @@ class IndexingPipelineService:
|
||||||
|
|
||||||
llm = await get_llm(isolated_session)
|
llm = await get_llm(isolated_session)
|
||||||
iso_pipeline = IndexingPipelineService(isolated_session)
|
iso_pipeline = IndexingPipelineService(isolated_session)
|
||||||
result = await iso_pipeline.index(
|
result = await iso_pipeline.index(refetched, connector_doc, llm)
|
||||||
refetched, connector_doc, llm
|
|
||||||
)
|
|
||||||
|
|
||||||
async with lock:
|
async with lock:
|
||||||
if DocumentStatus.is_state(
|
if DocumentStatus.is_state(
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from datetime import datetime
|
||||||
|
|
||||||
from app.utils.perf import get_perf_logger
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
_MAX_FETCH_CHUNKS_PER_DOC = 30
|
_MAX_FETCH_CHUNKS_PER_DOC = 20
|
||||||
|
|
||||||
|
|
||||||
class ChucksHybridSearchRetriever:
|
class ChucksHybridSearchRetriever:
|
||||||
|
|
@ -185,7 +185,7 @@ class ChucksHybridSearchRetriever:
|
||||||
- chunks: list[{chunk_id, content}] for citation-aware prompting
|
- chunks: list[{chunk_id, content}] for citation-aware prompting
|
||||||
- document: {id, title, document_type, metadata}
|
- document: {id, title, document_type, metadata}
|
||||||
"""
|
"""
|
||||||
from sqlalchemy import func, select, text
|
from sqlalchemy import func, or_, select, text
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
|
|
@ -360,64 +360,81 @@ class ChucksHybridSearchRetriever:
|
||||||
if not doc_ids:
|
if not doc_ids:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Fetch chunks for selected documents. We cap per document to avoid
|
# Collect document metadata from the small RRF result set (already
|
||||||
# loading hundreds of chunks for a single large file while still
|
# loaded via joinedload) so the bulk chunk fetch can skip the expensive
|
||||||
# ensuring the chunks that matched the RRF query are always included.
|
# Document JOIN entirely.
|
||||||
chunk_query = (
|
|
||||||
select(Chunk)
|
|
||||||
.options(joinedload(Chunk.document))
|
|
||||||
.join(Document, Chunk.document_id == Document.id)
|
|
||||||
.where(Document.id.in_(doc_ids))
|
|
||||||
.where(*base_conditions)
|
|
||||||
.order_by(Chunk.document_id, Chunk.id)
|
|
||||||
)
|
|
||||||
chunks_result = await self.db_session.execute(chunk_query)
|
|
||||||
raw_chunks = chunks_result.scalars().all()
|
|
||||||
|
|
||||||
matched_chunk_ids: set[int] = {
|
matched_chunk_ids: set[int] = {
|
||||||
item["chunk_id"] for item in serialized_chunk_results
|
item["chunk_id"] for item in serialized_chunk_results
|
||||||
}
|
}
|
||||||
|
doc_meta_cache: dict[int, dict] = {}
|
||||||
|
for item in serialized_chunk_results:
|
||||||
|
did = item["document"]["id"]
|
||||||
|
if did not in doc_meta_cache:
|
||||||
|
doc_meta_cache[did] = item["document"]
|
||||||
|
|
||||||
doc_chunk_counts: dict[int, int] = {}
|
# SQL-level per-document chunk limit using ROW_NUMBER().
|
||||||
all_chunks: list = []
|
# Avoids loading hundreds of chunks per large document only to
|
||||||
for chunk in raw_chunks:
|
# discard them in Python.
|
||||||
did = chunk.document_id
|
numbered = (
|
||||||
count = doc_chunk_counts.get(did, 0)
|
select(
|
||||||
if chunk.id in matched_chunk_ids or count < _MAX_FETCH_CHUNKS_PER_DOC:
|
Chunk.id.label("chunk_id"),
|
||||||
all_chunks.append(chunk)
|
func.row_number()
|
||||||
doc_chunk_counts[did] = count + 1
|
.over(partition_by=Chunk.document_id, order_by=Chunk.id)
|
||||||
|
.label("rn"),
|
||||||
|
)
|
||||||
|
.where(Chunk.document_id.in_(doc_ids))
|
||||||
|
.subquery("numbered")
|
||||||
|
)
|
||||||
|
|
||||||
# Assemble final doc-grouped results in the same order as doc_ids
|
matched_list = list(matched_chunk_ids)
|
||||||
|
if matched_list:
|
||||||
|
chunk_filter = or_(
|
||||||
|
numbered.c.rn <= _MAX_FETCH_CHUNKS_PER_DOC,
|
||||||
|
Chunk.id.in_(matched_list),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk_filter = numbered.c.rn <= _MAX_FETCH_CHUNKS_PER_DOC
|
||||||
|
|
||||||
|
# Select only the columns we need (skip Chunk.embedding ~12KB/row).
|
||||||
|
chunk_query = (
|
||||||
|
select(Chunk.id, Chunk.content, Chunk.document_id)
|
||||||
|
.join(numbered, Chunk.id == numbered.c.chunk_id)
|
||||||
|
.where(chunk_filter)
|
||||||
|
.order_by(Chunk.document_id, Chunk.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
t_fetch = time.perf_counter()
|
||||||
|
chunks_result = await self.db_session.execute(chunk_query)
|
||||||
|
fetched_chunks = chunks_result.all()
|
||||||
|
perf.debug(
|
||||||
|
"[chunk_search] chunk fetch in %.3fs rows=%d",
|
||||||
|
time.perf_counter() - t_fetch,
|
||||||
|
len(fetched_chunks),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assemble final doc-grouped results in the same order as doc_ids,
|
||||||
|
# using pre-cached doc metadata instead of joinedload.
|
||||||
doc_map: dict[int, dict] = {
|
doc_map: dict[int, dict] = {
|
||||||
doc_id: {
|
doc_id: {
|
||||||
"document_id": doc_id,
|
"document_id": doc_id,
|
||||||
"content": "",
|
"content": "",
|
||||||
"score": float(doc_scores.get(doc_id, 0.0)),
|
"score": float(doc_scores.get(doc_id, 0.0)),
|
||||||
"chunks": [],
|
"chunks": [],
|
||||||
"document": {},
|
"matched_chunk_ids": [],
|
||||||
"source": None,
|
"document": doc_meta_cache.get(doc_id, {}),
|
||||||
|
"source": (doc_meta_cache.get(doc_id) or {}).get("document_type"),
|
||||||
}
|
}
|
||||||
for doc_id in doc_ids
|
for doc_id in doc_ids
|
||||||
}
|
}
|
||||||
|
|
||||||
for chunk in all_chunks:
|
for row in fetched_chunks:
|
||||||
doc = chunk.document
|
doc_id = row.document_id
|
||||||
doc_id = doc.id
|
|
||||||
if doc_id not in doc_map:
|
if doc_id not in doc_map:
|
||||||
continue
|
continue
|
||||||
doc_entry = doc_map[doc_id]
|
doc_entry = doc_map[doc_id]
|
||||||
doc_entry["document"] = {
|
doc_entry["chunks"].append({"chunk_id": row.id, "content": row.content})
|
||||||
"id": doc.id,
|
if row.id in matched_chunk_ids:
|
||||||
"title": doc.title,
|
doc_entry["matched_chunk_ids"].append(row.id)
|
||||||
"document_type": doc.document_type.value
|
|
||||||
if getattr(doc, "document_type", None)
|
|
||||||
else None,
|
|
||||||
"metadata": doc.document_metadata or {},
|
|
||||||
}
|
|
||||||
doc_entry["source"] = (
|
|
||||||
doc.document_type.value if getattr(doc, "document_type", None) else None
|
|
||||||
)
|
|
||||||
doc_entry["chunks"].append({"chunk_id": chunk.id, "content": chunk.content})
|
|
||||||
|
|
||||||
# Fill concatenated content (useful for reranking)
|
# Fill concatenated content (useful for reranking)
|
||||||
final_docs: list[dict] = []
|
final_docs: list[dict] = []
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from datetime import datetime
|
||||||
|
|
||||||
from app.utils.perf import get_perf_logger
|
from app.utils.perf import get_perf_logger
|
||||||
|
|
||||||
_MAX_FETCH_CHUNKS_PER_DOC = 30
|
_MAX_FETCH_CHUNKS_PER_DOC = 20
|
||||||
|
|
||||||
|
|
||||||
class DocumentHybridSearchRetriever:
|
class DocumentHybridSearchRetriever:
|
||||||
|
|
@ -289,57 +289,77 @@ class DocumentHybridSearchRetriever:
|
||||||
if not documents_with_scores:
|
if not documents_with_scores:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Collect document IDs for chunk fetching
|
# Collect document IDs and pre-cache metadata from the small RRF
|
||||||
|
# result set so the bulk chunk fetch can skip joinedload entirely.
|
||||||
doc_ids: list[int] = [doc.id for doc, _score in documents_with_scores]
|
doc_ids: list[int] = [doc.id for doc, _score in documents_with_scores]
|
||||||
|
doc_meta_cache: dict[int, dict] = {}
|
||||||
# Fetch chunks for these documents, capped per document to avoid
|
doc_score_cache: dict[int, float] = {}
|
||||||
# loading hundreds of chunks for a single large file.
|
doc_source_cache: dict[int, str | None] = {}
|
||||||
chunks_query = (
|
for doc, score in documents_with_scores:
|
||||||
select(Chunk)
|
doc_meta_cache[doc.id] = {
|
||||||
.options(joinedload(Chunk.document))
|
"id": doc.id,
|
||||||
.where(Chunk.document_id.in_(doc_ids))
|
"title": doc.title,
|
||||||
.order_by(Chunk.document_id, Chunk.id)
|
"document_type": doc.document_type.value
|
||||||
)
|
|
||||||
chunks_result = await self.db_session.execute(chunks_query)
|
|
||||||
raw_chunks = chunks_result.scalars().all()
|
|
||||||
|
|
||||||
doc_chunk_counts: dict[int, int] = {}
|
|
||||||
chunks: list = []
|
|
||||||
for chunk in raw_chunks:
|
|
||||||
did = chunk.document_id
|
|
||||||
count = doc_chunk_counts.get(did, 0)
|
|
||||||
if count < _MAX_FETCH_CHUNKS_PER_DOC:
|
|
||||||
chunks.append(chunk)
|
|
||||||
doc_chunk_counts[did] = count + 1
|
|
||||||
|
|
||||||
# Assemble doc-grouped results
|
|
||||||
doc_map: dict[int, dict] = {
|
|
||||||
doc.id: {
|
|
||||||
"document_id": doc.id,
|
|
||||||
"content": "",
|
|
||||||
"score": float(score),
|
|
||||||
"chunks": [],
|
|
||||||
"document": {
|
|
||||||
"id": doc.id,
|
|
||||||
"title": doc.title,
|
|
||||||
"document_type": doc.document_type.value
|
|
||||||
if getattr(doc, "document_type", None)
|
|
||||||
else None,
|
|
||||||
"metadata": doc.document_metadata or {},
|
|
||||||
},
|
|
||||||
"source": doc.document_type.value
|
|
||||||
if getattr(doc, "document_type", None)
|
if getattr(doc, "document_type", None)
|
||||||
else None,
|
else None,
|
||||||
|
"metadata": doc.document_metadata or {},
|
||||||
}
|
}
|
||||||
for doc, score in documents_with_scores
|
doc_score_cache[doc.id] = float(score)
|
||||||
|
doc_source_cache[doc.id] = (
|
||||||
|
doc.document_type.value if getattr(doc, "document_type", None) else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# SQL-level per-document chunk limit using ROW_NUMBER().
|
||||||
|
# Avoids loading hundreds of chunks per large document only to
|
||||||
|
# discard them in Python.
|
||||||
|
numbered = (
|
||||||
|
select(
|
||||||
|
Chunk.id.label("chunk_id"),
|
||||||
|
func.row_number()
|
||||||
|
.over(partition_by=Chunk.document_id, order_by=Chunk.id)
|
||||||
|
.label("rn"),
|
||||||
|
)
|
||||||
|
.where(Chunk.document_id.in_(doc_ids))
|
||||||
|
.subquery("numbered")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select only the columns we need (skip Chunk.embedding ~12KB/row).
|
||||||
|
chunks_query = (
|
||||||
|
select(Chunk.id, Chunk.content, Chunk.document_id)
|
||||||
|
.join(numbered, Chunk.id == numbered.c.chunk_id)
|
||||||
|
.where(numbered.c.rn <= _MAX_FETCH_CHUNKS_PER_DOC)
|
||||||
|
.order_by(Chunk.document_id, Chunk.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
t_fetch = time.perf_counter()
|
||||||
|
chunks_result = await self.db_session.execute(chunks_query)
|
||||||
|
fetched_chunks = chunks_result.all()
|
||||||
|
perf.debug(
|
||||||
|
"[doc_search] chunk fetch in %.3fs rows=%d",
|
||||||
|
time.perf_counter() - t_fetch,
|
||||||
|
len(fetched_chunks),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assemble doc-grouped results using pre-cached metadata.
|
||||||
|
doc_map: dict[int, dict] = {
|
||||||
|
doc_id: {
|
||||||
|
"document_id": doc_id,
|
||||||
|
"content": "",
|
||||||
|
"score": doc_score_cache.get(doc_id, 0.0),
|
||||||
|
"chunks": [],
|
||||||
|
"matched_chunk_ids": [],
|
||||||
|
"document": doc_meta_cache.get(doc_id, {}),
|
||||||
|
"source": doc_source_cache.get(doc_id),
|
||||||
|
}
|
||||||
|
for doc_id in doc_ids
|
||||||
}
|
}
|
||||||
|
|
||||||
for chunk in chunks:
|
for row in fetched_chunks:
|
||||||
doc_id = chunk.document_id
|
doc_id = row.document_id
|
||||||
if doc_id not in doc_map:
|
if doc_id not in doc_map:
|
||||||
continue
|
continue
|
||||||
doc_map[doc_id]["chunks"].append(
|
doc_map[doc_id]["chunks"].append(
|
||||||
{"chunk_id": chunk.id, "content": chunk.content}
|
{"chunk_id": row.id, "content": row.content}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fill concatenated content (useful for reranking)
|
# Fill concatenated content (useful for reranking)
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ import asyncio
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import tempfile
|
import tempfile
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
@ -22,9 +21,9 @@ from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from app.db import Document, DocumentType, Permission, User, get_async_session
|
from app.db import Document, DocumentType, Permission, User, get_async_session
|
||||||
from app.routes.reports_routes import (
|
from app.routes.reports_routes import (
|
||||||
ExportFormat,
|
|
||||||
_FILE_EXTENSIONS,
|
_FILE_EXTENSIONS,
|
||||||
_MEDIA_TYPES,
|
_MEDIA_TYPES,
|
||||||
|
ExportFormat,
|
||||||
_normalize_latex_delimiters,
|
_normalize_latex_delimiters,
|
||||||
_strip_wrapping_code_fences,
|
_strip_wrapping_code_fences,
|
||||||
)
|
)
|
||||||
|
|
@ -238,9 +237,7 @@ async def save_document(
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get("/search-spaces/{search_space_id}/documents/{document_id}/export")
|
||||||
"/search-spaces/{search_space_id}/documents/{document_id}/export"
|
|
||||||
)
|
|
||||||
async def export_document(
|
async def export_document(
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
document_id: int,
|
document_id: int,
|
||||||
|
|
@ -284,9 +281,7 @@ async def export_document(
|
||||||
markdown_content = "\n\n".join(chunk.content for chunk in chunks)
|
markdown_content = "\n\n".join(chunk.content for chunk in chunks)
|
||||||
|
|
||||||
if not markdown_content or not markdown_content.strip():
|
if not markdown_content or not markdown_content.strip():
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=400, detail="Document has no content to export")
|
||||||
status_code=400, detail="Document has no content to export"
|
|
||||||
)
|
|
||||||
|
|
||||||
markdown_content = _strip_wrapping_code_fences(markdown_content)
|
markdown_content = _strip_wrapping_code_fences(markdown_content)
|
||||||
markdown_content = _normalize_latex_delimiters(markdown_content)
|
markdown_content = _normalize_latex_delimiters(markdown_content)
|
||||||
|
|
@ -308,8 +303,10 @@ async def export_document(
|
||||||
extra_args=[
|
extra_args=[
|
||||||
"--standalone",
|
"--standalone",
|
||||||
f"--template={typst_template}",
|
f"--template={typst_template}",
|
||||||
"-V", "mainfont:Libertinus Serif",
|
"-V",
|
||||||
"-V", "codefont:DejaVu Sans Mono",
|
"mainfont:Libertinus Serif",
|
||||||
|
"-V",
|
||||||
|
"codefont:DejaVu Sans Mono",
|
||||||
*meta_args,
|
*meta_args,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -318,7 +315,11 @@ async def export_document(
|
||||||
if format == ExportFormat.DOCX:
|
if format == ExportFormat.DOCX:
|
||||||
return _pandoc_to_tempfile(
|
return _pandoc_to_tempfile(
|
||||||
format.value,
|
format.value,
|
||||||
["--standalone", f"--reference-doc={get_reference_docx_path()}", *meta_args],
|
[
|
||||||
|
"--standalone",
|
||||||
|
f"--reference-doc={get_reference_docx_path()}",
|
||||||
|
*meta_args,
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
if format == ExportFormat.HTML:
|
if format == ExportFormat.HTML:
|
||||||
|
|
@ -327,7 +328,8 @@ async def export_document(
|
||||||
"html5",
|
"html5",
|
||||||
format=input_fmt,
|
format=input_fmt,
|
||||||
extra_args=[
|
extra_args=[
|
||||||
"--standalone", "--embed-resources",
|
"--standalone",
|
||||||
|
"--embed-resources",
|
||||||
f"--css={get_html_css_path()}",
|
f"--css={get_html_css_path()}",
|
||||||
"--syntax-highlighting=pygments",
|
"--syntax-highlighting=pygments",
|
||||||
*meta_args,
|
*meta_args,
|
||||||
|
|
@ -343,13 +345,17 @@ async def export_document(
|
||||||
|
|
||||||
if format == ExportFormat.LATEX:
|
if format == ExportFormat.LATEX:
|
||||||
tex_str: str = pypandoc.convert_text(
|
tex_str: str = pypandoc.convert_text(
|
||||||
markdown_content, "latex", format=input_fmt,
|
markdown_content,
|
||||||
|
"latex",
|
||||||
|
format=input_fmt,
|
||||||
extra_args=["--standalone", *meta_args],
|
extra_args=["--standalone", *meta_args],
|
||||||
)
|
)
|
||||||
return tex_str.encode("utf-8")
|
return tex_str.encode("utf-8")
|
||||||
|
|
||||||
plain_str: str = pypandoc.convert_text(
|
plain_str: str = pypandoc.convert_text(
|
||||||
markdown_content, "plain", format=input_fmt,
|
markdown_content,
|
||||||
|
"plain",
|
||||||
|
format=input_fmt,
|
||||||
extra_args=["--wrap=auto", "--columns=80"],
|
extra_args=["--wrap=auto", "--columns=80"],
|
||||||
)
|
)
|
||||||
return plain_str.encode("utf-8")
|
return plain_str.encode("utf-8")
|
||||||
|
|
@ -359,8 +365,11 @@ async def export_document(
|
||||||
os.close(fd)
|
os.close(fd)
|
||||||
try:
|
try:
|
||||||
pypandoc.convert_text(
|
pypandoc.convert_text(
|
||||||
markdown_content, output_format, format=input_fmt,
|
markdown_content,
|
||||||
extra_args=extra_args, outputfile=tmp_path,
|
output_format,
|
||||||
|
format=input_fmt,
|
||||||
|
extra_args=extra_args,
|
||||||
|
outputfile=tmp_path,
|
||||||
)
|
)
|
||||||
with open(tmp_path, "rb") as f:
|
with open(tmp_path, "rb") as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|
@ -375,8 +384,7 @@ async def export_document(
|
||||||
raise HTTPException(status_code=500, detail=f"Export failed: {e!s}") from e
|
raise HTTPException(status_code=500, detail=f"Export failed: {e!s}") from e
|
||||||
|
|
||||||
safe_title = (
|
safe_title = (
|
||||||
"".join(c if c.isalnum() or c in " -_" else "_" for c in doc_title)
|
"".join(c if c.isalnum() or c in " -_" else "_" for c in doc_title).strip()[:80]
|
||||||
.strip()[:80]
|
|
||||||
or "document"
|
or "document"
|
||||||
)
|
)
|
||||||
ext = _FILE_EXTENSIONS[format]
|
ext = _FILE_EXTENSIONS[format]
|
||||||
|
|
|
||||||
|
|
@ -2406,7 +2406,11 @@ async def run_google_drive_indexing(
|
||||||
if items.files:
|
if items.files:
|
||||||
try:
|
try:
|
||||||
file_tuples = [(f.id, f.name) for f in items.files]
|
file_tuples = [(f.id, f.name) for f in items.files]
|
||||||
indexed_count, _skipped, file_errors = await index_google_drive_selected_files(
|
(
|
||||||
|
indexed_count,
|
||||||
|
_skipped,
|
||||||
|
file_errors,
|
||||||
|
) = await index_google_drive_selected_files(
|
||||||
session,
|
session,
|
||||||
connector_id,
|
connector_id,
|
||||||
search_space_id,
|
search_space_id,
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ Supports loading LLM configurations from:
|
||||||
- NewLLMConfig database table (positive IDs for user-created configs with prompt settings)
|
- NewLLMConfig database table (positive IDs for user-created configs with prompt settings)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import ast
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import gc
|
import gc
|
||||||
|
|
@ -36,10 +37,6 @@ from app.agents.new_chat.llm_config import (
|
||||||
load_agent_config,
|
load_agent_config,
|
||||||
load_llm_config_from_yaml,
|
load_llm_config_from_yaml,
|
||||||
)
|
)
|
||||||
from app.agents.new_chat.sandbox import (
|
|
||||||
get_or_create_sandbox,
|
|
||||||
is_sandbox_enabled,
|
|
||||||
)
|
|
||||||
from app.db import (
|
from app.db import (
|
||||||
ChatVisibility,
|
ChatVisibility,
|
||||||
Document,
|
Document,
|
||||||
|
|
@ -212,7 +209,7 @@ class StreamResult:
|
||||||
accumulated_text: str = ""
|
accumulated_text: str = ""
|
||||||
is_interrupted: bool = False
|
is_interrupted: bool = False
|
||||||
interrupt_value: dict[str, Any] | None = None
|
interrupt_value: dict[str, Any] | None = None
|
||||||
sandbox_files: list[str] = field(default_factory=list)
|
sandbox_files: list[str] = field(default_factory=list) # unused, kept for compat
|
||||||
|
|
||||||
|
|
||||||
async def _stream_agent_events(
|
async def _stream_agent_events(
|
||||||
|
|
@ -281,6 +278,8 @@ async def _stream_agent_events(
|
||||||
if event_type == "on_chat_model_stream":
|
if event_type == "on_chat_model_stream":
|
||||||
if active_tool_depth > 0:
|
if active_tool_depth > 0:
|
||||||
continue # Suppress inner-tool LLM tokens from leaking into chat
|
continue # Suppress inner-tool LLM tokens from leaking into chat
|
||||||
|
if "surfsense:internal" in event.get("tags", []):
|
||||||
|
continue # Suppress middleware-internal LLM tokens (e.g. KB search classification)
|
||||||
chunk = event.get("data", {}).get("chunk")
|
chunk = event.get("data", {}).get("chunk")
|
||||||
if chunk and hasattr(chunk, "content"):
|
if chunk and hasattr(chunk, "content"):
|
||||||
content = chunk.content
|
content = chunk.content
|
||||||
|
|
@ -319,19 +318,114 @@ async def _stream_agent_events(
|
||||||
tool_step_ids[run_id] = tool_step_id
|
tool_step_ids[run_id] = tool_step_id
|
||||||
last_active_step_id = tool_step_id
|
last_active_step_id = tool_step_id
|
||||||
|
|
||||||
if tool_name == "search_knowledge_base":
|
if tool_name == "ls":
|
||||||
query = (
|
ls_path = (
|
||||||
tool_input.get("query", "")
|
tool_input.get("path", "/")
|
||||||
if isinstance(tool_input, dict)
|
if isinstance(tool_input, dict)
|
||||||
else str(tool_input)
|
else str(tool_input)
|
||||||
)
|
)
|
||||||
last_active_step_title = "Searching knowledge base"
|
last_active_step_title = "Listing files"
|
||||||
|
last_active_step_items = [ls_path]
|
||||||
|
yield streaming_service.format_thinking_step(
|
||||||
|
step_id=tool_step_id,
|
||||||
|
title="Listing files",
|
||||||
|
status="in_progress",
|
||||||
|
items=last_active_step_items,
|
||||||
|
)
|
||||||
|
elif tool_name == "read_file":
|
||||||
|
fp = (
|
||||||
|
tool_input.get("file_path", "")
|
||||||
|
if isinstance(tool_input, dict)
|
||||||
|
else str(tool_input)
|
||||||
|
)
|
||||||
|
display_fp = fp if len(fp) <= 80 else "…" + fp[-77:]
|
||||||
|
last_active_step_title = "Reading file"
|
||||||
|
last_active_step_items = [display_fp]
|
||||||
|
yield streaming_service.format_thinking_step(
|
||||||
|
step_id=tool_step_id,
|
||||||
|
title="Reading file",
|
||||||
|
status="in_progress",
|
||||||
|
items=last_active_step_items,
|
||||||
|
)
|
||||||
|
elif tool_name == "write_file":
|
||||||
|
fp = (
|
||||||
|
tool_input.get("file_path", "")
|
||||||
|
if isinstance(tool_input, dict)
|
||||||
|
else str(tool_input)
|
||||||
|
)
|
||||||
|
display_fp = fp if len(fp) <= 80 else "…" + fp[-77:]
|
||||||
|
last_active_step_title = "Writing file"
|
||||||
|
last_active_step_items = [display_fp]
|
||||||
|
yield streaming_service.format_thinking_step(
|
||||||
|
step_id=tool_step_id,
|
||||||
|
title="Writing file",
|
||||||
|
status="in_progress",
|
||||||
|
items=last_active_step_items,
|
||||||
|
)
|
||||||
|
elif tool_name == "edit_file":
|
||||||
|
fp = (
|
||||||
|
tool_input.get("file_path", "")
|
||||||
|
if isinstance(tool_input, dict)
|
||||||
|
else str(tool_input)
|
||||||
|
)
|
||||||
|
display_fp = fp if len(fp) <= 80 else "…" + fp[-77:]
|
||||||
|
last_active_step_title = "Editing file"
|
||||||
|
last_active_step_items = [display_fp]
|
||||||
|
yield streaming_service.format_thinking_step(
|
||||||
|
step_id=tool_step_id,
|
||||||
|
title="Editing file",
|
||||||
|
status="in_progress",
|
||||||
|
items=last_active_step_items,
|
||||||
|
)
|
||||||
|
elif tool_name == "glob":
|
||||||
|
pat = (
|
||||||
|
tool_input.get("pattern", "")
|
||||||
|
if isinstance(tool_input, dict)
|
||||||
|
else str(tool_input)
|
||||||
|
)
|
||||||
|
base_path = (
|
||||||
|
tool_input.get("path", "/") if isinstance(tool_input, dict) else "/"
|
||||||
|
)
|
||||||
|
last_active_step_title = "Searching files"
|
||||||
|
last_active_step_items = [f"{pat} in {base_path}"]
|
||||||
|
yield streaming_service.format_thinking_step(
|
||||||
|
step_id=tool_step_id,
|
||||||
|
title="Searching files",
|
||||||
|
status="in_progress",
|
||||||
|
items=last_active_step_items,
|
||||||
|
)
|
||||||
|
elif tool_name == "grep":
|
||||||
|
pat = (
|
||||||
|
tool_input.get("pattern", "")
|
||||||
|
if isinstance(tool_input, dict)
|
||||||
|
else str(tool_input)
|
||||||
|
)
|
||||||
|
grep_path = (
|
||||||
|
tool_input.get("path", "") if isinstance(tool_input, dict) else ""
|
||||||
|
)
|
||||||
|
display_pat = pat[:60] + ("…" if len(pat) > 60 else "")
|
||||||
|
last_active_step_title = "Searching content"
|
||||||
last_active_step_items = [
|
last_active_step_items = [
|
||||||
f"Query: {query[:100]}{'...' if len(query) > 100 else ''}"
|
f'"{display_pat}"' + (f" in {grep_path}" if grep_path else "")
|
||||||
]
|
]
|
||||||
yield streaming_service.format_thinking_step(
|
yield streaming_service.format_thinking_step(
|
||||||
step_id=tool_step_id,
|
step_id=tool_step_id,
|
||||||
title="Searching knowledge base",
|
title="Searching content",
|
||||||
|
status="in_progress",
|
||||||
|
items=last_active_step_items,
|
||||||
|
)
|
||||||
|
elif tool_name == "save_document":
|
||||||
|
doc_title = (
|
||||||
|
tool_input.get("title", "")
|
||||||
|
if isinstance(tool_input, dict)
|
||||||
|
else str(tool_input)
|
||||||
|
)
|
||||||
|
display_title = doc_title[:60] + ("…" if len(doc_title) > 60 else "")
|
||||||
|
last_active_step_title = "Saving document"
|
||||||
|
last_active_step_items = [display_title]
|
||||||
|
yield streaming_service.format_thinking_step(
|
||||||
|
step_id=tool_step_id,
|
||||||
|
title="Saving document",
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
items=last_active_step_items,
|
items=last_active_step_items,
|
||||||
)
|
)
|
||||||
|
|
@ -441,10 +535,22 @@ async def _stream_agent_events(
|
||||||
else streaming_service.generate_tool_call_id()
|
else streaming_service.generate_tool_call_id()
|
||||||
)
|
)
|
||||||
yield streaming_service.format_tool_input_start(tool_call_id, tool_name)
|
yield streaming_service.format_tool_input_start(tool_call_id, tool_name)
|
||||||
|
# Sanitize tool_input: strip runtime-injected non-serializable
|
||||||
|
# values (e.g. LangChain ToolRuntime) before sending over SSE.
|
||||||
|
if isinstance(tool_input, dict):
|
||||||
|
_safe_input: dict[str, Any] = {}
|
||||||
|
for _k, _v in tool_input.items():
|
||||||
|
try:
|
||||||
|
json.dumps(_v)
|
||||||
|
_safe_input[_k] = _v
|
||||||
|
except (TypeError, ValueError, OverflowError):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
_safe_input = {"input": tool_input}
|
||||||
yield streaming_service.format_tool_input_available(
|
yield streaming_service.format_tool_input_available(
|
||||||
tool_call_id,
|
tool_call_id,
|
||||||
tool_name,
|
tool_name,
|
||||||
tool_input if isinstance(tool_input, dict) else {"input": tool_input},
|
_safe_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif event_type == "on_tool_end":
|
elif event_type == "on_tool_end":
|
||||||
|
|
@ -475,16 +581,55 @@ async def _stream_agent_events(
|
||||||
)
|
)
|
||||||
completed_step_ids.add(original_step_id)
|
completed_step_ids.add(original_step_id)
|
||||||
|
|
||||||
if tool_name == "search_knowledge_base":
|
if tool_name == "read_file":
|
||||||
result_info = "Search completed"
|
|
||||||
if isinstance(tool_output, dict):
|
|
||||||
result_len = tool_output.get("result_length", 0)
|
|
||||||
if result_len > 0:
|
|
||||||
result_info = f"Found relevant information ({result_len} chars)"
|
|
||||||
completed_items = [*last_active_step_items, result_info]
|
|
||||||
yield streaming_service.format_thinking_step(
|
yield streaming_service.format_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title="Searching knowledge base",
|
title="Reading file",
|
||||||
|
status="completed",
|
||||||
|
items=last_active_step_items,
|
||||||
|
)
|
||||||
|
elif tool_name == "write_file":
|
||||||
|
yield streaming_service.format_thinking_step(
|
||||||
|
step_id=original_step_id,
|
||||||
|
title="Writing file",
|
||||||
|
status="completed",
|
||||||
|
items=last_active_step_items,
|
||||||
|
)
|
||||||
|
elif tool_name == "edit_file":
|
||||||
|
yield streaming_service.format_thinking_step(
|
||||||
|
step_id=original_step_id,
|
||||||
|
title="Editing file",
|
||||||
|
status="completed",
|
||||||
|
items=last_active_step_items,
|
||||||
|
)
|
||||||
|
elif tool_name == "glob":
|
||||||
|
yield streaming_service.format_thinking_step(
|
||||||
|
step_id=original_step_id,
|
||||||
|
title="Searching files",
|
||||||
|
status="completed",
|
||||||
|
items=last_active_step_items,
|
||||||
|
)
|
||||||
|
elif tool_name == "grep":
|
||||||
|
yield streaming_service.format_thinking_step(
|
||||||
|
step_id=original_step_id,
|
||||||
|
title="Searching content",
|
||||||
|
status="completed",
|
||||||
|
items=last_active_step_items,
|
||||||
|
)
|
||||||
|
elif tool_name == "save_document":
|
||||||
|
result_str = (
|
||||||
|
tool_output.get("result", "")
|
||||||
|
if isinstance(tool_output, dict)
|
||||||
|
else str(tool_output)
|
||||||
|
)
|
||||||
|
is_error = "Error" in result_str
|
||||||
|
completed_items = [
|
||||||
|
*last_active_step_items,
|
||||||
|
result_str[:80] if is_error else "Saved to knowledge base",
|
||||||
|
]
|
||||||
|
yield streaming_service.format_thinking_step(
|
||||||
|
step_id=original_step_id,
|
||||||
|
title="Saving document",
|
||||||
status="completed",
|
status="completed",
|
||||||
items=completed_items,
|
items=completed_items,
|
||||||
)
|
)
|
||||||
|
|
@ -690,14 +835,23 @@ async def _stream_agent_events(
|
||||||
ls_output = str(tool_output) if tool_output else ""
|
ls_output = str(tool_output) if tool_output else ""
|
||||||
file_names: list[str] = []
|
file_names: list[str] = []
|
||||||
if ls_output:
|
if ls_output:
|
||||||
for line in ls_output.strip().split("\n"):
|
paths: list[str] = []
|
||||||
line = line.strip()
|
try:
|
||||||
if line:
|
parsed = ast.literal_eval(ls_output)
|
||||||
name = line.rstrip("/").split("/")[-1]
|
if isinstance(parsed, list):
|
||||||
if name and len(name) <= 40:
|
paths = [str(p) for p in parsed]
|
||||||
file_names.append(name)
|
except (ValueError, SyntaxError):
|
||||||
elif name:
|
paths = [
|
||||||
file_names.append(name[:37] + "...")
|
line.strip()
|
||||||
|
for line in ls_output.strip().split("\n")
|
||||||
|
if line.strip()
|
||||||
|
]
|
||||||
|
for p in paths:
|
||||||
|
name = p.rstrip("/").split("/")[-1]
|
||||||
|
if name and len(name) <= 40:
|
||||||
|
file_names.append(name)
|
||||||
|
elif name:
|
||||||
|
file_names.append(name[:37] + "...")
|
||||||
if file_names:
|
if file_names:
|
||||||
if len(file_names) <= 5:
|
if len(file_names) <= 5:
|
||||||
completed_items = [f"[{name}]" for name in file_names]
|
completed_items = [f"[{name}]" for name in file_names]
|
||||||
|
|
@ -708,7 +862,7 @@ async def _stream_agent_events(
|
||||||
completed_items = ["No files found"]
|
completed_items = ["No files found"]
|
||||||
yield streaming_service.format_thinking_step(
|
yield streaming_service.format_thinking_step(
|
||||||
step_id=original_step_id,
|
step_id=original_step_id,
|
||||||
title="Exploring files",
|
title="Listing files",
|
||||||
status="completed",
|
status="completed",
|
||||||
items=completed_items,
|
items=completed_items,
|
||||||
)
|
)
|
||||||
|
|
@ -832,14 +986,6 @@ async def _stream_agent_events(
|
||||||
f"Scrape failed: {error_msg}",
|
f"Scrape failed: {error_msg}",
|
||||||
"error",
|
"error",
|
||||||
)
|
)
|
||||||
elif tool_name == "search_knowledge_base":
|
|
||||||
yield streaming_service.format_tool_output_available(
|
|
||||||
tool_call_id,
|
|
||||||
{"status": "completed", "result_length": len(str(tool_output))},
|
|
||||||
)
|
|
||||||
yield streaming_service.format_terminal_info(
|
|
||||||
"Knowledge base search completed", "success"
|
|
||||||
)
|
|
||||||
elif tool_name == "generate_report":
|
elif tool_name == "generate_report":
|
||||||
# Stream the full report result so frontend can render the ReportCard
|
# Stream the full report result so frontend can render the ReportCard
|
||||||
yield streaming_service.format_tool_output_available(
|
yield streaming_service.format_tool_output_available(
|
||||||
|
|
@ -973,6 +1119,19 @@ async def _stream_agent_events(
|
||||||
items=last_active_step_items,
|
items=last_active_step_items,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif (
|
||||||
|
event_type == "on_custom_event" and event.get("name") == "document_created"
|
||||||
|
):
|
||||||
|
data = event.get("data", {})
|
||||||
|
if data.get("id"):
|
||||||
|
yield streaming_service.format_data(
|
||||||
|
"documents-updated",
|
||||||
|
{
|
||||||
|
"action": "created",
|
||||||
|
"document": data,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
elif event_type in ("on_chain_end", "on_agent_end"):
|
elif event_type in ("on_chain_end", "on_agent_end"):
|
||||||
if current_text_id is not None:
|
if current_text_id is not None:
|
||||||
yield streaming_service.format_text_end(current_text_id)
|
yield streaming_service.format_text_end(current_text_id)
|
||||||
|
|
@ -995,38 +1154,6 @@ async def _stream_agent_events(
|
||||||
yield streaming_service.format_interrupt_request(result.interrupt_value)
|
yield streaming_service.format_interrupt_request(result.interrupt_value)
|
||||||
|
|
||||||
|
|
||||||
def _try_persist_and_delete_sandbox(
|
|
||||||
thread_id: int,
|
|
||||||
sandbox_files: list[str],
|
|
||||||
) -> None:
|
|
||||||
"""Fire-and-forget: persist sandbox files locally then delete the sandbox."""
|
|
||||||
from app.agents.new_chat.sandbox import (
|
|
||||||
is_sandbox_enabled,
|
|
||||||
persist_and_delete_sandbox,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_sandbox_enabled():
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _run() -> None:
|
|
||||||
try:
|
|
||||||
await persist_and_delete_sandbox(thread_id, sandbox_files)
|
|
||||||
except Exception:
|
|
||||||
logging.getLogger(__name__).warning(
|
|
||||||
"persist_and_delete_sandbox failed for thread %s",
|
|
||||||
thread_id,
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
task = loop.create_task(_run())
|
|
||||||
_background_tasks.add(task)
|
|
||||||
task.add_done_callback(_background_tasks.discard)
|
|
||||||
except RuntimeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def stream_new_chat(
|
async def stream_new_chat(
|
||||||
user_query: str,
|
user_query: str,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
|
|
@ -1141,22 +1268,6 @@ async def stream_new_chat(
|
||||||
"[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
|
"[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
|
||||||
)
|
)
|
||||||
|
|
||||||
sandbox_backend = None
|
|
||||||
_t0 = time.perf_counter()
|
|
||||||
if is_sandbox_enabled():
|
|
||||||
try:
|
|
||||||
sandbox_backend = await get_or_create_sandbox(chat_id)
|
|
||||||
except Exception as sandbox_err:
|
|
||||||
logging.getLogger(__name__).warning(
|
|
||||||
"Sandbox creation failed, continuing without execute tool: %s",
|
|
||||||
sandbox_err,
|
|
||||||
)
|
|
||||||
_perf_log.info(
|
|
||||||
"[stream_new_chat] Sandbox provisioning in %.3fs (enabled=%s)",
|
|
||||||
time.perf_counter() - _t0,
|
|
||||||
sandbox_backend is not None,
|
|
||||||
)
|
|
||||||
|
|
||||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
agent = await create_surfsense_deep_agent(
|
agent = await create_surfsense_deep_agent(
|
||||||
|
|
@ -1170,7 +1281,6 @@ async def stream_new_chat(
|
||||||
agent_config=agent_config,
|
agent_config=agent_config,
|
||||||
firecrawl_api_key=firecrawl_api_key,
|
firecrawl_api_key=firecrawl_api_key,
|
||||||
thread_visibility=visibility,
|
thread_visibility=visibility,
|
||||||
sandbox_backend=sandbox_backend,
|
|
||||||
disabled_tools=disabled_tools,
|
disabled_tools=disabled_tools,
|
||||||
)
|
)
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
|
|
@ -1531,8 +1641,6 @@ async def stream_new_chat(
|
||||||
"Failed to clear AI responding state for thread %s", chat_id
|
"Failed to clear AI responding state for thread %s", chat_id
|
||||||
)
|
)
|
||||||
|
|
||||||
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
|
||||||
|
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
session.expunge_all()
|
session.expunge_all()
|
||||||
|
|
||||||
|
|
@ -1541,7 +1649,7 @@ async def stream_new_chat(
|
||||||
|
|
||||||
# Break circular refs held by the agent graph, tools, and LLM
|
# Break circular refs held by the agent graph, tools, and LLM
|
||||||
# wrappers so the GC can reclaim them in a single pass.
|
# wrappers so the GC can reclaim them in a single pass.
|
||||||
agent = llm = connector_service = sandbox_backend = None
|
agent = llm = connector_service = None
|
||||||
input_state = stream_result = None
|
input_state = stream_result = None
|
||||||
session = None
|
session = None
|
||||||
|
|
||||||
|
|
@ -1627,22 +1735,6 @@ async def stream_resume_chat(
|
||||||
"[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
|
"[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
|
||||||
)
|
)
|
||||||
|
|
||||||
sandbox_backend = None
|
|
||||||
_t0 = time.perf_counter()
|
|
||||||
if is_sandbox_enabled():
|
|
||||||
try:
|
|
||||||
sandbox_backend = await get_or_create_sandbox(chat_id)
|
|
||||||
except Exception as sandbox_err:
|
|
||||||
logging.getLogger(__name__).warning(
|
|
||||||
"Sandbox creation failed, continuing without execute tool: %s",
|
|
||||||
sandbox_err,
|
|
||||||
)
|
|
||||||
_perf_log.info(
|
|
||||||
"[stream_resume] Sandbox provisioning in %.3fs (enabled=%s)",
|
|
||||||
time.perf_counter() - _t0,
|
|
||||||
sandbox_backend is not None,
|
|
||||||
)
|
|
||||||
|
|
||||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||||
|
|
||||||
_t0 = time.perf_counter()
|
_t0 = time.perf_counter()
|
||||||
|
|
@ -1657,7 +1749,6 @@ async def stream_resume_chat(
|
||||||
agent_config=agent_config,
|
agent_config=agent_config,
|
||||||
firecrawl_api_key=firecrawl_api_key,
|
firecrawl_api_key=firecrawl_api_key,
|
||||||
thread_visibility=visibility,
|
thread_visibility=visibility,
|
||||||
sandbox_backend=sandbox_backend,
|
|
||||||
)
|
)
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
|
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
|
||||||
|
|
@ -1742,15 +1833,13 @@ async def stream_resume_chat(
|
||||||
"Failed to clear AI responding state for thread %s", chat_id
|
"Failed to clear AI responding state for thread %s", chat_id
|
||||||
)
|
)
|
||||||
|
|
||||||
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
|
||||||
|
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
session.expunge_all()
|
session.expunge_all()
|
||||||
|
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|
||||||
agent = llm = connector_service = sandbox_backend = None
|
agent = llm = connector_service = None
|
||||||
stream_result = None
|
stream_result = None
|
||||||
session = None
|
session = None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,10 @@ from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||||
from app.db import DocumentType, SearchSourceConnectorType
|
from app.db import DocumentType, SearchSourceConnectorType
|
||||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||||
from app.indexing_pipeline.document_hashing import compute_content_hash
|
from app.indexing_pipeline.document_hashing import compute_content_hash
|
||||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
from app.indexing_pipeline.indexing_pipeline_service import (
|
||||||
|
IndexingPipelineService,
|
||||||
|
PlaceholderInfo,
|
||||||
|
)
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
from app.services.task_logging_service import TaskLoggingService
|
from app.services.task_logging_service import TaskLoggingService
|
||||||
|
|
||||||
|
|
@ -194,6 +197,27 @@ async def index_confluence_pages(
|
||||||
await confluence_client.close()
|
await confluence_client.close()
|
||||||
return 0, 0, None
|
return 0, 0, None
|
||||||
|
|
||||||
|
# ── Create placeholders for instant UI feedback ───────────────
|
||||||
|
pipeline = IndexingPipelineService(session)
|
||||||
|
placeholders = [
|
||||||
|
PlaceholderInfo(
|
||||||
|
title=page.get("title", ""),
|
||||||
|
document_type=DocumentType.CONFLUENCE_CONNECTOR,
|
||||||
|
unique_id=page.get("id", ""),
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
connector_id=connector_id,
|
||||||
|
created_by_id=user_id,
|
||||||
|
metadata={
|
||||||
|
"page_id": page.get("id", ""),
|
||||||
|
"connector_id": connector_id,
|
||||||
|
"connector_type": "Confluence",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for page in pages
|
||||||
|
if page.get("id") and page.get("title")
|
||||||
|
]
|
||||||
|
await pipeline.create_placeholder_documents(placeholders)
|
||||||
|
|
||||||
documents_skipped = 0
|
documents_skipped = 0
|
||||||
duplicate_content_count = 0
|
duplicate_content_count = 0
|
||||||
connector_docs: list[ConnectorDocument] = []
|
connector_docs: list[ConnectorDocument] = []
|
||||||
|
|
@ -202,7 +226,7 @@ async def index_confluence_pages(
|
||||||
try:
|
try:
|
||||||
page_id = page.get("id")
|
page_id = page.get("id")
|
||||||
page_title = page.get("title", "")
|
page_title = page.get("title", "")
|
||||||
space_id = page.get("spaceId", "")
|
page.get("spaceId", "")
|
||||||
|
|
||||||
if not page_id or not page_title:
|
if not page_id or not page_title:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
@ -265,11 +289,12 @@ async def index_confluence_pages(
|
||||||
connector_docs.append(doc)
|
connector_docs.append(doc)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error building ConnectorDocument for page: {e!s}", exc_info=True)
|
logger.error(
|
||||||
|
f"Error building ConnectorDocument for page: {e!s}", exc_info=True
|
||||||
|
)
|
||||||
documents_skipped += 1
|
documents_skipped += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
pipeline = IndexingPipelineService(session)
|
|
||||||
await pipeline.migrate_legacy_docs(connector_docs)
|
await pipeline.migrate_legacy_docs(connector_docs)
|
||||||
|
|
||||||
async def _get_llm(s: AsyncSession):
|
async def _get_llm(s: AsyncSession):
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,10 @@ from app.connectors.google_calendar_connector import GoogleCalendarConnector
|
||||||
from app.db import DocumentType, SearchSourceConnectorType
|
from app.db import DocumentType, SearchSourceConnectorType
|
||||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||||
from app.indexing_pipeline.document_hashing import compute_content_hash
|
from app.indexing_pipeline.document_hashing import compute_content_hash
|
||||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
from app.indexing_pipeline.indexing_pipeline_service import (
|
||||||
|
IndexingPipelineService,
|
||||||
|
PlaceholderInfo,
|
||||||
|
)
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
from app.services.task_logging_service import TaskLoggingService
|
from app.services.task_logging_service import TaskLoggingService
|
||||||
from app.utils.google_credentials import (
|
from app.utils.google_credentials import (
|
||||||
|
|
@ -73,9 +76,7 @@ def _build_connector_doc(
|
||||||
"connector_type": "Google Calendar",
|
"connector_type": "Google Calendar",
|
||||||
}
|
}
|
||||||
|
|
||||||
fallback_summary = (
|
fallback_summary = f"Google Calendar Event: {event_summary}\n\n{event_markdown}"
|
||||||
f"Google Calendar Event: {event_summary}\n\n{event_markdown}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return ConnectorDocument(
|
return ConnectorDocument(
|
||||||
title=event_summary,
|
title=event_summary,
|
||||||
|
|
@ -344,6 +345,27 @@ async def index_google_calendar_events(
|
||||||
logger.error(f"Error fetching Google Calendar events: {e!s}", exc_info=True)
|
logger.error(f"Error fetching Google Calendar events: {e!s}", exc_info=True)
|
||||||
return 0, 0, f"Error fetching Google Calendar events: {e!s}"
|
return 0, 0, f"Error fetching Google Calendar events: {e!s}"
|
||||||
|
|
||||||
|
# ── Create placeholders for instant UI feedback ───────────────
|
||||||
|
pipeline = IndexingPipelineService(session)
|
||||||
|
placeholders = [
|
||||||
|
PlaceholderInfo(
|
||||||
|
title=event.get("summary", "No Title"),
|
||||||
|
document_type=DocumentType.GOOGLE_CALENDAR_CONNECTOR,
|
||||||
|
unique_id=event.get("id", ""),
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
connector_id=connector_id,
|
||||||
|
created_by_id=user_id,
|
||||||
|
metadata={
|
||||||
|
"event_id": event.get("id", ""),
|
||||||
|
"connector_id": connector_id,
|
||||||
|
"connector_type": "Google Calendar",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for event in events
|
||||||
|
if event.get("id")
|
||||||
|
]
|
||||||
|
await pipeline.create_placeholder_documents(placeholders)
|
||||||
|
|
||||||
# ── Build ConnectorDocuments ──────────────────────────────────
|
# ── Build ConnectorDocuments ──────────────────────────────────
|
||||||
connector_docs: list[ConnectorDocument] = []
|
connector_docs: list[ConnectorDocument] = []
|
||||||
documents_skipped = 0
|
documents_skipped = 0
|
||||||
|
|
@ -391,13 +413,13 @@ async def index_google_calendar_events(
|
||||||
connector_docs.append(doc)
|
connector_docs.append(doc)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error building ConnectorDocument for event: {e!s}", exc_info=True)
|
logger.error(
|
||||||
|
f"Error building ConnectorDocument for event: {e!s}", exc_info=True
|
||||||
|
)
|
||||||
documents_skipped += 1
|
documents_skipped += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# ── Pipeline: migrate legacy docs + parallel index ─────────────
|
# ── Pipeline: migrate legacy docs + parallel index ─────────────
|
||||||
pipeline = IndexingPipelineService(session)
|
|
||||||
|
|
||||||
await pipeline.migrate_legacy_docs(connector_docs)
|
await pipeline.migrate_legacy_docs(connector_docs)
|
||||||
|
|
||||||
async def _get_llm(s):
|
async def _get_llm(s):
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,10 @@ from app.connectors.google_drive.file_types import should_skip_file as skip_mime
|
||||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
from app.indexing_pipeline.indexing_pipeline_service import (
|
||||||
|
IndexingPipelineService,
|
||||||
|
PlaceholderInfo,
|
||||||
|
)
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
from app.services.task_logging_service import TaskLoggingService
|
from app.services.task_logging_service import TaskLoggingService
|
||||||
from app.tasks.connector_indexers.base import (
|
from app.tasks.connector_indexers.base import (
|
||||||
|
|
@ -57,6 +60,7 @@ logger = logging.getLogger(__name__)
|
||||||
# Helpers
|
# Helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
async def _should_skip_file(
|
async def _should_skip_file(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
file: dict,
|
file: dict,
|
||||||
|
|
@ -97,11 +101,14 @@ async def _should_skip_file(
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(Document).where(
|
select(Document).where(
|
||||||
Document.search_space_id == search_space_id,
|
Document.search_space_id == search_space_id,
|
||||||
Document.document_type.in_([
|
Document.document_type.in_(
|
||||||
DocumentType.GOOGLE_DRIVE_FILE,
|
[
|
||||||
DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
DocumentType.GOOGLE_DRIVE_FILE,
|
||||||
]),
|
DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||||
cast(Document.document_metadata["google_drive_file_id"], String) == file_id,
|
]
|
||||||
|
),
|
||||||
|
cast(Document.document_metadata["google_drive_file_id"], String)
|
||||||
|
== file_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
existing = result.scalar_one_or_none()
|
existing = result.scalar_one_or_none()
|
||||||
|
|
@ -191,6 +198,50 @@ def _build_connector_doc(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_drive_placeholders(
|
||||||
|
session: AsyncSession,
|
||||||
|
files: list[dict],
|
||||||
|
*,
|
||||||
|
connector_id: int,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Create placeholder document rows for discovered Drive files.
|
||||||
|
|
||||||
|
Called immediately after file discovery (Phase 1) so documents appear
|
||||||
|
in the UI via Zero sync before the slow download/ETL phase begins.
|
||||||
|
"""
|
||||||
|
if not files:
|
||||||
|
return
|
||||||
|
|
||||||
|
placeholders = []
|
||||||
|
for file in files:
|
||||||
|
file_id = file.get("id")
|
||||||
|
file_name = file.get("name", "Unknown")
|
||||||
|
if not file_id:
|
||||||
|
continue
|
||||||
|
placeholders.append(
|
||||||
|
PlaceholderInfo(
|
||||||
|
title=file_name,
|
||||||
|
document_type=DocumentType.GOOGLE_DRIVE_FILE,
|
||||||
|
unique_id=file_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
connector_id=connector_id,
|
||||||
|
created_by_id=user_id,
|
||||||
|
metadata={
|
||||||
|
"google_drive_file_id": file_id,
|
||||||
|
"FILE_NAME": file_name,
|
||||||
|
"connector_id": connector_id,
|
||||||
|
"connector_type": "Google Drive",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if placeholders:
|
||||||
|
pipeline = IndexingPipelineService(session)
|
||||||
|
await pipeline.create_placeholder_documents(placeholders)
|
||||||
|
|
||||||
|
|
||||||
async def _download_files_parallel(
|
async def _download_files_parallel(
|
||||||
drive_client: GoogleDriveClient,
|
drive_client: GoogleDriveClient,
|
||||||
files: list[dict],
|
files: list[dict],
|
||||||
|
|
@ -246,9 +297,7 @@ async def _download_files_parallel(
|
||||||
|
|
||||||
failed = 0
|
failed = 0
|
||||||
for outcome in outcomes:
|
for outcome in outcomes:
|
||||||
if isinstance(outcome, Exception):
|
if isinstance(outcome, Exception) or outcome is None:
|
||||||
failed += 1
|
|
||||||
elif outcome is None:
|
|
||||||
failed += 1
|
failed += 1
|
||||||
else:
|
else:
|
||||||
results.append(outcome)
|
results.append(outcome)
|
||||||
|
|
@ -300,14 +349,18 @@ async def _process_single_file(
|
||||||
if not documents:
|
if not documents:
|
||||||
return 0, 1, 0
|
return 0, 1, 0
|
||||||
|
|
||||||
from app.indexing_pipeline.document_hashing import compute_unique_identifier_hash
|
from app.indexing_pipeline.document_hashing import (
|
||||||
|
compute_unique_identifier_hash,
|
||||||
|
)
|
||||||
|
|
||||||
doc_map = {compute_unique_identifier_hash(doc): doc}
|
doc_map = {compute_unique_identifier_hash(doc): doc}
|
||||||
for document in documents:
|
for document in documents:
|
||||||
connector_doc = doc_map.get(document.unique_identifier_hash)
|
connector_doc = doc_map.get(document.unique_identifier_hash)
|
||||||
if not connector_doc:
|
if not connector_doc:
|
||||||
continue
|
continue
|
||||||
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
|
user_llm = await get_user_long_context_llm(
|
||||||
|
session, user_id, search_space_id
|
||||||
|
)
|
||||||
await pipeline.index(document, connector_doc, user_llm)
|
await pipeline.index(document, connector_doc, user_llm)
|
||||||
|
|
||||||
logger.info(f"Successfully indexed Google Drive file: {file_name}")
|
logger.info(f"Successfully indexed Google Drive file: {file_name}")
|
||||||
|
|
@ -335,11 +388,14 @@ async def _remove_document(session: AsyncSession, file_id: str, search_space_id:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(Document).where(
|
select(Document).where(
|
||||||
Document.search_space_id == search_space_id,
|
Document.search_space_id == search_space_id,
|
||||||
Document.document_type.in_([
|
Document.document_type.in_(
|
||||||
DocumentType.GOOGLE_DRIVE_FILE,
|
[
|
||||||
DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
DocumentType.GOOGLE_DRIVE_FILE,
|
||||||
]),
|
DocumentType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||||
cast(Document.document_metadata["google_drive_file_id"], String) == file_id,
|
]
|
||||||
|
),
|
||||||
|
cast(Document.document_metadata["google_drive_file_id"], String)
|
||||||
|
== file_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
existing = result.scalar_one_or_none()
|
existing = result.scalar_one_or_none()
|
||||||
|
|
@ -383,7 +439,9 @@ async def _download_and_index(
|
||||||
return await get_user_long_context_llm(s, user_id, search_space_id)
|
return await get_user_long_context_llm(s, user_id, search_space_id)
|
||||||
|
|
||||||
_, batch_indexed, batch_failed = await pipeline.index_batch_parallel(
|
_, batch_indexed, batch_failed = await pipeline.index_batch_parallel(
|
||||||
connector_docs, _get_llm, max_concurrency=3,
|
connector_docs,
|
||||||
|
_get_llm,
|
||||||
|
max_concurrency=3,
|
||||||
on_heartbeat=on_heartbeat,
|
on_heartbeat=on_heartbeat,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -430,10 +488,22 @@ async def _index_selected_files(
|
||||||
|
|
||||||
files_to_download.append(file)
|
files_to_download.append(file)
|
||||||
|
|
||||||
batch_indexed, failed = await _download_and_index(
|
await _create_drive_placeholders(
|
||||||
drive_client, session, files_to_download,
|
session,
|
||||||
connector_id=connector_id, search_space_id=search_space_id,
|
files_to_download,
|
||||||
user_id=user_id, enable_summary=enable_summary,
|
connector_id=connector_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_indexed, _failed = await _download_and_index(
|
||||||
|
drive_client,
|
||||||
|
session,
|
||||||
|
files_to_download,
|
||||||
|
connector_id=connector_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
enable_summary=enable_summary,
|
||||||
on_heartbeat=on_heartbeat,
|
on_heartbeat=on_heartbeat,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -444,6 +514,7 @@ async def _index_selected_files(
|
||||||
# Scan strategies
|
# Scan strategies
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
async def _index_full_scan(
|
async def _index_full_scan(
|
||||||
drive_client: GoogleDriveClient,
|
drive_client: GoogleDriveClient,
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
|
|
@ -464,7 +535,11 @@ async def _index_full_scan(
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Starting full scan of folder: {folder_name} (include_subfolders={include_subfolders})",
|
f"Starting full scan of folder: {folder_name} (include_subfolders={include_subfolders})",
|
||||||
{"stage": "full_scan", "folder_id": folder_id, "include_subfolders": include_subfolders},
|
{
|
||||||
|
"stage": "full_scan",
|
||||||
|
"folder_id": folder_id,
|
||||||
|
"include_subfolders": include_subfolders,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
@ -483,7 +558,10 @@ async def _index_full_scan(
|
||||||
|
|
||||||
while files_processed < max_files:
|
while files_processed < max_files:
|
||||||
files, next_token, error = await get_files_in_folder(
|
files, next_token, error = await get_files_in_folder(
|
||||||
drive_client, cur_id, include_subfolders=True, page_token=page_token,
|
drive_client,
|
||||||
|
cur_id,
|
||||||
|
include_subfolders=True,
|
||||||
|
page_token=page_token,
|
||||||
)
|
)
|
||||||
if error:
|
if error:
|
||||||
logger.error(f"Error listing files in {cur_name}: {error}")
|
logger.error(f"Error listing files in {cur_name}: {error}")
|
||||||
|
|
@ -500,7 +578,9 @@ async def _index_full_scan(
|
||||||
mime = file.get("mimeType", "")
|
mime = file.get("mimeType", "")
|
||||||
if mime == "application/vnd.google-apps.folder":
|
if mime == "application/vnd.google-apps.folder":
|
||||||
if include_subfolders:
|
if include_subfolders:
|
||||||
folders_to_process.append((file["id"], file.get("name", "Unknown")))
|
folders_to_process.append(
|
||||||
|
(file["id"], file.get("name", "Unknown"))
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
files_processed += 1
|
files_processed += 1
|
||||||
|
|
@ -521,24 +601,45 @@ async def _index_full_scan(
|
||||||
|
|
||||||
if not files_processed and first_error:
|
if not files_processed and first_error:
|
||||||
err_lower = first_error.lower()
|
err_lower = first_error.lower()
|
||||||
if "401" in first_error or "invalid credentials" in err_lower or "authError" in first_error:
|
if (
|
||||||
|
"401" in first_error
|
||||||
|
or "invalid credentials" in err_lower
|
||||||
|
or "authError" in first_error
|
||||||
|
):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Google Drive authentication failed. Please re-authenticate. (Error: {first_error})"
|
f"Google Drive authentication failed. Please re-authenticate. (Error: {first_error})"
|
||||||
)
|
)
|
||||||
raise Exception(f"Failed to list Google Drive files: {first_error}")
|
raise Exception(f"Failed to list Google Drive files: {first_error}")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Phase 1.5: create placeholders for instant UI feedback
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
await _create_drive_placeholders(
|
||||||
|
session,
|
||||||
|
files_to_download,
|
||||||
|
connector_id=connector_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Phase 2+3 (parallel): download, ETL, index
|
# Phase 2+3 (parallel): download, ETL, index
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
batch_indexed, failed = await _download_and_index(
|
batch_indexed, failed = await _download_and_index(
|
||||||
drive_client, session, files_to_download,
|
drive_client,
|
||||||
connector_id=connector_id, search_space_id=search_space_id,
|
session,
|
||||||
user_id=user_id, enable_summary=enable_summary,
|
files_to_download,
|
||||||
|
connector_id=connector_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
enable_summary=enable_summary,
|
||||||
on_heartbeat=on_heartbeat_callback,
|
on_heartbeat=on_heartbeat_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
indexed = renamed_count + batch_indexed
|
indexed = renamed_count + batch_indexed
|
||||||
logger.info(f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed")
|
logger.info(
|
||||||
|
f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||||
|
)
|
||||||
return indexed, skipped
|
return indexed, skipped
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -565,7 +666,9 @@ async def _index_with_delta_sync(
|
||||||
{"stage": "delta_sync", "start_token": start_page_token},
|
{"stage": "delta_sync", "start_token": start_page_token},
|
||||||
)
|
)
|
||||||
|
|
||||||
changes, _final_token, error = await fetch_all_changes(drive_client, start_page_token, folder_id)
|
changes, _final_token, error = await fetch_all_changes(
|
||||||
|
drive_client, start_page_token, folder_id
|
||||||
|
)
|
||||||
if error:
|
if error:
|
||||||
err_lower = error.lower()
|
err_lower = error.lower()
|
||||||
if "401" in error or "invalid credentials" in err_lower or "authError" in error:
|
if "401" in error or "invalid credentials" in err_lower or "authError" in error:
|
||||||
|
|
@ -614,18 +717,35 @@ async def _index_with_delta_sync(
|
||||||
|
|
||||||
files_to_download.append(file)
|
files_to_download.append(file)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Phase 1.5: create placeholders for instant UI feedback
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
await _create_drive_placeholders(
|
||||||
|
session,
|
||||||
|
files_to_download,
|
||||||
|
connector_id=connector_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Phase 2+3 (parallel): download, ETL, index
|
# Phase 2+3 (parallel): download, ETL, index
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
batch_indexed, failed = await _download_and_index(
|
batch_indexed, failed = await _download_and_index(
|
||||||
drive_client, session, files_to_download,
|
drive_client,
|
||||||
connector_id=connector_id, search_space_id=search_space_id,
|
session,
|
||||||
user_id=user_id, enable_summary=enable_summary,
|
files_to_download,
|
||||||
|
connector_id=connector_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
enable_summary=enable_summary,
|
||||||
on_heartbeat=on_heartbeat_callback,
|
on_heartbeat=on_heartbeat_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
indexed = renamed_count + batch_indexed
|
indexed = renamed_count + batch_indexed
|
||||||
logger.info(f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed")
|
logger.info(
|
||||||
|
f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed"
|
||||||
|
)
|
||||||
return indexed, skipped
|
return indexed, skipped
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -633,6 +753,7 @@ async def _index_with_delta_sync(
|
||||||
# Public entry points
|
# Public entry points
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
async def index_google_drive_files(
|
async def index_google_drive_files(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
connector_id: int,
|
connector_id: int,
|
||||||
|
|
@ -653,8 +774,11 @@ async def index_google_drive_files(
|
||||||
source="connector_indexing_task",
|
source="connector_indexing_task",
|
||||||
message=f"Starting Google Drive indexing for connector {connector_id}",
|
message=f"Starting Google Drive indexing for connector {connector_id}",
|
||||||
metadata={
|
metadata={
|
||||||
"connector_id": connector_id, "user_id": str(user_id),
|
"connector_id": connector_id,
|
||||||
"folder_id": folder_id, "use_delta_sync": use_delta_sync, "max_files": max_files,
|
"user_id": str(user_id),
|
||||||
|
"folder_id": folder_id,
|
||||||
|
"use_delta_sync": use_delta_sync,
|
||||||
|
"max_files": max_files,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -666,11 +790,14 @@ async def index_google_drive_files(
|
||||||
break
|
break
|
||||||
if not connector:
|
if not connector:
|
||||||
error_msg = f"Google Drive connector with ID {connector_id} not found"
|
error_msg = f"Google Drive connector with ID {connector_id} not found"
|
||||||
await task_logger.log_task_failure(log_entry, error_msg, None, {"error_type": "ConnectorNotFound"})
|
await task_logger.log_task_failure(
|
||||||
|
log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}
|
||||||
|
)
|
||||||
return 0, 0, error_msg
|
return 0, 0, error_msg
|
||||||
|
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry, f"Initializing Google Drive client for connector {connector_id}",
|
log_entry,
|
||||||
|
f"Initializing Google Drive client for connector {connector_id}",
|
||||||
{"stage": "client_initialization"},
|
{"stage": "client_initialization"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -679,24 +806,39 @@ async def index_google_drive_files(
|
||||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||||
if not connected_account_id:
|
if not connected_account_id:
|
||||||
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
|
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
|
||||||
await task_logger.log_task_failure(log_entry, error_msg, "Missing Composio account", {"error_type": "MissingComposioAccount"})
|
await task_logger.log_task_failure(
|
||||||
|
log_entry,
|
||||||
|
error_msg,
|
||||||
|
"Missing Composio account",
|
||||||
|
{"error_type": "MissingComposioAccount"},
|
||||||
|
)
|
||||||
return 0, 0, error_msg
|
return 0, 0, error_msg
|
||||||
pre_built_credentials = build_composio_credentials(connected_account_id)
|
pre_built_credentials = build_composio_credentials(connected_account_id)
|
||||||
else:
|
else:
|
||||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||||
if token_encrypted and not config.SECRET_KEY:
|
if token_encrypted and not config.SECRET_KEY:
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
log_entry, "SECRET_KEY not configured but credentials are encrypted",
|
log_entry,
|
||||||
"Missing SECRET_KEY", {"error_type": "MissingSecretKey"},
|
"SECRET_KEY not configured but credentials are encrypted",
|
||||||
|
"Missing SECRET_KEY",
|
||||||
|
{"error_type": "MissingSecretKey"},
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
"SECRET_KEY not configured but credentials are marked as encrypted",
|
||||||
)
|
)
|
||||||
return 0, 0, "SECRET_KEY not configured but credentials are marked as encrypted"
|
|
||||||
|
|
||||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||||
drive_client = GoogleDriveClient(session, connector_id, credentials=pre_built_credentials)
|
drive_client = GoogleDriveClient(
|
||||||
|
session, connector_id, credentials=pre_built_credentials
|
||||||
|
)
|
||||||
|
|
||||||
if not folder_id:
|
if not folder_id:
|
||||||
error_msg = "folder_id is required for Google Drive indexing"
|
error_msg = "folder_id is required for Google Drive indexing"
|
||||||
await task_logger.log_task_failure(log_entry, error_msg, {"error_type": "MissingParameter"})
|
await task_logger.log_task_failure(
|
||||||
|
log_entry, error_msg, {"error_type": "MissingParameter"}
|
||||||
|
)
|
||||||
return 0, 0, error_msg
|
return 0, 0, error_msg
|
||||||
|
|
||||||
target_folder_id = folder_id
|
target_folder_id = folder_id
|
||||||
|
|
@ -704,29 +846,64 @@ async def index_google_drive_files(
|
||||||
|
|
||||||
folder_tokens = connector.config.get("folder_tokens", {})
|
folder_tokens = connector.config.get("folder_tokens", {})
|
||||||
start_page_token = folder_tokens.get(target_folder_id)
|
start_page_token = folder_tokens.get(target_folder_id)
|
||||||
can_use_delta = use_delta_sync and start_page_token and connector.last_indexed_at
|
can_use_delta = (
|
||||||
|
use_delta_sync and start_page_token and connector.last_indexed_at
|
||||||
|
)
|
||||||
|
|
||||||
if can_use_delta:
|
if can_use_delta:
|
||||||
logger.info(f"Using delta sync for connector {connector_id}")
|
logger.info(f"Using delta sync for connector {connector_id}")
|
||||||
documents_indexed, documents_skipped = await _index_with_delta_sync(
|
documents_indexed, documents_skipped = await _index_with_delta_sync(
|
||||||
drive_client, session, connector, connector_id, search_space_id, user_id,
|
drive_client,
|
||||||
target_folder_id, start_page_token, task_logger, log_entry, max_files,
|
session,
|
||||||
include_subfolders, on_heartbeat_callback, connector_enable_summary,
|
connector,
|
||||||
|
connector_id,
|
||||||
|
search_space_id,
|
||||||
|
user_id,
|
||||||
|
target_folder_id,
|
||||||
|
start_page_token,
|
||||||
|
task_logger,
|
||||||
|
log_entry,
|
||||||
|
max_files,
|
||||||
|
include_subfolders,
|
||||||
|
on_heartbeat_callback,
|
||||||
|
connector_enable_summary,
|
||||||
)
|
)
|
||||||
logger.info("Running reconciliation scan after delta sync")
|
logger.info("Running reconciliation scan after delta sync")
|
||||||
ri, rs = await _index_full_scan(
|
ri, rs = await _index_full_scan(
|
||||||
drive_client, session, connector, connector_id, search_space_id, user_id,
|
drive_client,
|
||||||
target_folder_id, target_folder_name, task_logger, log_entry, max_files,
|
session,
|
||||||
include_subfolders, on_heartbeat_callback, connector_enable_summary,
|
connector,
|
||||||
|
connector_id,
|
||||||
|
search_space_id,
|
||||||
|
user_id,
|
||||||
|
target_folder_id,
|
||||||
|
target_folder_name,
|
||||||
|
task_logger,
|
||||||
|
log_entry,
|
||||||
|
max_files,
|
||||||
|
include_subfolders,
|
||||||
|
on_heartbeat_callback,
|
||||||
|
connector_enable_summary,
|
||||||
)
|
)
|
||||||
documents_indexed += ri
|
documents_indexed += ri
|
||||||
documents_skipped += rs
|
documents_skipped += rs
|
||||||
else:
|
else:
|
||||||
logger.info(f"Using full scan for connector {connector_id}")
|
logger.info(f"Using full scan for connector {connector_id}")
|
||||||
documents_indexed, documents_skipped = await _index_full_scan(
|
documents_indexed, documents_skipped = await _index_full_scan(
|
||||||
drive_client, session, connector, connector_id, search_space_id, user_id,
|
drive_client,
|
||||||
target_folder_id, target_folder_name, task_logger, log_entry, max_files,
|
session,
|
||||||
include_subfolders, on_heartbeat_callback, connector_enable_summary,
|
connector,
|
||||||
|
connector_id,
|
||||||
|
search_space_id,
|
||||||
|
user_id,
|
||||||
|
target_folder_id,
|
||||||
|
target_folder_name,
|
||||||
|
task_logger,
|
||||||
|
log_entry,
|
||||||
|
max_files,
|
||||||
|
include_subfolders,
|
||||||
|
on_heartbeat_callback,
|
||||||
|
connector_enable_summary,
|
||||||
)
|
)
|
||||||
|
|
||||||
if documents_indexed > 0 or can_use_delta:
|
if documents_indexed > 0 or can_use_delta:
|
||||||
|
|
@ -745,26 +922,34 @@ async def index_google_drive_files(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Successfully completed Google Drive indexing for connector {connector_id}",
|
f"Successfully completed Google Drive indexing for connector {connector_id}",
|
||||||
{
|
{
|
||||||
"files_processed": documents_indexed, "files_skipped": documents_skipped,
|
"files_processed": documents_indexed,
|
||||||
"sync_type": "delta" if can_use_delta else "full", "folder": target_folder_name,
|
"files_skipped": documents_skipped,
|
||||||
|
"sync_type": "delta" if can_use_delta else "full",
|
||||||
|
"folder": target_folder_name,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
logger.info(f"Google Drive indexing completed: {documents_indexed} indexed, {documents_skipped} skipped")
|
logger.info(
|
||||||
|
f"Google Drive indexing completed: {documents_indexed} indexed, {documents_skipped} skipped"
|
||||||
|
)
|
||||||
return documents_indexed, documents_skipped, None
|
return documents_indexed, documents_skipped, None
|
||||||
|
|
||||||
except SQLAlchemyError as db_error:
|
except SQLAlchemyError as db_error:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
log_entry, f"Database error during Google Drive indexing for connector {connector_id}",
|
log_entry,
|
||||||
str(db_error), {"error_type": "SQLAlchemyError"},
|
f"Database error during Google Drive indexing for connector {connector_id}",
|
||||||
|
str(db_error),
|
||||||
|
{"error_type": "SQLAlchemyError"},
|
||||||
)
|
)
|
||||||
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||||
return 0, 0, f"Database error: {db_error!s}"
|
return 0, 0, f"Database error: {db_error!s}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
log_entry, f"Failed to index Google Drive files for connector {connector_id}",
|
log_entry,
|
||||||
str(e), {"error_type": type(e).__name__},
|
f"Failed to index Google Drive files for connector {connector_id}",
|
||||||
|
str(e),
|
||||||
|
{"error_type": type(e).__name__},
|
||||||
)
|
)
|
||||||
logger.error(f"Failed to index Google Drive files: {e!s}", exc_info=True)
|
logger.error(f"Failed to index Google Drive files: {e!s}", exc_info=True)
|
||||||
return 0, 0, f"Failed to index Google Drive files: {e!s}"
|
return 0, 0, f"Failed to index Google Drive files: {e!s}"
|
||||||
|
|
@ -784,7 +969,12 @@ async def index_google_drive_single_file(
|
||||||
task_name="google_drive_single_file_indexing",
|
task_name="google_drive_single_file_indexing",
|
||||||
source="connector_indexing_task",
|
source="connector_indexing_task",
|
||||||
message=f"Starting Google Drive single file indexing for file {file_id}",
|
message=f"Starting Google Drive single file indexing for file {file_id}",
|
||||||
metadata={"connector_id": connector_id, "user_id": str(user_id), "file_id": file_id, "file_name": file_name},
|
metadata={
|
||||||
|
"connector_id": connector_id,
|
||||||
|
"user_id": str(user_id),
|
||||||
|
"file_id": file_id,
|
||||||
|
"file_name": file_name,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -795,7 +985,9 @@ async def index_google_drive_single_file(
|
||||||
break
|
break
|
||||||
if not connector:
|
if not connector:
|
||||||
error_msg = f"Google Drive connector with ID {connector_id} not found"
|
error_msg = f"Google Drive connector with ID {connector_id} not found"
|
||||||
await task_logger.log_task_failure(log_entry, error_msg, None, {"error_type": "ConnectorNotFound"})
|
await task_logger.log_task_failure(
|
||||||
|
log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}
|
||||||
|
)
|
||||||
return 0, error_msg
|
return 0, error_msg
|
||||||
|
|
||||||
pre_built_credentials = None
|
pre_built_credentials = None
|
||||||
|
|
@ -803,43 +995,65 @@ async def index_google_drive_single_file(
|
||||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||||
if not connected_account_id:
|
if not connected_account_id:
|
||||||
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
|
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
|
||||||
await task_logger.log_task_failure(log_entry, error_msg, "Missing Composio account", {"error_type": "MissingComposioAccount"})
|
await task_logger.log_task_failure(
|
||||||
|
log_entry,
|
||||||
|
error_msg,
|
||||||
|
"Missing Composio account",
|
||||||
|
{"error_type": "MissingComposioAccount"},
|
||||||
|
)
|
||||||
return 0, error_msg
|
return 0, error_msg
|
||||||
pre_built_credentials = build_composio_credentials(connected_account_id)
|
pre_built_credentials = build_composio_credentials(connected_account_id)
|
||||||
else:
|
else:
|
||||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||||
if token_encrypted and not config.SECRET_KEY:
|
if token_encrypted and not config.SECRET_KEY:
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
log_entry, "SECRET_KEY not configured but credentials are encrypted",
|
log_entry,
|
||||||
"Missing SECRET_KEY", {"error_type": "MissingSecretKey"},
|
"SECRET_KEY not configured but credentials are encrypted",
|
||||||
|
"Missing SECRET_KEY",
|
||||||
|
{"error_type": "MissingSecretKey"},
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
0,
|
||||||
|
"SECRET_KEY not configured but credentials are marked as encrypted",
|
||||||
)
|
)
|
||||||
return 0, "SECRET_KEY not configured but credentials are marked as encrypted"
|
|
||||||
|
|
||||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||||
drive_client = GoogleDriveClient(session, connector_id, credentials=pre_built_credentials)
|
drive_client = GoogleDriveClient(
|
||||||
|
session, connector_id, credentials=pre_built_credentials
|
||||||
|
)
|
||||||
|
|
||||||
file, error = await get_file_by_id(drive_client, file_id)
|
file, error = await get_file_by_id(drive_client, file_id)
|
||||||
if error or not file:
|
if error or not file:
|
||||||
error_msg = f"Failed to fetch file {file_id}: {error or 'File not found'}"
|
error_msg = f"Failed to fetch file {file_id}: {error or 'File not found'}"
|
||||||
await task_logger.log_task_failure(log_entry, error_msg, {"error_type": "FileNotFound"})
|
await task_logger.log_task_failure(
|
||||||
|
log_entry, error_msg, {"error_type": "FileNotFound"}
|
||||||
|
)
|
||||||
return 0, error_msg
|
return 0, error_msg
|
||||||
|
|
||||||
display_name = file_name or file.get("name", "Unknown")
|
display_name = file_name or file.get("name", "Unknown")
|
||||||
|
|
||||||
indexed, _skipped, failed = await _process_single_file(
|
indexed, _skipped, failed = await _process_single_file(
|
||||||
drive_client, session, file,
|
drive_client,
|
||||||
connector_id, search_space_id, user_id, connector_enable_summary,
|
session,
|
||||||
|
file,
|
||||||
|
connector_id,
|
||||||
|
search_space_id,
|
||||||
|
user_id,
|
||||||
|
connector_enable_summary,
|
||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
if failed > 0:
|
if failed > 0:
|
||||||
error_msg = f"Failed to index file {display_name}"
|
error_msg = f"Failed to index file {display_name}"
|
||||||
await task_logger.log_task_failure(log_entry, error_msg, {"file_name": display_name, "file_id": file_id})
|
await task_logger.log_task_failure(
|
||||||
|
log_entry, error_msg, {"file_name": display_name, "file_id": file_id}
|
||||||
|
)
|
||||||
return 0, error_msg
|
return 0, error_msg
|
||||||
|
|
||||||
if indexed > 0:
|
if indexed > 0:
|
||||||
await task_logger.log_task_success(
|
await task_logger.log_task_success(
|
||||||
log_entry, f"Successfully indexed file {display_name}",
|
log_entry,
|
||||||
|
f"Successfully indexed file {display_name}",
|
||||||
{"file_name": display_name, "file_id": file_id},
|
{"file_name": display_name, "file_id": file_id},
|
||||||
)
|
)
|
||||||
return 1, None
|
return 1, None
|
||||||
|
|
@ -848,12 +1062,22 @@ async def index_google_drive_single_file(
|
||||||
|
|
||||||
except SQLAlchemyError as db_error:
|
except SQLAlchemyError as db_error:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
await task_logger.log_task_failure(log_entry, "Database error during file indexing", str(db_error), {"error_type": "SQLAlchemyError"})
|
await task_logger.log_task_failure(
|
||||||
|
log_entry,
|
||||||
|
"Database error during file indexing",
|
||||||
|
str(db_error),
|
||||||
|
{"error_type": "SQLAlchemyError"},
|
||||||
|
)
|
||||||
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||||
return 0, f"Database error: {db_error!s}"
|
return 0, f"Database error: {db_error!s}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
await task_logger.log_task_failure(log_entry, "Failed to index Google Drive file", str(e), {"error_type": type(e).__name__})
|
await task_logger.log_task_failure(
|
||||||
|
log_entry,
|
||||||
|
"Failed to index Google Drive file",
|
||||||
|
str(e),
|
||||||
|
{"error_type": type(e).__name__},
|
||||||
|
)
|
||||||
logger.error(f"Failed to index Google Drive file: {e!s}", exc_info=True)
|
logger.error(f"Failed to index Google Drive file: {e!s}", exc_info=True)
|
||||||
return 0, f"Failed to index Google Drive file: {e!s}"
|
return 0, f"Failed to index Google Drive file: {e!s}"
|
||||||
|
|
||||||
|
|
@ -878,7 +1102,11 @@ async def index_google_drive_selected_files(
|
||||||
task_name="google_drive_selected_files_indexing",
|
task_name="google_drive_selected_files_indexing",
|
||||||
source="connector_indexing_task",
|
source="connector_indexing_task",
|
||||||
message=f"Starting Google Drive batch file indexing for {len(files)} files",
|
message=f"Starting Google Drive batch file indexing for {len(files)} files",
|
||||||
metadata={"connector_id": connector_id, "user_id": str(user_id), "file_count": len(files)},
|
metadata={
|
||||||
|
"connector_id": connector_id,
|
||||||
|
"user_id": str(user_id),
|
||||||
|
"file_count": len(files),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -889,7 +1117,9 @@ async def index_google_drive_selected_files(
|
||||||
break
|
break
|
||||||
if not connector:
|
if not connector:
|
||||||
error_msg = f"Google Drive connector with ID {connector_id} not found"
|
error_msg = f"Google Drive connector with ID {connector_id} not found"
|
||||||
await task_logger.log_task_failure(log_entry, error_msg, None, {"error_type": "ConnectorNotFound"})
|
await task_logger.log_task_failure(
|
||||||
|
log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}
|
||||||
|
)
|
||||||
return 0, 0, [error_msg]
|
return 0, 0, [error_msg]
|
||||||
|
|
||||||
pre_built_credentials = None
|
pre_built_credentials = None
|
||||||
|
|
@ -897,25 +1127,41 @@ async def index_google_drive_selected_files(
|
||||||
connected_account_id = connector.config.get("composio_connected_account_id")
|
connected_account_id = connector.config.get("composio_connected_account_id")
|
||||||
if not connected_account_id:
|
if not connected_account_id:
|
||||||
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
|
error_msg = f"Composio connected_account_id not found for connector {connector_id}"
|
||||||
await task_logger.log_task_failure(log_entry, error_msg, "Missing Composio account", {"error_type": "MissingComposioAccount"})
|
await task_logger.log_task_failure(
|
||||||
|
log_entry,
|
||||||
|
error_msg,
|
||||||
|
"Missing Composio account",
|
||||||
|
{"error_type": "MissingComposioAccount"},
|
||||||
|
)
|
||||||
return 0, 0, [error_msg]
|
return 0, 0, [error_msg]
|
||||||
pre_built_credentials = build_composio_credentials(connected_account_id)
|
pre_built_credentials = build_composio_credentials(connected_account_id)
|
||||||
else:
|
else:
|
||||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||||
if token_encrypted and not config.SECRET_KEY:
|
if token_encrypted and not config.SECRET_KEY:
|
||||||
error_msg = "SECRET_KEY not configured but credentials are marked as encrypted"
|
error_msg = (
|
||||||
|
"SECRET_KEY not configured but credentials are marked as encrypted"
|
||||||
|
)
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
log_entry, error_msg, "Missing SECRET_KEY", {"error_type": "MissingSecretKey"},
|
log_entry,
|
||||||
|
error_msg,
|
||||||
|
"Missing SECRET_KEY",
|
||||||
|
{"error_type": "MissingSecretKey"},
|
||||||
)
|
)
|
||||||
return 0, 0, [error_msg]
|
return 0, 0, [error_msg]
|
||||||
|
|
||||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||||
drive_client = GoogleDriveClient(session, connector_id, credentials=pre_built_credentials)
|
drive_client = GoogleDriveClient(
|
||||||
|
session, connector_id, credentials=pre_built_credentials
|
||||||
|
)
|
||||||
|
|
||||||
indexed, skipped, errors = await _index_selected_files(
|
indexed, skipped, errors = await _index_selected_files(
|
||||||
drive_client, session, files,
|
drive_client,
|
||||||
connector_id=connector_id, search_space_id=search_space_id,
|
session,
|
||||||
user_id=user_id, enable_summary=connector_enable_summary,
|
files,
|
||||||
|
connector_id=connector_id,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
user_id=user_id,
|
||||||
|
enable_summary=connector_enable_summary,
|
||||||
on_heartbeat=on_heartbeat_callback,
|
on_heartbeat=on_heartbeat_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -935,18 +1181,24 @@ async def index_google_drive_selected_files(
|
||||||
{"indexed": indexed, "skipped": skipped},
|
{"indexed": indexed, "skipped": skipped},
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Selected files indexing: {indexed} indexed, {skipped} skipped, {len(errors)} errors")
|
logger.info(
|
||||||
|
f"Selected files indexing: {indexed} indexed, {skipped} skipped, {len(errors)} errors"
|
||||||
|
)
|
||||||
return indexed, skipped, errors
|
return indexed, skipped, errors
|
||||||
|
|
||||||
except SQLAlchemyError as db_error:
|
except SQLAlchemyError as db_error:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
error_msg = f"Database error: {db_error!s}"
|
error_msg = f"Database error: {db_error!s}"
|
||||||
await task_logger.log_task_failure(log_entry, error_msg, str(db_error), {"error_type": "SQLAlchemyError"})
|
await task_logger.log_task_failure(
|
||||||
|
log_entry, error_msg, str(db_error), {"error_type": "SQLAlchemyError"}
|
||||||
|
)
|
||||||
logger.error(error_msg, exc_info=True)
|
logger.error(error_msg, exc_info=True)
|
||||||
return 0, 0, [error_msg]
|
return 0, 0, [error_msg]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
error_msg = f"Failed to index Google Drive files: {e!s}"
|
error_msg = f"Failed to index Google Drive files: {e!s}"
|
||||||
await task_logger.log_task_failure(log_entry, error_msg, str(e), {"error_type": type(e).__name__})
|
await task_logger.log_task_failure(
|
||||||
|
log_entry, error_msg, str(e), {"error_type": type(e).__name__}
|
||||||
|
)
|
||||||
logger.error(error_msg, exc_info=True)
|
logger.error(error_msg, exc_info=True)
|
||||||
return 0, 0, [error_msg]
|
return 0, 0, [error_msg]
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,10 @@ from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||||
from app.db import DocumentType, SearchSourceConnectorType
|
from app.db import DocumentType, SearchSourceConnectorType
|
||||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||||
from app.indexing_pipeline.document_hashing import compute_content_hash
|
from app.indexing_pipeline.document_hashing import compute_content_hash
|
||||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
from app.indexing_pipeline.indexing_pipeline_service import (
|
||||||
|
IndexingPipelineService,
|
||||||
|
PlaceholderInfo,
|
||||||
|
)
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
from app.services.task_logging_service import TaskLoggingService
|
from app.services.task_logging_service import TaskLoggingService
|
||||||
from app.utils.google_credentials import (
|
from app.utils.google_credentials import (
|
||||||
|
|
@ -282,6 +285,34 @@ async def index_google_gmail_messages(
|
||||||
|
|
||||||
logger.info(f"Found {len(messages)} Google gmail messages to index")
|
logger.info(f"Found {len(messages)} Google gmail messages to index")
|
||||||
|
|
||||||
|
# ── Create placeholders for instant UI feedback ───────────────
|
||||||
|
pipeline = IndexingPipelineService(session)
|
||||||
|
|
||||||
|
def _gmail_subject(msg: dict) -> str:
|
||||||
|
for h in msg.get("payload", {}).get("headers", []):
|
||||||
|
if h.get("name", "").lower() == "subject":
|
||||||
|
return h.get("value", "No Subject")
|
||||||
|
return "No Subject"
|
||||||
|
|
||||||
|
placeholders = [
|
||||||
|
PlaceholderInfo(
|
||||||
|
title=_gmail_subject(msg),
|
||||||
|
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||||
|
unique_id=msg.get("id", ""),
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
connector_id=connector_id,
|
||||||
|
created_by_id=user_id,
|
||||||
|
metadata={
|
||||||
|
"message_id": msg.get("id", ""),
|
||||||
|
"connector_id": connector_id,
|
||||||
|
"connector_type": "Google Gmail",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for msg in messages
|
||||||
|
if msg.get("id")
|
||||||
|
]
|
||||||
|
await pipeline.create_placeholder_documents(placeholders)
|
||||||
|
|
||||||
# ── Build ConnectorDocuments ──────────────────────────────────
|
# ── Build ConnectorDocuments ──────────────────────────────────
|
||||||
connector_docs: list[ConnectorDocument] = []
|
connector_docs: list[ConnectorDocument] = []
|
||||||
documents_skipped = 0
|
documents_skipped = 0
|
||||||
|
|
@ -327,13 +358,14 @@ async def index_google_gmail_messages(
|
||||||
connector_docs.append(doc)
|
connector_docs.append(doc)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error building ConnectorDocument for message: {e!s}", exc_info=True)
|
logger.error(
|
||||||
|
f"Error building ConnectorDocument for message: {e!s}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
documents_skipped += 1
|
documents_skipped += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# ── Pipeline: migrate legacy docs + parallel index ─────────────
|
# ── Pipeline: migrate legacy docs + parallel index ─────────────
|
||||||
pipeline = IndexingPipelineService(session)
|
|
||||||
|
|
||||||
await pipeline.migrate_legacy_docs(connector_docs)
|
await pipeline.migrate_legacy_docs(connector_docs)
|
||||||
|
|
||||||
async def _get_llm(s):
|
async def _get_llm(s):
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,10 @@ from app.connectors.jira_history import JiraHistoryConnector
|
||||||
from app.db import DocumentType, SearchSourceConnectorType
|
from app.db import DocumentType, SearchSourceConnectorType
|
||||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||||
from app.indexing_pipeline.document_hashing import compute_content_hash
|
from app.indexing_pipeline.document_hashing import compute_content_hash
|
||||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
from app.indexing_pipeline.indexing_pipeline_service import (
|
||||||
|
IndexingPipelineService,
|
||||||
|
PlaceholderInfo,
|
||||||
|
)
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
from app.services.task_logging_service import TaskLoggingService
|
from app.services.task_logging_service import TaskLoggingService
|
||||||
|
|
||||||
|
|
@ -191,6 +194,27 @@ async def index_jira_issues(
|
||||||
await jira_client.close()
|
await jira_client.close()
|
||||||
return 0, 0, None
|
return 0, 0, None
|
||||||
|
|
||||||
|
# ── Create placeholders for instant UI feedback ───────────────
|
||||||
|
pipeline = IndexingPipelineService(session)
|
||||||
|
placeholders = [
|
||||||
|
PlaceholderInfo(
|
||||||
|
title=f"{issue.get('key', '')}: {issue.get('id', '')}",
|
||||||
|
document_type=DocumentType.JIRA_CONNECTOR,
|
||||||
|
unique_id=issue.get("key", ""),
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
connector_id=connector_id,
|
||||||
|
created_by_id=user_id,
|
||||||
|
metadata={
|
||||||
|
"issue_id": issue.get("key", ""),
|
||||||
|
"connector_id": connector_id,
|
||||||
|
"connector_type": "Jira",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for issue in issues
|
||||||
|
if issue.get("key") and issue.get("id")
|
||||||
|
]
|
||||||
|
await pipeline.create_placeholder_documents(placeholders)
|
||||||
|
|
||||||
connector_docs: list[ConnectorDocument] = []
|
connector_docs: list[ConnectorDocument] = []
|
||||||
documents_skipped = 0
|
documents_skipped = 0
|
||||||
duplicate_content_count = 0
|
duplicate_content_count = 0
|
||||||
|
|
@ -253,7 +277,6 @@ async def index_jira_issues(
|
||||||
documents_skipped += 1
|
documents_skipped += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
pipeline = IndexingPipelineService(session)
|
|
||||||
await pipeline.migrate_legacy_docs(connector_docs)
|
await pipeline.migrate_legacy_docs(connector_docs)
|
||||||
|
|
||||||
async def _get_llm(s: AsyncSession):
|
async def _get_llm(s: AsyncSession):
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,10 @@ from app.connectors.linear_connector import LinearConnector
|
||||||
from app.db import DocumentType, SearchSourceConnectorType
|
from app.db import DocumentType, SearchSourceConnectorType
|
||||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||||
from app.indexing_pipeline.document_hashing import compute_content_hash
|
from app.indexing_pipeline.document_hashing import compute_content_hash
|
||||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
from app.indexing_pipeline.indexing_pipeline_service import (
|
||||||
|
IndexingPipelineService,
|
||||||
|
PlaceholderInfo,
|
||||||
|
)
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
from app.services.task_logging_service import TaskLoggingService
|
from app.services.task_logging_service import TaskLoggingService
|
||||||
|
|
||||||
|
|
@ -199,9 +202,7 @@ async def index_linear_issues(
|
||||||
logger.info(f"Retrieved {len(issues)} issues from Linear API")
|
logger.info(f"Retrieved {len(issues)} issues from Linear API")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"Exception when calling Linear API: {e!s}", exc_info=True)
|
||||||
f"Exception when calling Linear API: {e!s}", exc_info=True
|
|
||||||
)
|
|
||||||
return 0, 0, f"Failed to get Linear issues: {e!s}"
|
return 0, 0, f"Failed to get Linear issues: {e!s}"
|
||||||
|
|
||||||
if not issues:
|
if not issues:
|
||||||
|
|
@ -213,6 +214,28 @@ async def index_linear_issues(
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return 0, 0, None
|
return 0, 0, None
|
||||||
|
|
||||||
|
# ── Create placeholders for instant UI feedback ───────────────
|
||||||
|
pipeline = IndexingPipelineService(session)
|
||||||
|
placeholders = [
|
||||||
|
PlaceholderInfo(
|
||||||
|
title=f"{issue.get('identifier', '')}: {issue.get('title', '')}",
|
||||||
|
document_type=DocumentType.LINEAR_CONNECTOR,
|
||||||
|
unique_id=issue.get("id", ""),
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
connector_id=connector_id,
|
||||||
|
created_by_id=user_id,
|
||||||
|
metadata={
|
||||||
|
"issue_id": issue.get("id", ""),
|
||||||
|
"issue_identifier": issue.get("identifier", ""),
|
||||||
|
"connector_id": connector_id,
|
||||||
|
"connector_type": "Linear",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for issue in issues
|
||||||
|
if issue.get("id") and issue.get("title")
|
||||||
|
]
|
||||||
|
await pipeline.create_placeholder_documents(placeholders)
|
||||||
|
|
||||||
# ── Build ConnectorDocuments ──────────────────────────────────
|
# ── Build ConnectorDocuments ──────────────────────────────────
|
||||||
connector_docs: list[ConnectorDocument] = []
|
connector_docs: list[ConnectorDocument] = []
|
||||||
documents_skipped = 0
|
documents_skipped = 0
|
||||||
|
|
@ -238,9 +261,7 @@ async def index_linear_issues(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
formatted_issue = linear_client.format_issue(issue)
|
formatted_issue = linear_client.format_issue(issue)
|
||||||
issue_content = linear_client.format_issue_to_markdown(
|
issue_content = linear_client.format_issue_to_markdown(formatted_issue)
|
||||||
formatted_issue
|
|
||||||
)
|
|
||||||
|
|
||||||
if not issue_content:
|
if not issue_content:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
@ -284,8 +305,6 @@ async def index_linear_issues(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# ── Pipeline: migrate legacy docs + parallel index ────────────
|
# ── Pipeline: migrate legacy docs + parallel index ────────────
|
||||||
pipeline = IndexingPipelineService(session)
|
|
||||||
|
|
||||||
await pipeline.migrate_legacy_docs(connector_docs)
|
await pipeline.migrate_legacy_docs(connector_docs)
|
||||||
|
|
||||||
async def _get_llm(s):
|
async def _get_llm(s):
|
||||||
|
|
@ -302,9 +321,7 @@ async def index_linear_issues(
|
||||||
# ── Finalize ──────────────────────────────────────────────────
|
# ── Finalize ──────────────────────────────────────────────────
|
||||||
await update_connector_last_indexed(session, connector, update_last_indexed)
|
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"Final commit: Total {documents_indexed} Linear issues processed")
|
||||||
f"Final commit: Total {documents_indexed} Linear issues processed"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
await session.commit()
|
await session.commit()
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,10 @@ from app.connectors.notion_history import NotionHistoryConnector
|
||||||
from app.db import DocumentType, SearchSourceConnectorType
|
from app.db import DocumentType, SearchSourceConnectorType
|
||||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||||
from app.indexing_pipeline.document_hashing import compute_content_hash
|
from app.indexing_pipeline.document_hashing import compute_content_hash
|
||||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
from app.indexing_pipeline.indexing_pipeline_service import (
|
||||||
|
IndexingPipelineService,
|
||||||
|
PlaceholderInfo,
|
||||||
|
)
|
||||||
from app.services.llm_service import get_user_long_context_llm
|
from app.services.llm_service import get_user_long_context_llm
|
||||||
from app.services.task_logging_service import TaskLoggingService
|
from app.services.task_logging_service import TaskLoggingService
|
||||||
from app.utils.notion_utils import process_blocks
|
from app.utils.notion_utils import process_blocks
|
||||||
|
|
@ -245,13 +248,32 @@ async def index_notion_pages(
|
||||||
{"pages_found": 0},
|
{"pages_found": 0},
|
||||||
)
|
)
|
||||||
logger.info("No Notion pages found to index")
|
logger.info("No Notion pages found to index")
|
||||||
await update_connector_last_indexed(
|
await update_connector_last_indexed(session, connector, update_last_indexed)
|
||||||
session, connector, update_last_indexed
|
|
||||||
)
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await notion_client.close()
|
await notion_client.close()
|
||||||
return 0, 0, None
|
return 0, 0, None
|
||||||
|
|
||||||
|
# ── Create placeholders for instant UI feedback ───────────────
|
||||||
|
pipeline = IndexingPipelineService(session)
|
||||||
|
placeholders = [
|
||||||
|
PlaceholderInfo(
|
||||||
|
title=page.get("title", f"Untitled page ({page.get('page_id', '')})"),
|
||||||
|
document_type=DocumentType.NOTION_CONNECTOR,
|
||||||
|
unique_id=page.get("page_id", ""),
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
connector_id=connector_id,
|
||||||
|
created_by_id=user_id,
|
||||||
|
metadata={
|
||||||
|
"page_id": page.get("page_id", ""),
|
||||||
|
"connector_id": connector_id,
|
||||||
|
"connector_type": "Notion",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for page in pages
|
||||||
|
if page.get("page_id")
|
||||||
|
]
|
||||||
|
await pipeline.create_placeholder_documents(placeholders)
|
||||||
|
|
||||||
# ── Build ConnectorDocuments ──────────────────────────────────
|
# ── Build ConnectorDocuments ──────────────────────────────────
|
||||||
connector_docs: list[ConnectorDocument] = []
|
connector_docs: list[ConnectorDocument] = []
|
||||||
documents_skipped = 0
|
documents_skipped = 0
|
||||||
|
|
@ -282,9 +304,7 @@ async def index_notion_pages(
|
||||||
markdown_content += process_blocks(page_content)
|
markdown_content += process_blocks(page_content)
|
||||||
|
|
||||||
if not markdown_content.strip():
|
if not markdown_content.strip():
|
||||||
logger.warning(
|
logger.warning(f"Skipping page with empty markdown: {page_title}")
|
||||||
f"Skipping page with empty markdown: {page_title}"
|
|
||||||
)
|
|
||||||
documents_skipped += 1
|
documents_skipped += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -322,8 +342,6 @@ async def index_notion_pages(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# ── Pipeline: migrate legacy docs + parallel index ────────────
|
# ── Pipeline: migrate legacy docs + parallel index ────────────
|
||||||
pipeline = IndexingPipelineService(session)
|
|
||||||
|
|
||||||
await pipeline.migrate_legacy_docs(connector_docs)
|
await pipeline.migrate_legacy_docs(connector_docs)
|
||||||
|
|
||||||
async def _get_llm(s):
|
async def _get_llm(s):
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,6 @@ dependencies = [
|
||||||
"redis>=5.2.1",
|
"redis>=5.2.1",
|
||||||
"firecrawl-py>=4.9.0",
|
"firecrawl-py>=4.9.0",
|
||||||
"boto3>=1.35.0",
|
"boto3>=1.35.0",
|
||||||
"langchain-community>=0.3.31",
|
|
||||||
"litellm>=1.80.10",
|
"litellm>=1.80.10",
|
||||||
"langchain-litellm>=0.3.5",
|
"langchain-litellm>=0.3.5",
|
||||||
"fake-useragent>=2.2.0",
|
"fake-useragent>=2.2.0",
|
||||||
|
|
@ -60,20 +59,21 @@ dependencies = [
|
||||||
"sse-starlette>=3.1.1,<3.1.2",
|
"sse-starlette>=3.1.1,<3.1.2",
|
||||||
"gitingest>=0.3.1",
|
"gitingest>=0.3.1",
|
||||||
"composio>=0.10.9",
|
"composio>=0.10.9",
|
||||||
"langchain>=1.2.6",
|
|
||||||
"langgraph>=1.0.5",
|
|
||||||
"unstructured[all-docs]>=0.18.31",
|
"unstructured[all-docs]>=0.18.31",
|
||||||
"unstructured-client>=0.42.3",
|
"unstructured-client>=0.42.3",
|
||||||
"langchain-unstructured>=1.0.1",
|
"langchain-unstructured>=1.0.1",
|
||||||
"slowapi>=0.1.9",
|
"slowapi>=0.1.9",
|
||||||
"pypandoc_binary>=1.16.2",
|
"pypandoc_binary>=1.16.2",
|
||||||
"typst>=0.14.0",
|
"typst>=0.14.0",
|
||||||
"deepagents>=0.4.3",
|
|
||||||
"daytona>=0.146.0",
|
"daytona>=0.146.0",
|
||||||
"langchain-daytona>=0.0.2",
|
"langchain-daytona>=0.0.2",
|
||||||
"pypandoc>=1.16.2",
|
"pypandoc>=1.16.2",
|
||||||
"notion-markdown>=0.7.0",
|
"notion-markdown>=0.7.0",
|
||||||
"fractional-indexing>=0.1.3",
|
"fractional-indexing>=0.1.3",
|
||||||
|
"langchain>=1.2.13",
|
||||||
|
"langgraph>=1.1.3",
|
||||||
|
"langchain-community>=0.4.1",
|
||||||
|
"deepagents>=0.4.12",
|
||||||
]
|
]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,9 @@ _EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
||||||
pytestmark = pytest.mark.integration
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
def _cal_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
|
def _cal_doc(
|
||||||
|
*, unique_id: str, search_space_id: int, connector_id: int, user_id: str
|
||||||
|
) -> ConnectorDocument:
|
||||||
return ConnectorDocument(
|
return ConnectorDocument(
|
||||||
title=f"Event {unique_id}",
|
title=f"Event {unique_id}",
|
||||||
source_markdown=f"## Calendar Event\n\nDetails for {unique_id}",
|
source_markdown=f"## Calendar Event\n\nDetails for {unique_id}",
|
||||||
|
|
@ -34,7 +36,9 @@ def _cal_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_calendar_pipeline_creates_ready_document(
|
async def test_calendar_pipeline_creates_ready_document(
|
||||||
db_session, db_search_space, db_connector, db_user, mocker
|
db_session, db_search_space, db_connector, db_user, mocker
|
||||||
):
|
):
|
||||||
|
|
@ -63,7 +67,9 @@ async def test_calendar_pipeline_creates_ready_document(
|
||||||
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
|
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_calendar_legacy_doc_migrated(
|
async def test_calendar_legacy_doc_migrated(
|
||||||
db_session, db_search_space, db_connector, db_user, mocker
|
db_session, db_search_space, db_connector, db_user, mocker
|
||||||
):
|
):
|
||||||
|
|
@ -101,7 +107,9 @@ async def test_calendar_legacy_doc_migrated(
|
||||||
service = IndexingPipelineService(session=db_session)
|
service = IndexingPipelineService(session=db_session)
|
||||||
await service.migrate_legacy_docs([connector_doc])
|
await service.migrate_legacy_docs([connector_doc])
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == original_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == original_id)
|
||||||
|
)
|
||||||
row = result.scalars().first()
|
row = result.scalars().first()
|
||||||
|
|
||||||
assert row.document_type == DocumentType.GOOGLE_CALENDAR_CONNECTOR
|
assert row.document_type == DocumentType.GOOGLE_CALENDAR_CONNECTOR
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,9 @@ _EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
||||||
pytestmark = pytest.mark.integration
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
def _drive_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
|
def _drive_doc(
|
||||||
|
*, unique_id: str, search_space_id: int, connector_id: int, user_id: str
|
||||||
|
) -> ConnectorDocument:
|
||||||
return ConnectorDocument(
|
return ConnectorDocument(
|
||||||
title=f"File {unique_id}.pdf",
|
title=f"File {unique_id}.pdf",
|
||||||
source_markdown=f"## Document Content\n\nText from file {unique_id}",
|
source_markdown=f"## Document Content\n\nText from file {unique_id}",
|
||||||
|
|
@ -33,7 +35,9 @@ def _drive_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_drive_pipeline_creates_ready_document(
|
async def test_drive_pipeline_creates_ready_document(
|
||||||
db_session, db_search_space, db_connector, db_user, mocker
|
db_session, db_search_space, db_connector, db_user, mocker
|
||||||
):
|
):
|
||||||
|
|
@ -62,7 +66,9 @@ async def test_drive_pipeline_creates_ready_document(
|
||||||
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
|
assert DocumentStatus.is_state(row.status, DocumentStatus.READY)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_drive_legacy_doc_migrated(
|
async def test_drive_legacy_doc_migrated(
|
||||||
db_session, db_search_space, db_connector, db_user, mocker
|
db_session, db_search_space, db_connector, db_user, mocker
|
||||||
):
|
):
|
||||||
|
|
@ -100,7 +106,9 @@ async def test_drive_legacy_doc_migrated(
|
||||||
service = IndexingPipelineService(session=db_session)
|
service = IndexingPipelineService(session=db_session)
|
||||||
await service.migrate_legacy_docs([connector_doc])
|
await service.migrate_legacy_docs([connector_doc])
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == original_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == original_id)
|
||||||
|
)
|
||||||
row = result.scalars().first()
|
row = result.scalars().first()
|
||||||
|
|
||||||
assert row.document_type == DocumentType.GOOGLE_DRIVE_FILE
|
assert row.document_type == DocumentType.GOOGLE_DRIVE_FILE
|
||||||
|
|
@ -111,7 +119,9 @@ async def test_drive_legacy_doc_migrated(
|
||||||
|
|
||||||
|
|
||||||
async def test_should_skip_file_skips_failed_document(
|
async def test_should_skip_file_skips_failed_document(
|
||||||
db_session, db_search_space, db_user,
|
db_session,
|
||||||
|
db_search_space,
|
||||||
|
db_user,
|
||||||
):
|
):
|
||||||
"""A FAILED document with unchanged md5 must be skipped — user can manually retry via Quick Index."""
|
"""A FAILED document with unchanged md5 must be skipped — user can manually retry via Quick Index."""
|
||||||
import importlib
|
import importlib
|
||||||
|
|
@ -162,7 +172,12 @@ async def test_should_skip_file_skips_failed_document(
|
||||||
db_session.add(failed_doc)
|
db_session.add(failed_doc)
|
||||||
await db_session.flush()
|
await db_session.flush()
|
||||||
|
|
||||||
incoming_file = {"id": file_id, "name": "Failed File.pdf", "mimeType": "application/pdf", "md5Checksum": md5}
|
incoming_file = {
|
||||||
|
"id": file_id,
|
||||||
|
"name": "Failed File.pdf",
|
||||||
|
"mimeType": "application/pdf",
|
||||||
|
"md5Checksum": md5,
|
||||||
|
}
|
||||||
|
|
||||||
should_skip, msg = await _should_skip_file(db_session, incoming_file, space_id)
|
should_skip, msg = await _should_skip_file(db_session, incoming_file, space_id)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ from app.db import Document, DocumentStatus, DocumentType
|
||||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||||
from app.indexing_pipeline.document_hashing import (
|
from app.indexing_pipeline.document_hashing import (
|
||||||
compute_identifier_hash,
|
compute_identifier_hash,
|
||||||
compute_unique_identifier_hash,
|
|
||||||
)
|
)
|
||||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||||
|
|
||||||
|
|
@ -17,7 +16,9 @@ _EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
||||||
pytestmark = pytest.mark.integration
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
def _gmail_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_id: str) -> ConnectorDocument:
|
def _gmail_doc(
|
||||||
|
*, unique_id: str, search_space_id: int, connector_id: int, user_id: str
|
||||||
|
) -> ConnectorDocument:
|
||||||
"""Build a Gmail-style ConnectorDocument like the real indexer does."""
|
"""Build a Gmail-style ConnectorDocument like the real indexer does."""
|
||||||
return ConnectorDocument(
|
return ConnectorDocument(
|
||||||
title=f"Subject for {unique_id}",
|
title=f"Subject for {unique_id}",
|
||||||
|
|
@ -37,7 +38,9 @@ def _gmail_doc(*, unique_id: str, search_space_id: int, connector_id: int, user_
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_gmail_pipeline_creates_ready_document(
|
async def test_gmail_pipeline_creates_ready_document(
|
||||||
db_session, db_search_space, db_connector, db_user, mocker
|
db_session, db_search_space, db_connector, db_user, mocker
|
||||||
):
|
):
|
||||||
|
|
@ -67,7 +70,9 @@ async def test_gmail_pipeline_creates_ready_document(
|
||||||
assert row.source_markdown == doc.source_markdown
|
assert row.source_markdown == doc.source_markdown
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_gmail_legacy_doc_migrated_then_reused(
|
async def test_gmail_legacy_doc_migrated_then_reused(
|
||||||
db_session, db_search_space, db_connector, db_user, mocker
|
db_session, db_search_space, db_connector, db_user, mocker
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,9 @@ from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineServ
|
||||||
pytestmark = pytest.mark.integration
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_index_batch_creates_ready_documents(
|
async def test_index_batch_creates_ready_documents(
|
||||||
db_session, db_search_space, make_connector_document, mocker
|
db_session, db_search_space, make_connector_document, mocker
|
||||||
):
|
):
|
||||||
|
|
@ -47,7 +49,9 @@ async def test_index_batch_creates_ready_documents(
|
||||||
assert row.embedding is not None
|
assert row.embedding is not None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_texts", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_texts", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_index_batch_empty_returns_empty(db_session, mocker):
|
async def test_index_batch_empty_returns_empty(db_session, mocker):
|
||||||
"""index_batch with empty input returns an empty list."""
|
"""index_batch with empty input returns an empty list."""
|
||||||
service = IndexingPipelineService(session=db_session)
|
service = IndexingPipelineService(session=db_session)
|
||||||
|
|
|
||||||
106
surfsense_backend/tests/integration/retriever/conftest.py
Normal file
106
surfsense_backend/tests/integration/retriever/conftest.py
Normal file
|
|
@ -0,0 +1,106 @@
|
||||||
|
"""Shared fixtures for retriever integration tests."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
import pytest_asyncio
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.config import config as app_config
|
||||||
|
from app.db import Chunk, Document, DocumentType, SearchSpace, User
|
||||||
|
|
||||||
|
EMBEDDING_DIM = app_config.embedding_model_instance.dimension
|
||||||
|
DUMMY_EMBEDDING = [0.1] * EMBEDDING_DIM
|
||||||
|
|
||||||
|
|
||||||
|
def _make_document(
|
||||||
|
*,
|
||||||
|
title: str,
|
||||||
|
document_type: DocumentType,
|
||||||
|
content: str,
|
||||||
|
search_space_id: int,
|
||||||
|
created_by_id: str,
|
||||||
|
) -> Document:
|
||||||
|
uid = uuid.uuid4().hex[:12]
|
||||||
|
return Document(
|
||||||
|
title=title,
|
||||||
|
document_type=document_type,
|
||||||
|
content=content,
|
||||||
|
content_hash=f"content-{uid}",
|
||||||
|
unique_identifier_hash=f"uid-{uid}",
|
||||||
|
source_markdown=content,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
created_by_id=created_by_id,
|
||||||
|
embedding=DUMMY_EMBEDDING,
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
status={"state": "ready"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_chunk(*, content: str, document_id: int) -> Chunk:
|
||||||
|
return Chunk(
|
||||||
|
content=content,
|
||||||
|
document_id=document_id,
|
||||||
|
embedding=DUMMY_EMBEDDING,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def seed_large_doc(
|
||||||
|
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
|
||||||
|
):
|
||||||
|
"""Insert a document with 35 chunks (more than _MAX_FETCH_CHUNKS_PER_DOC=20).
|
||||||
|
|
||||||
|
Also inserts a small 3-chunk document for diversity testing.
|
||||||
|
Returns a dict with ``large_doc``, ``small_doc``, ``search_space``, ``user``,
|
||||||
|
and ``large_chunk_ids`` (all 35 chunk IDs).
|
||||||
|
"""
|
||||||
|
user_id = str(db_user.id)
|
||||||
|
space_id = db_search_space.id
|
||||||
|
|
||||||
|
large_doc = _make_document(
|
||||||
|
title="Large PDF Document",
|
||||||
|
document_type=DocumentType.FILE,
|
||||||
|
content="large document about quarterly performance reviews and budgets",
|
||||||
|
search_space_id=space_id,
|
||||||
|
created_by_id=user_id,
|
||||||
|
)
|
||||||
|
small_doc = _make_document(
|
||||||
|
title="Small Note",
|
||||||
|
document_type=DocumentType.NOTE,
|
||||||
|
content="quarterly performance review summary note",
|
||||||
|
search_space_id=space_id,
|
||||||
|
created_by_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add_all([large_doc, small_doc])
|
||||||
|
await db_session.flush()
|
||||||
|
|
||||||
|
large_chunks = []
|
||||||
|
for i in range(35):
|
||||||
|
chunk = _make_chunk(
|
||||||
|
content=f"chunk {i} about quarterly performance review section {i}",
|
||||||
|
document_id=large_doc.id,
|
||||||
|
)
|
||||||
|
large_chunks.append(chunk)
|
||||||
|
|
||||||
|
small_chunks = [
|
||||||
|
_make_chunk(
|
||||||
|
content="quarterly performance review summary note content",
|
||||||
|
document_id=small_doc.id,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
db_session.add_all(large_chunks + small_chunks)
|
||||||
|
await db_session.flush()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"large_doc": large_doc,
|
||||||
|
"small_doc": small_doc,
|
||||||
|
"large_chunk_ids": [c.id for c in large_chunks],
|
||||||
|
"small_chunk_ids": [c.id for c in small_chunks],
|
||||||
|
"search_space": db_search_space,
|
||||||
|
"user": db_user,
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,116 @@
|
||||||
|
"""Integration tests for optimized ChucksHybridSearchRetriever.
|
||||||
|
|
||||||
|
Verifies the SQL ROW_NUMBER per-doc chunk limit, column pruning,
|
||||||
|
and doc metadata caching from RRF results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.retriever.chunks_hybrid_search import (
|
||||||
|
_MAX_FETCH_CHUNKS_PER_DOC,
|
||||||
|
ChucksHybridSearchRetriever,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .conftest import DUMMY_EMBEDDING
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
|
async def test_per_doc_chunk_limit_respected(db_session, seed_large_doc):
|
||||||
|
"""A document with 35 chunks should have at most _MAX_FETCH_CHUNKS_PER_DOC chunks returned."""
|
||||||
|
space_id = seed_large_doc["search_space"].id
|
||||||
|
|
||||||
|
retriever = ChucksHybridSearchRetriever(db_session)
|
||||||
|
results = await retriever.hybrid_search(
|
||||||
|
query_text="quarterly performance review",
|
||||||
|
top_k=10,
|
||||||
|
search_space_id=space_id,
|
||||||
|
query_embedding=DUMMY_EMBEDDING,
|
||||||
|
)
|
||||||
|
|
||||||
|
large_doc_id = seed_large_doc["large_doc"].id
|
||||||
|
for result in results:
|
||||||
|
if result["document"].get("id") == large_doc_id:
|
||||||
|
assert len(result["chunks"]) <= _MAX_FETCH_CHUNKS_PER_DOC
|
||||||
|
assert len(result["chunks"]) == _MAX_FETCH_CHUNKS_PER_DOC
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
pytest.fail("Large doc not found in search results")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_doc_metadata_populated_from_rrf(db_session, seed_large_doc):
|
||||||
|
"""Document metadata (title, type, etc.) should be present even without joinedload."""
|
||||||
|
space_id = seed_large_doc["search_space"].id
|
||||||
|
|
||||||
|
retriever = ChucksHybridSearchRetriever(db_session)
|
||||||
|
results = await retriever.hybrid_search(
|
||||||
|
query_text="quarterly performance review",
|
||||||
|
top_k=10,
|
||||||
|
search_space_id=space_id,
|
||||||
|
query_embedding=DUMMY_EMBEDDING,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) >= 1
|
||||||
|
for result in results:
|
||||||
|
doc = result["document"]
|
||||||
|
assert "id" in doc
|
||||||
|
assert "title" in doc
|
||||||
|
assert doc["title"]
|
||||||
|
assert "document_type" in doc
|
||||||
|
assert doc["document_type"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_matched_chunk_ids_tracked(db_session, seed_large_doc):
|
||||||
|
"""matched_chunk_ids should contain the chunk IDs that appeared in the RRF results."""
|
||||||
|
space_id = seed_large_doc["search_space"].id
|
||||||
|
|
||||||
|
retriever = ChucksHybridSearchRetriever(db_session)
|
||||||
|
results = await retriever.hybrid_search(
|
||||||
|
query_text="quarterly performance review",
|
||||||
|
top_k=10,
|
||||||
|
search_space_id=space_id,
|
||||||
|
query_embedding=DUMMY_EMBEDDING,
|
||||||
|
)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
matched = result.get("matched_chunk_ids", [])
|
||||||
|
chunk_ids_in_result = {c["chunk_id"] for c in result["chunks"]}
|
||||||
|
for mid in matched:
|
||||||
|
assert mid in chunk_ids_in_result, (
|
||||||
|
f"matched_chunk_id {mid} not found in chunks"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_chunks_ordered_by_id(db_session, seed_large_doc):
|
||||||
|
"""Chunks within each document should be ordered by chunk ID (original order)."""
|
||||||
|
space_id = seed_large_doc["search_space"].id
|
||||||
|
|
||||||
|
retriever = ChucksHybridSearchRetriever(db_session)
|
||||||
|
results = await retriever.hybrid_search(
|
||||||
|
query_text="quarterly performance review",
|
||||||
|
top_k=10,
|
||||||
|
search_space_id=space_id,
|
||||||
|
query_embedding=DUMMY_EMBEDDING,
|
||||||
|
)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
chunk_ids = [c["chunk_id"] for c in result["chunks"]]
|
||||||
|
assert chunk_ids == sorted(chunk_ids), "Chunks not ordered by ID"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_score_is_positive_float(db_session, seed_large_doc):
|
||||||
|
"""Each result should have a positive float score from RRF."""
|
||||||
|
space_id = seed_large_doc["search_space"].id
|
||||||
|
|
||||||
|
retriever = ChucksHybridSearchRetriever(db_session)
|
||||||
|
results = await retriever.hybrid_search(
|
||||||
|
query_text="quarterly performance review",
|
||||||
|
top_k=10,
|
||||||
|
search_space_id=space_id,
|
||||||
|
query_embedding=DUMMY_EMBEDDING,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) >= 1
|
||||||
|
for result in results:
|
||||||
|
assert isinstance(result["score"], float)
|
||||||
|
assert result["score"] > 0
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
"""Integration tests for optimized DocumentHybridSearchRetriever.
|
||||||
|
|
||||||
|
Verifies the SQL ROW_NUMBER per-doc chunk limit and column pruning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.retriever.documents_hybrid_search import (
|
||||||
|
_MAX_FETCH_CHUNKS_PER_DOC,
|
||||||
|
DocumentHybridSearchRetriever,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .conftest import DUMMY_EMBEDDING
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
|
async def test_per_doc_chunk_limit_respected(db_session, seed_large_doc):
|
||||||
|
"""A document with 35 chunks should have at most _MAX_FETCH_CHUNKS_PER_DOC chunks returned."""
|
||||||
|
space_id = seed_large_doc["search_space"].id
|
||||||
|
|
||||||
|
retriever = DocumentHybridSearchRetriever(db_session)
|
||||||
|
results = await retriever.hybrid_search(
|
||||||
|
query_text="quarterly performance review",
|
||||||
|
top_k=10,
|
||||||
|
search_space_id=space_id,
|
||||||
|
query_embedding=DUMMY_EMBEDDING,
|
||||||
|
)
|
||||||
|
|
||||||
|
large_doc_id = seed_large_doc["large_doc"].id
|
||||||
|
for result in results:
|
||||||
|
if result["document"].get("id") == large_doc_id:
|
||||||
|
assert len(result["chunks"]) <= _MAX_FETCH_CHUNKS_PER_DOC
|
||||||
|
assert len(result["chunks"]) == _MAX_FETCH_CHUNKS_PER_DOC
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
pytest.fail("Large doc not found in search results")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_doc_metadata_populated(db_session, seed_large_doc):
|
||||||
|
"""Document metadata should be present from the RRF results."""
|
||||||
|
space_id = seed_large_doc["search_space"].id
|
||||||
|
|
||||||
|
retriever = DocumentHybridSearchRetriever(db_session)
|
||||||
|
results = await retriever.hybrid_search(
|
||||||
|
query_text="quarterly performance review",
|
||||||
|
top_k=10,
|
||||||
|
search_space_id=space_id,
|
||||||
|
query_embedding=DUMMY_EMBEDDING,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) >= 1
|
||||||
|
for result in results:
|
||||||
|
doc = result["document"]
|
||||||
|
assert "id" in doc
|
||||||
|
assert "title" in doc
|
||||||
|
assert doc["title"]
|
||||||
|
assert "document_type" in doc
|
||||||
|
assert doc["document_type"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_chunks_ordered_by_id(db_session, seed_large_doc):
|
||||||
|
"""Chunks within each document should be ordered by chunk ID."""
|
||||||
|
space_id = seed_large_doc["search_space"].id
|
||||||
|
|
||||||
|
retriever = DocumentHybridSearchRetriever(db_session)
|
||||||
|
results = await retriever.hybrid_search(
|
||||||
|
query_text="quarterly performance review",
|
||||||
|
top_k=10,
|
||||||
|
search_space_id=space_id,
|
||||||
|
query_embedding=DUMMY_EMBEDDING,
|
||||||
|
)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
chunk_ids = [c["chunk_id"] for c in result["chunks"]]
|
||||||
|
assert chunk_ids == sorted(chunk_ids), "Chunks not ordered by ID"
|
||||||
|
|
@ -42,14 +42,11 @@ def _to_markdown(page: dict) -> str:
|
||||||
if comments:
|
if comments:
|
||||||
comments_content = "\n\n## Comments\n\n"
|
comments_content = "\n\n## Comments\n\n"
|
||||||
for comment in comments:
|
for comment in comments:
|
||||||
comment_body = (
|
comment_body = comment.get("body", {}).get("storage", {}).get("value", "")
|
||||||
comment.get("body", {}).get("storage", {}).get("value", "")
|
|
||||||
)
|
|
||||||
comment_author = comment.get("version", {}).get("authorId", "Unknown")
|
comment_author = comment.get("version", {}).get("authorId", "Unknown")
|
||||||
comment_date = comment.get("version", {}).get("createdAt", "")
|
comment_date = comment.get("version", {}).get("createdAt", "")
|
||||||
comments_content += (
|
comments_content += (
|
||||||
f"**Comment by {comment_author}** ({comment_date}):\n"
|
f"**Comment by {comment_author}** ({comment_date}):\n{comment_body}\n\n"
|
||||||
f"{comment_body}\n\n"
|
|
||||||
)
|
)
|
||||||
return f"# {page_title}\n\n{page_content}{comments_content}"
|
return f"# {page_title}\n\n{page_content}{comments_content}"
|
||||||
|
|
||||||
|
|
@ -138,22 +135,32 @@ def confluence_mocks(monkeypatch):
|
||||||
|
|
||||||
mock_connector = _mock_connector()
|
mock_connector = _mock_connector()
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
|
_mod,
|
||||||
|
"get_connector_by_id",
|
||||||
|
AsyncMock(return_value=mock_connector),
|
||||||
)
|
)
|
||||||
|
|
||||||
confluence_client = _mock_confluence_client(pages=[_make_page()])
|
confluence_client = _mock_confluence_client(pages=[_make_page()])
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "ConfluenceHistoryConnector", MagicMock(return_value=confluence_client),
|
_mod,
|
||||||
|
"ConfluenceHistoryConnector",
|
||||||
|
MagicMock(return_value=confluence_client),
|
||||||
)
|
)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
|
_mod,
|
||||||
|
"check_duplicate_document_by_hash",
|
||||||
|
AsyncMock(return_value=None),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "update_connector_last_indexed", AsyncMock(),
|
_mod,
|
||||||
|
"update_connector_last_indexed",
|
||||||
|
AsyncMock(),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
|
_mod,
|
||||||
|
"calculate_date_range",
|
||||||
|
MagicMock(return_value=("2025-01-01", "2025-12-31")),
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_task_logger = MagicMock()
|
mock_task_logger = MagicMock()
|
||||||
|
|
@ -162,15 +169,20 @@ def confluence_mocks(monkeypatch):
|
||||||
mock_task_logger.log_task_success = AsyncMock()
|
mock_task_logger.log_task_success = AsyncMock()
|
||||||
mock_task_logger.log_task_failure = AsyncMock()
|
mock_task_logger.log_task_failure = AsyncMock()
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
|
_mod,
|
||||||
|
"TaskLoggingService",
|
||||||
|
MagicMock(return_value=mock_task_logger),
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_mock = AsyncMock(return_value=([], 1, 0))
|
batch_mock = AsyncMock(return_value=([], 1, 0))
|
||||||
pipeline_mock = MagicMock()
|
pipeline_mock = MagicMock()
|
||||||
pipeline_mock.index_batch_parallel = batch_mock
|
pipeline_mock.index_batch_parallel = batch_mock
|
||||||
pipeline_mock.migrate_legacy_docs = AsyncMock()
|
pipeline_mock.migrate_legacy_docs = AsyncMock()
|
||||||
|
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
|
_mod,
|
||||||
|
"IndexingPipelineService",
|
||||||
|
MagicMock(return_value=pipeline_mock),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ def mock_drive_client():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def patch_extract(monkeypatch):
|
def patch_extract(monkeypatch):
|
||||||
"""Provide a helper to set the download_and_extract_content mock."""
|
"""Provide a helper to set the download_and_extract_content mock."""
|
||||||
|
|
||||||
def _patch(side_effect=None, return_value=None):
|
def _patch(side_effect=None, return_value=None):
|
||||||
mock = AsyncMock(side_effect=side_effect, return_value=return_value)
|
mock = AsyncMock(side_effect=side_effect, return_value=return_value)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
|
|
@ -48,11 +49,13 @@ def patch_extract(monkeypatch):
|
||||||
mock,
|
mock,
|
||||||
)
|
)
|
||||||
return mock
|
return mock
|
||||||
|
|
||||||
return _patch
|
return _patch
|
||||||
|
|
||||||
|
|
||||||
async def test_single_file_returns_one_connector_document(
|
async def test_single_file_returns_one_connector_document(
|
||||||
mock_drive_client, patch_extract,
|
mock_drive_client,
|
||||||
|
patch_extract,
|
||||||
):
|
):
|
||||||
"""Tracer bullet: downloading one file produces one ConnectorDocument."""
|
"""Tracer bullet: downloading one file produces one ConnectorDocument."""
|
||||||
patch_extract(return_value=_mock_extract_ok("f1", "test.txt"))
|
patch_extract(return_value=_mock_extract_ok("f1", "test.txt"))
|
||||||
|
|
@ -73,7 +76,8 @@ async def test_single_file_returns_one_connector_document(
|
||||||
|
|
||||||
|
|
||||||
async def test_multiple_files_all_produce_documents(
|
async def test_multiple_files_all_produce_documents(
|
||||||
mock_drive_client, patch_extract,
|
mock_drive_client,
|
||||||
|
patch_extract,
|
||||||
):
|
):
|
||||||
"""All files are downloaded and converted to ConnectorDocuments."""
|
"""All files are downloaded and converted to ConnectorDocuments."""
|
||||||
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
|
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
|
||||||
|
|
@ -96,7 +100,8 @@ async def test_multiple_files_all_produce_documents(
|
||||||
|
|
||||||
|
|
||||||
async def test_one_download_exception_does_not_block_others(
|
async def test_one_download_exception_does_not_block_others(
|
||||||
mock_drive_client, patch_extract,
|
mock_drive_client,
|
||||||
|
patch_extract,
|
||||||
):
|
):
|
||||||
"""A RuntimeError in one download still lets the other files succeed."""
|
"""A RuntimeError in one download still lets the other files succeed."""
|
||||||
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
|
files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(3)]
|
||||||
|
|
@ -123,7 +128,8 @@ async def test_one_download_exception_does_not_block_others(
|
||||||
|
|
||||||
|
|
||||||
async def test_etl_error_counts_as_download_failure(
|
async def test_etl_error_counts_as_download_failure(
|
||||||
mock_drive_client, patch_extract,
|
mock_drive_client,
|
||||||
|
patch_extract,
|
||||||
):
|
):
|
||||||
"""download_and_extract_content returning an error is counted as failed."""
|
"""download_and_extract_content returning an error is counted as failed."""
|
||||||
files = [_make_file_dict("f0", "good.txt"), _make_file_dict("f1", "bad.txt")]
|
files = [_make_file_dict("f0", "good.txt"), _make_file_dict("f1", "bad.txt")]
|
||||||
|
|
@ -148,7 +154,8 @@ async def test_etl_error_counts_as_download_failure(
|
||||||
|
|
||||||
|
|
||||||
async def test_concurrency_bounded_by_semaphore(
|
async def test_concurrency_bounded_by_semaphore(
|
||||||
mock_drive_client, monkeypatch,
|
mock_drive_client,
|
||||||
|
monkeypatch,
|
||||||
):
|
):
|
||||||
"""Peak concurrent downloads never exceeds max_concurrency."""
|
"""Peak concurrent downloads never exceeds max_concurrency."""
|
||||||
lock = asyncio.Lock()
|
lock = asyncio.Lock()
|
||||||
|
|
@ -189,7 +196,8 @@ async def test_concurrency_bounded_by_semaphore(
|
||||||
|
|
||||||
|
|
||||||
async def test_heartbeat_fires_during_parallel_downloads(
|
async def test_heartbeat_fires_during_parallel_downloads(
|
||||||
mock_drive_client, monkeypatch,
|
mock_drive_client,
|
||||||
|
monkeypatch,
|
||||||
):
|
):
|
||||||
"""on_heartbeat is called at least once when downloads take time."""
|
"""on_heartbeat is called at least once when downloads take time."""
|
||||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||||
|
|
@ -231,8 +239,13 @@ async def test_heartbeat_fires_during_parallel_downloads(
|
||||||
# Slice 6, 6b, 6c -- _index_full_scan three-phase pipeline
|
# Slice 6, 6b, 6c -- _index_full_scan three-phase pipeline
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _folder_dict(file_id: str, name: str) -> dict:
|
def _folder_dict(file_id: str, name: str) -> dict:
|
||||||
return {"id": file_id, "name": name, "mimeType": "application/vnd.google-apps.folder"}
|
return {
|
||||||
|
"id": file_id,
|
||||||
|
"name": name,
|
||||||
|
"mimeType": "application/vnd.google-apps.folder",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -259,12 +272,17 @@ def full_scan_mocks(mock_drive_client, monkeypatch):
|
||||||
batch_mock = AsyncMock(return_value=([], 0, 0))
|
batch_mock = AsyncMock(return_value=([], 0, 0))
|
||||||
pipeline_mock = MagicMock()
|
pipeline_mock = MagicMock()
|
||||||
pipeline_mock.index_batch_parallel = batch_mock
|
pipeline_mock.index_batch_parallel = batch_mock
|
||||||
|
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
|
_mod,
|
||||||
|
"IndexingPipelineService",
|
||||||
|
MagicMock(return_value=pipeline_mock),
|
||||||
)
|
)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock()),
|
_mod,
|
||||||
|
"get_user_long_context_llm",
|
||||||
|
AsyncMock(return_value=MagicMock()),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -312,12 +330,16 @@ async def test_full_scan_three_phase_counts(full_scan_mocks, monkeypatch):
|
||||||
]
|
]
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "get_files_in_folder",
|
_mod,
|
||||||
|
"get_files_in_folder",
|
||||||
AsyncMock(return_value=(page_files, None, None)),
|
AsyncMock(return_value=(page_files, None, None)),
|
||||||
)
|
)
|
||||||
|
|
||||||
full_scan_mocks["skip_results"]["skip1"] = (True, "unchanged")
|
full_scan_mocks["skip_results"]["skip1"] = (True, "unchanged")
|
||||||
full_scan_mocks["skip_results"]["rename1"] = (True, "File renamed: 'old' → 'renamed.txt'")
|
full_scan_mocks["skip_results"]["rename1"] = (
|
||||||
|
True,
|
||||||
|
"File renamed: 'old' → 'renamed.txt'",
|
||||||
|
)
|
||||||
|
|
||||||
mock_docs = [MagicMock(), MagicMock()]
|
mock_docs = [MagicMock(), MagicMock()]
|
||||||
full_scan_mocks["download_mock"].return_value = (mock_docs, 0)
|
full_scan_mocks["download_mock"].return_value = (mock_docs, 0)
|
||||||
|
|
@ -341,7 +363,8 @@ async def test_full_scan_respects_max_files(full_scan_mocks, monkeypatch):
|
||||||
page_files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(10)]
|
page_files = [_make_file_dict(f"f{i}", f"file{i}.txt") for i in range(10)]
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "get_files_in_folder",
|
_mod,
|
||||||
|
"get_files_in_folder",
|
||||||
AsyncMock(return_value=(page_files, None, None)),
|
AsyncMock(return_value=(page_files, None, None)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -355,14 +378,16 @@ async def test_full_scan_respects_max_files(full_scan_mocks, monkeypatch):
|
||||||
|
|
||||||
|
|
||||||
async def test_full_scan_uses_max_concurrency_3_for_indexing(
|
async def test_full_scan_uses_max_concurrency_3_for_indexing(
|
||||||
full_scan_mocks, monkeypatch,
|
full_scan_mocks,
|
||||||
|
monkeypatch,
|
||||||
):
|
):
|
||||||
"""index_batch_parallel is called with max_concurrency=3."""
|
"""index_batch_parallel is called with max_concurrency=3."""
|
||||||
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
import app.tasks.connector_indexers.google_drive_indexer as _mod
|
||||||
|
|
||||||
page_files = [_make_file_dict("f1", "file1.txt")]
|
page_files = [_make_file_dict("f1", "file1.txt")]
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "get_files_in_folder",
|
_mod,
|
||||||
|
"get_files_in_folder",
|
||||||
AsyncMock(return_value=(page_files, None, None)),
|
AsyncMock(return_value=(page_files, None, None)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -382,6 +407,7 @@ async def test_full_scan_uses_max_concurrency_3_for_indexing(
|
||||||
# Slice 7 -- _index_with_delta_sync three-phase pipeline
|
# Slice 7 -- _index_with_delta_sync three-phase pipeline
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
|
async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
|
||||||
"""Removed/trashed changes call _remove_document; the rest go through
|
"""Removed/trashed changes call _remove_document; the rest go through
|
||||||
_download_files_parallel and index_batch_parallel."""
|
_download_files_parallel and index_batch_parallel."""
|
||||||
|
|
@ -396,7 +422,8 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
|
||||||
]
|
]
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "fetch_all_changes",
|
_mod,
|
||||||
|
"fetch_all_changes",
|
||||||
AsyncMock(return_value=(changes, "new-token", None)),
|
AsyncMock(return_value=(changes, "new-token", None)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -408,7 +435,8 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
|
||||||
"mod2": "modified",
|
"mod2": "modified",
|
||||||
}
|
}
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "categorize_change",
|
_mod,
|
||||||
|
"categorize_change",
|
||||||
lambda change: change_types[change["fileId"]],
|
lambda change: change_types[change["fileId"]],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -420,7 +448,8 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
|
||||||
monkeypatch.setattr(_mod, "_remove_document", _fake_remove)
|
monkeypatch.setattr(_mod, "_remove_document", _fake_remove)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "_should_skip_file",
|
_mod,
|
||||||
|
"_should_skip_file",
|
||||||
AsyncMock(return_value=(False, None)),
|
AsyncMock(return_value=(False, None)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -431,11 +460,16 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
|
||||||
batch_mock = AsyncMock(return_value=([], 2, 0))
|
batch_mock = AsyncMock(return_value=([], 2, 0))
|
||||||
pipeline_mock = MagicMock()
|
pipeline_mock = MagicMock()
|
||||||
pipeline_mock.index_batch_parallel = batch_mock
|
pipeline_mock.index_batch_parallel = batch_mock
|
||||||
|
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
|
_mod,
|
||||||
|
"IndexingPipelineService",
|
||||||
|
MagicMock(return_value=pipeline_mock),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "get_user_long_context_llm", AsyncMock(return_value=MagicMock()),
|
_mod,
|
||||||
|
"get_user_long_context_llm",
|
||||||
|
AsyncMock(return_value=MagicMock()),
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
|
|
@ -472,6 +506,7 @@ async def test_delta_sync_removals_serial_rest_parallel(monkeypatch):
|
||||||
# _index_selected_files -- parallel indexing of user-selected files
|
# _index_selected_files -- parallel indexing of user-selected files
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def selected_files_mocks(mock_drive_client, monkeypatch):
|
def selected_files_mocks(mock_drive_client, monkeypatch):
|
||||||
"""Wire up mocks for _index_selected_files tests."""
|
"""Wire up mocks for _index_selected_files tests."""
|
||||||
|
|
@ -496,6 +531,14 @@ def selected_files_mocks(mock_drive_client, monkeypatch):
|
||||||
download_and_index_mock = AsyncMock(return_value=(0, 0))
|
download_and_index_mock = AsyncMock(return_value=(0, 0))
|
||||||
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
|
monkeypatch.setattr(_mod, "_download_and_index", download_and_index_mock)
|
||||||
|
|
||||||
|
pipeline_mock = MagicMock()
|
||||||
|
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
_mod,
|
||||||
|
"IndexingPipelineService",
|
||||||
|
MagicMock(return_value=pipeline_mock),
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"drive_client": mock_drive_client,
|
"drive_client": mock_drive_client,
|
||||||
"session": mock_session,
|
"session": mock_session,
|
||||||
|
|
@ -526,7 +569,8 @@ async def test_selected_files_single_file_indexed(selected_files_mocks):
|
||||||
selected_files_mocks["download_and_index_mock"].return_value = (1, 0)
|
selected_files_mocks["download_and_index_mock"].return_value = (1, 0)
|
||||||
|
|
||||||
indexed, skipped, errors = await _run_selected(
|
indexed, skipped, errors = await _run_selected(
|
||||||
selected_files_mocks, [("f1", "report.pdf")],
|
selected_files_mocks,
|
||||||
|
[("f1", "report.pdf")],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert indexed == 1
|
assert indexed == 1
|
||||||
|
|
@ -538,11 +582,13 @@ async def test_selected_files_single_file_indexed(selected_files_mocks):
|
||||||
async def test_selected_files_fetch_failure_isolation(selected_files_mocks):
|
async def test_selected_files_fetch_failure_isolation(selected_files_mocks):
|
||||||
"""get_file_by_id failing for one file collects an error; others still indexed."""
|
"""get_file_by_id failing for one file collects an error; others still indexed."""
|
||||||
selected_files_mocks["get_file_results"]["f1"] = (
|
selected_files_mocks["get_file_results"]["f1"] = (
|
||||||
_make_file_dict("f1", "first.txt"), None,
|
_make_file_dict("f1", "first.txt"),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
selected_files_mocks["get_file_results"]["f2"] = (None, "HTTP 404")
|
selected_files_mocks["get_file_results"]["f2"] = (None, "HTTP 404")
|
||||||
selected_files_mocks["get_file_results"]["f3"] = (
|
selected_files_mocks["get_file_results"]["f3"] = (
|
||||||
_make_file_dict("f3", "third.txt"), None,
|
_make_file_dict("f3", "third.txt"),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
|
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
|
||||||
|
|
||||||
|
|
@ -561,30 +607,46 @@ async def test_selected_files_fetch_failure_isolation(selected_files_mocks):
|
||||||
async def test_selected_files_skip_rename_counting(selected_files_mocks):
|
async def test_selected_files_skip_rename_counting(selected_files_mocks):
|
||||||
"""Unchanged files are skipped, renames counted as indexed,
|
"""Unchanged files are skipped, renames counted as indexed,
|
||||||
and only new files are sent to _download_and_index."""
|
and only new files are sent to _download_and_index."""
|
||||||
for fid, fname in [("s1", "unchanged.txt"), ("r1", "renamed.txt"),
|
for fid, fname in [
|
||||||
("n1", "new1.txt"), ("n2", "new2.txt")]:
|
("s1", "unchanged.txt"),
|
||||||
|
("r1", "renamed.txt"),
|
||||||
|
("n1", "new1.txt"),
|
||||||
|
("n2", "new2.txt"),
|
||||||
|
]:
|
||||||
selected_files_mocks["get_file_results"][fid] = (
|
selected_files_mocks["get_file_results"][fid] = (
|
||||||
_make_file_dict(fid, fname), None,
|
_make_file_dict(fid, fname),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
selected_files_mocks["skip_results"]["s1"] = (True, "unchanged")
|
selected_files_mocks["skip_results"]["s1"] = (True, "unchanged")
|
||||||
selected_files_mocks["skip_results"]["r1"] = (True, "File renamed: 'old' \u2192 'renamed.txt'")
|
selected_files_mocks["skip_results"]["r1"] = (
|
||||||
|
True,
|
||||||
|
"File renamed: 'old' \u2192 'renamed.txt'",
|
||||||
|
)
|
||||||
|
|
||||||
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
|
selected_files_mocks["download_and_index_mock"].return_value = (2, 0)
|
||||||
|
|
||||||
indexed, skipped, errors = await _run_selected(
|
indexed, skipped, errors = await _run_selected(
|
||||||
selected_files_mocks,
|
selected_files_mocks,
|
||||||
[("s1", "unchanged.txt"), ("r1", "renamed.txt"),
|
[
|
||||||
("n1", "new1.txt"), ("n2", "new2.txt")],
|
("s1", "unchanged.txt"),
|
||||||
|
("r1", "renamed.txt"),
|
||||||
|
("n1", "new1.txt"),
|
||||||
|
("n2", "new2.txt"),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert indexed == 3 # 1 renamed + 2 batch
|
assert indexed == 3 # 1 renamed + 2 batch
|
||||||
assert skipped == 1 # 1 unchanged
|
assert skipped == 1 # 1 unchanged
|
||||||
assert errors == []
|
assert errors == []
|
||||||
|
|
||||||
mock = selected_files_mocks["download_and_index_mock"]
|
mock = selected_files_mocks["download_and_index_mock"]
|
||||||
mock.assert_called_once()
|
mock.assert_called_once()
|
||||||
call_files = mock.call_args[1].get("files") if "files" in (mock.call_args[1] or {}) else mock.call_args[0][2]
|
call_files = (
|
||||||
|
mock.call_args[1].get("files")
|
||||||
|
if "files" in (mock.call_args[1] or {})
|
||||||
|
else mock.call_args[0][2]
|
||||||
|
)
|
||||||
assert len(call_files) == 2
|
assert len(call_files) == 2
|
||||||
assert {f["id"] for f in call_files} == {"n1", "n2"}
|
assert {f["id"] for f in call_files} == {"n1", "n2"}
|
||||||
|
|
||||||
|
|
@ -593,6 +655,7 @@ async def test_selected_files_skip_rename_counting(selected_files_mocks):
|
||||||
# asyncio.to_thread verification — prove blocking calls run in parallel
|
# asyncio.to_thread verification — prove blocking calls run in parallel
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
async def test_client_download_file_runs_in_thread_parallel():
|
async def test_client_download_file_runs_in_thread_parallel():
|
||||||
"""Calling download_file concurrently via asyncio.gather should overlap
|
"""Calling download_file concurrently via asyncio.gather should overlap
|
||||||
blocking work on separate threads, proving to_thread is effective.
|
blocking work on separate threads, proving to_thread is effective.
|
||||||
|
|
@ -602,11 +665,11 @@ async def test_client_download_file_runs_in_thread_parallel():
|
||||||
"""
|
"""
|
||||||
from app.connectors.google_drive.client import GoogleDriveClient
|
from app.connectors.google_drive.client import GoogleDriveClient
|
||||||
|
|
||||||
BLOCK_SECONDS = 0.2
|
block_seconds = 0.2
|
||||||
NUM_CALLS = 3
|
num_calls = 3
|
||||||
|
|
||||||
def _blocking_download(service, file_id, credentials):
|
def _blocking_download(service, file_id, credentials):
|
||||||
time.sleep(BLOCK_SECONDS)
|
time.sleep(block_seconds)
|
||||||
return b"fake-content", None
|
return b"fake-content", None
|
||||||
|
|
||||||
client = GoogleDriveClient.__new__(GoogleDriveClient)
|
client = GoogleDriveClient.__new__(GoogleDriveClient)
|
||||||
|
|
@ -615,11 +678,13 @@ async def test_client_download_file_runs_in_thread_parallel():
|
||||||
client._service_lock = asyncio.Lock()
|
client._service_lock = asyncio.Lock()
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(
|
||||||
GoogleDriveClient, "_sync_download_file", staticmethod(_blocking_download),
|
GoogleDriveClient,
|
||||||
|
"_sync_download_file",
|
||||||
|
staticmethod(_blocking_download),
|
||||||
):
|
):
|
||||||
start = time.monotonic()
|
start = time.monotonic()
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*(client.download_file(f"file-{i}") for i in range(NUM_CALLS))
|
*(client.download_file(f"file-{i}") for i in range(num_calls))
|
||||||
)
|
)
|
||||||
elapsed = time.monotonic() - start
|
elapsed = time.monotonic() - start
|
||||||
|
|
||||||
|
|
@ -627,7 +692,7 @@ async def test_client_download_file_runs_in_thread_parallel():
|
||||||
assert content == b"fake-content"
|
assert content == b"fake-content"
|
||||||
assert error is None
|
assert error is None
|
||||||
|
|
||||||
serial_minimum = BLOCK_SECONDS * NUM_CALLS
|
serial_minimum = block_seconds * num_calls
|
||||||
assert elapsed < serial_minimum, (
|
assert elapsed < serial_minimum, (
|
||||||
f"Elapsed {elapsed:.2f}s >= serial minimum {serial_minimum:.2f}s — "
|
f"Elapsed {elapsed:.2f}s >= serial minimum {serial_minimum:.2f}s — "
|
||||||
f"downloads are not running in parallel"
|
f"downloads are not running in parallel"
|
||||||
|
|
@ -638,11 +703,11 @@ async def test_client_export_google_file_runs_in_thread_parallel():
|
||||||
"""Same strategy for export_google_file — verify to_thread parallelism."""
|
"""Same strategy for export_google_file — verify to_thread parallelism."""
|
||||||
from app.connectors.google_drive.client import GoogleDriveClient
|
from app.connectors.google_drive.client import GoogleDriveClient
|
||||||
|
|
||||||
BLOCK_SECONDS = 0.2
|
block_seconds = 0.2
|
||||||
NUM_CALLS = 3
|
num_calls = 3
|
||||||
|
|
||||||
def _blocking_export(service, file_id, mime_type, credentials):
|
def _blocking_export(service, file_id, mime_type, credentials):
|
||||||
time.sleep(BLOCK_SECONDS)
|
time.sleep(block_seconds)
|
||||||
return b"exported", None
|
return b"exported", None
|
||||||
|
|
||||||
client = GoogleDriveClient.__new__(GoogleDriveClient)
|
client = GoogleDriveClient.__new__(GoogleDriveClient)
|
||||||
|
|
@ -651,12 +716,16 @@ async def test_client_export_google_file_runs_in_thread_parallel():
|
||||||
client._service_lock = asyncio.Lock()
|
client._service_lock = asyncio.Lock()
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(
|
||||||
GoogleDriveClient, "_sync_export_google_file", staticmethod(_blocking_export),
|
GoogleDriveClient,
|
||||||
|
"_sync_export_google_file",
|
||||||
|
staticmethod(_blocking_export),
|
||||||
):
|
):
|
||||||
start = time.monotonic()
|
start = time.monotonic()
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*(client.export_google_file(f"file-{i}", "application/pdf")
|
*(
|
||||||
for i in range(NUM_CALLS))
|
client.export_google_file(f"file-{i}", "application/pdf")
|
||||||
|
for i in range(num_calls)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
elapsed = time.monotonic() - start
|
elapsed = time.monotonic() - start
|
||||||
|
|
||||||
|
|
@ -664,7 +733,7 @@ async def test_client_export_google_file_runs_in_thread_parallel():
|
||||||
assert content == b"exported"
|
assert content == b"exported"
|
||||||
assert error is None
|
assert error is None
|
||||||
|
|
||||||
serial_minimum = BLOCK_SECONDS * NUM_CALLS
|
serial_minimum = block_seconds * num_calls
|
||||||
assert elapsed < serial_minimum, (
|
assert elapsed < serial_minimum, (
|
||||||
f"Elapsed {elapsed:.2f}s >= serial minimum {serial_minimum:.2f}s — "
|
f"Elapsed {elapsed:.2f}s >= serial minimum {serial_minimum:.2f}s — "
|
||||||
f"exports are not running in parallel"
|
f"exports are not running in parallel"
|
||||||
|
|
|
||||||
|
|
@ -145,22 +145,32 @@ def jira_mocks(monkeypatch):
|
||||||
|
|
||||||
mock_connector = _mock_connector()
|
mock_connector = _mock_connector()
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
|
_mod,
|
||||||
|
"get_connector_by_id",
|
||||||
|
AsyncMock(return_value=mock_connector),
|
||||||
)
|
)
|
||||||
|
|
||||||
jira_client = _mock_jira_client(issues=[_make_issue()])
|
jira_client = _mock_jira_client(issues=[_make_issue()])
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "JiraHistoryConnector", MagicMock(return_value=jira_client),
|
_mod,
|
||||||
|
"JiraHistoryConnector",
|
||||||
|
MagicMock(return_value=jira_client),
|
||||||
)
|
)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
|
_mod,
|
||||||
|
"check_duplicate_document_by_hash",
|
||||||
|
AsyncMock(return_value=None),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "update_connector_last_indexed", AsyncMock(),
|
_mod,
|
||||||
|
"update_connector_last_indexed",
|
||||||
|
AsyncMock(),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
|
_mod,
|
||||||
|
"calculate_date_range",
|
||||||
|
MagicMock(return_value=("2025-01-01", "2025-12-31")),
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_task_logger = MagicMock()
|
mock_task_logger = MagicMock()
|
||||||
|
|
@ -169,15 +179,20 @@ def jira_mocks(monkeypatch):
|
||||||
mock_task_logger.log_task_success = AsyncMock()
|
mock_task_logger.log_task_success = AsyncMock()
|
||||||
mock_task_logger.log_task_failure = AsyncMock()
|
mock_task_logger.log_task_failure = AsyncMock()
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
|
_mod,
|
||||||
|
"TaskLoggingService",
|
||||||
|
MagicMock(return_value=mock_task_logger),
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_mock = AsyncMock(return_value=([], 1, 0))
|
batch_mock = AsyncMock(return_value=([], 1, 0))
|
||||||
pipeline_mock = MagicMock()
|
pipeline_mock = MagicMock()
|
||||||
pipeline_mock.index_batch_parallel = batch_mock
|
pipeline_mock.index_batch_parallel = batch_mock
|
||||||
pipeline_mock.migrate_legacy_docs = AsyncMock()
|
pipeline_mock.migrate_legacy_docs = AsyncMock()
|
||||||
|
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
|
_mod,
|
||||||
|
"IndexingPipelineService",
|
||||||
|
MagicMock(return_value=pipeline_mock),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -128,13 +128,17 @@ def _mock_linear_client(issues=None, error=None):
|
||||||
client.get_issues_by_date_range = AsyncMock(
|
client.get_issues_by_date_range = AsyncMock(
|
||||||
return_value=(issues if issues is not None else [], error),
|
return_value=(issues if issues is not None else [], error),
|
||||||
)
|
)
|
||||||
client.format_issue = MagicMock(side_effect=lambda i: _make_formatted_issue(
|
client.format_issue = MagicMock(
|
||||||
issue_id=i.get("id", ""),
|
side_effect=lambda i: _make_formatted_issue(
|
||||||
identifier=i.get("identifier", ""),
|
issue_id=i.get("id", ""),
|
||||||
title=i.get("title", ""),
|
identifier=i.get("identifier", ""),
|
||||||
))
|
title=i.get("title", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
client.format_issue_to_markdown = MagicMock(
|
client.format_issue_to_markdown = MagicMock(
|
||||||
side_effect=lambda fi: f"# {fi.get('identifier', '')}: {fi.get('title', '')}\n\nContent"
|
side_effect=lambda fi: (
|
||||||
|
f"# {fi.get('identifier', '')}: {fi.get('title', '')}\n\nContent"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
@ -147,24 +151,34 @@ def linear_mocks(monkeypatch):
|
||||||
|
|
||||||
mock_connector = _mock_connector()
|
mock_connector = _mock_connector()
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
|
_mod,
|
||||||
|
"get_connector_by_id",
|
||||||
|
AsyncMock(return_value=mock_connector),
|
||||||
)
|
)
|
||||||
|
|
||||||
linear_client = _mock_linear_client(issues=[_make_issue()])
|
linear_client = _mock_linear_client(issues=[_make_issue()])
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "LinearConnector", MagicMock(return_value=linear_client),
|
_mod,
|
||||||
|
"LinearConnector",
|
||||||
|
MagicMock(return_value=linear_client),
|
||||||
)
|
)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
|
_mod,
|
||||||
|
"check_duplicate_document_by_hash",
|
||||||
|
AsyncMock(return_value=None),
|
||||||
)
|
)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "update_connector_last_indexed", AsyncMock(),
|
_mod,
|
||||||
|
"update_connector_last_indexed",
|
||||||
|
AsyncMock(),
|
||||||
)
|
)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
|
_mod,
|
||||||
|
"calculate_date_range",
|
||||||
|
MagicMock(return_value=("2025-01-01", "2025-12-31")),
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_task_logger = MagicMock()
|
mock_task_logger = MagicMock()
|
||||||
|
|
@ -173,15 +187,20 @@ def linear_mocks(monkeypatch):
|
||||||
mock_task_logger.log_task_success = AsyncMock()
|
mock_task_logger.log_task_success = AsyncMock()
|
||||||
mock_task_logger.log_task_failure = AsyncMock()
|
mock_task_logger.log_task_failure = AsyncMock()
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
|
_mod,
|
||||||
|
"TaskLoggingService",
|
||||||
|
MagicMock(return_value=mock_task_logger),
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_mock = AsyncMock(return_value=([], 1, 0))
|
batch_mock = AsyncMock(return_value=([], 1, 0))
|
||||||
pipeline_mock = MagicMock()
|
pipeline_mock = MagicMock()
|
||||||
pipeline_mock.index_batch_parallel = batch_mock
|
pipeline_mock.index_batch_parallel = batch_mock
|
||||||
pipeline_mock.migrate_legacy_docs = AsyncMock()
|
pipeline_mock.migrate_legacy_docs = AsyncMock()
|
||||||
|
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
|
_mod,
|
||||||
|
"IndexingPipelineService",
|
||||||
|
MagicMock(return_value=pipeline_mock),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -255,7 +274,7 @@ async def test_issues_with_missing_id_are_skipped(linear_mocks):
|
||||||
]
|
]
|
||||||
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
|
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
|
||||||
|
|
||||||
indexed, skipped, _ = await _run_index(linear_mocks)
|
_indexed, skipped, _ = await _run_index(linear_mocks)
|
||||||
|
|
||||||
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
|
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
|
||||||
assert len(connector_docs) == 1
|
assert len(connector_docs) == 1
|
||||||
|
|
@ -271,7 +290,7 @@ async def test_issues_with_missing_title_are_skipped(linear_mocks):
|
||||||
]
|
]
|
||||||
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
|
linear_mocks["linear_client"].get_issues_by_date_range.return_value = (issues, None)
|
||||||
|
|
||||||
indexed, skipped, _ = await _run_index(linear_mocks)
|
_indexed, skipped, _ = await _run_index(linear_mocks)
|
||||||
|
|
||||||
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
|
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
|
||||||
assert len(connector_docs) == 1
|
assert len(connector_docs) == 1
|
||||||
|
|
@ -305,7 +324,7 @@ async def test_duplicate_content_issues_are_skipped(linear_mocks, monkeypatch):
|
||||||
|
|
||||||
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
|
monkeypatch.setattr(_mod, "check_duplicate_document_by_hash", _check_dup)
|
||||||
|
|
||||||
indexed, skipped, _ = await _run_index(linear_mocks)
|
_indexed, skipped, _ = await _run_index(linear_mocks)
|
||||||
|
|
||||||
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
|
connector_docs = linear_mocks["batch_mock"].call_args[0][0]
|
||||||
assert len(connector_docs) == 1
|
assert len(connector_docs) == 1
|
||||||
|
|
|
||||||
|
|
@ -107,28 +107,40 @@ def notion_mocks(monkeypatch):
|
||||||
|
|
||||||
mock_connector = _mock_connector()
|
mock_connector = _mock_connector()
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "get_connector_by_id", AsyncMock(return_value=mock_connector),
|
_mod,
|
||||||
|
"get_connector_by_id",
|
||||||
|
AsyncMock(return_value=mock_connector),
|
||||||
)
|
)
|
||||||
|
|
||||||
notion_client = _mock_notion_client(pages=[_make_page()])
|
notion_client = _mock_notion_client(pages=[_make_page()])
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "NotionHistoryConnector", MagicMock(return_value=notion_client),
|
_mod,
|
||||||
|
"NotionHistoryConnector",
|
||||||
|
MagicMock(return_value=notion_client),
|
||||||
)
|
)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "check_duplicate_document_by_hash", AsyncMock(return_value=None),
|
_mod,
|
||||||
|
"check_duplicate_document_by_hash",
|
||||||
|
AsyncMock(return_value=None),
|
||||||
)
|
)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "update_connector_last_indexed", AsyncMock(),
|
_mod,
|
||||||
|
"update_connector_last_indexed",
|
||||||
|
AsyncMock(),
|
||||||
)
|
)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "calculate_date_range", MagicMock(return_value=("2025-01-01", "2025-12-31")),
|
_mod,
|
||||||
|
"calculate_date_range",
|
||||||
|
MagicMock(return_value=("2025-01-01", "2025-12-31")),
|
||||||
)
|
)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "process_blocks", MagicMock(return_value="Converted markdown content"),
|
_mod,
|
||||||
|
"process_blocks",
|
||||||
|
MagicMock(return_value="Converted markdown content"),
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_task_logger = MagicMock()
|
mock_task_logger = MagicMock()
|
||||||
|
|
@ -137,15 +149,20 @@ def notion_mocks(monkeypatch):
|
||||||
mock_task_logger.log_task_success = AsyncMock()
|
mock_task_logger.log_task_success = AsyncMock()
|
||||||
mock_task_logger.log_task_failure = AsyncMock()
|
mock_task_logger.log_task_failure = AsyncMock()
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "TaskLoggingService", MagicMock(return_value=mock_task_logger),
|
_mod,
|
||||||
|
"TaskLoggingService",
|
||||||
|
MagicMock(return_value=mock_task_logger),
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_mock = AsyncMock(return_value=([], 1, 0))
|
batch_mock = AsyncMock(return_value=([], 1, 0))
|
||||||
pipeline_mock = MagicMock()
|
pipeline_mock = MagicMock()
|
||||||
pipeline_mock.index_batch_parallel = batch_mock
|
pipeline_mock.index_batch_parallel = batch_mock
|
||||||
pipeline_mock.migrate_legacy_docs = AsyncMock()
|
pipeline_mock.migrate_legacy_docs = AsyncMock()
|
||||||
|
pipeline_mock.create_placeholder_documents = AsyncMock(return_value=0)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
_mod, "IndexingPipelineService", MagicMock(return_value=pipeline_mock),
|
_mod,
|
||||||
|
"IndexingPipelineService",
|
||||||
|
MagicMock(return_value=pipeline_mock),
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -216,7 +233,10 @@ async def test_pages_with_missing_id_are_skipped(notion_mocks, monkeypatch):
|
||||||
"""Pages without page_id are skipped and not passed to the pipeline."""
|
"""Pages without page_id are skipped and not passed to the pipeline."""
|
||||||
pages = [
|
pages = [
|
||||||
_make_page(page_id="valid-1"),
|
_make_page(page_id="valid-1"),
|
||||||
{"title": "No ID page", "content": [{"type": "paragraph", "content": "text", "children": []}]},
|
{
|
||||||
|
"title": "No ID page",
|
||||||
|
"content": [{"type": "paragraph", "content": "text", "children": []}],
|
||||||
|
},
|
||||||
]
|
]
|
||||||
notion_mocks["notion_client"].get_all_pages.return_value = pages
|
notion_mocks["notion_client"].get_all_pages.return_value = pages
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,131 @@
|
||||||
|
"""Unit tests for IndexingPipelineService.create_placeholder_documents."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
|
||||||
|
from app.db import DocumentStatus, DocumentType
|
||||||
|
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||||
|
from app.indexing_pipeline.indexing_pipeline_service import (
|
||||||
|
IndexingPipelineService,
|
||||||
|
PlaceholderInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_placeholder(**overrides) -> PlaceholderInfo:
|
||||||
|
defaults = {
|
||||||
|
"title": "Test Doc",
|
||||||
|
"document_type": DocumentType.GOOGLE_DRIVE_FILE,
|
||||||
|
"unique_id": "file-001",
|
||||||
|
"search_space_id": 1,
|
||||||
|
"connector_id": 42,
|
||||||
|
"created_by_id": "00000000-0000-0000-0000-000000000001",
|
||||||
|
}
|
||||||
|
defaults.update(overrides)
|
||||||
|
return PlaceholderInfo(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
def _uid_hash(p: PlaceholderInfo) -> str:
|
||||||
|
return compute_identifier_hash(
|
||||||
|
p.document_type.value, p.unique_id, p.search_space_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _session_with_existing_hashes(existing: set[str] | None = None):
|
||||||
|
"""Build an AsyncMock session whose batch-query returns *existing* hashes."""
|
||||||
|
session = AsyncMock()
|
||||||
|
result = MagicMock()
|
||||||
|
result.scalars.return_value.all.return_value = list(existing or [])
|
||||||
|
session.execute = AsyncMock(return_value=result)
|
||||||
|
session.add = MagicMock()
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_empty_input_returns_zero_without_db_calls():
|
||||||
|
session = AsyncMock()
|
||||||
|
pipeline = IndexingPipelineService(session)
|
||||||
|
|
||||||
|
result = await pipeline.create_placeholder_documents([])
|
||||||
|
|
||||||
|
assert result == 0
|
||||||
|
session.execute.assert_not_awaited()
|
||||||
|
session.commit.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_creates_documents_with_pending_status_and_commits():
|
||||||
|
session = _session_with_existing_hashes(set())
|
||||||
|
pipeline = IndexingPipelineService(session)
|
||||||
|
p = _make_placeholder(title="My File", unique_id="file-abc")
|
||||||
|
|
||||||
|
result = await pipeline.create_placeholder_documents([p])
|
||||||
|
|
||||||
|
assert result == 1
|
||||||
|
session.add.assert_called_once()
|
||||||
|
|
||||||
|
doc = session.add.call_args[0][0]
|
||||||
|
assert doc.title == "My File"
|
||||||
|
assert doc.document_type == DocumentType.GOOGLE_DRIVE_FILE
|
||||||
|
assert doc.content == "Pending..."
|
||||||
|
assert DocumentStatus.is_state(doc.status, DocumentStatus.PENDING)
|
||||||
|
assert doc.search_space_id == 1
|
||||||
|
assert doc.connector_id == 42
|
||||||
|
|
||||||
|
session.commit.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_existing_documents_are_skipped():
|
||||||
|
"""Placeholders whose unique_identifier_hash already exists are not re-created."""
|
||||||
|
existing_p = _make_placeholder(unique_id="already-there")
|
||||||
|
new_p = _make_placeholder(unique_id="brand-new")
|
||||||
|
|
||||||
|
existing_hash = _uid_hash(existing_p)
|
||||||
|
session = _session_with_existing_hashes({existing_hash})
|
||||||
|
pipeline = IndexingPipelineService(session)
|
||||||
|
|
||||||
|
result = await pipeline.create_placeholder_documents([existing_p, new_p])
|
||||||
|
|
||||||
|
assert result == 1
|
||||||
|
doc = session.add.call_args[0][0]
|
||||||
|
assert doc.unique_identifier_hash == _uid_hash(new_p)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_duplicate_unique_ids_within_input_are_deduped():
|
||||||
|
"""Same unique_id passed twice only produces one placeholder."""
|
||||||
|
p1 = _make_placeholder(unique_id="dup-id", title="First")
|
||||||
|
p2 = _make_placeholder(unique_id="dup-id", title="Second")
|
||||||
|
|
||||||
|
session = _session_with_existing_hashes(set())
|
||||||
|
pipeline = IndexingPipelineService(session)
|
||||||
|
|
||||||
|
result = await pipeline.create_placeholder_documents([p1, p2])
|
||||||
|
|
||||||
|
assert result == 1
|
||||||
|
session.add.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_integrity_error_on_commit_returns_zero():
|
||||||
|
"""IntegrityError during commit (race condition) is swallowed gracefully."""
|
||||||
|
session = _session_with_existing_hashes(set())
|
||||||
|
session.commit = AsyncMock(side_effect=IntegrityError("dup", {}, None))
|
||||||
|
pipeline = IndexingPipelineService(session)
|
||||||
|
p = _make_placeholder()
|
||||||
|
|
||||||
|
result = await pipeline.create_placeholder_documents([p])
|
||||||
|
|
||||||
|
assert result == 0
|
||||||
|
session.rollback.assert_awaited_once()
|
||||||
|
|
@ -19,9 +19,7 @@ def pipeline(mock_session):
|
||||||
return IndexingPipelineService(mock_session)
|
return IndexingPipelineService(mock_session)
|
||||||
|
|
||||||
|
|
||||||
async def test_calls_prepare_then_index_per_document(
|
async def test_calls_prepare_then_index_per_document(pipeline, make_connector_document):
|
||||||
pipeline, make_connector_document
|
|
||||||
):
|
|
||||||
"""index_batch calls prepare_for_indexing, then index() for each returned doc."""
|
"""index_batch calls prepare_for_indexing, then index() for each returned doc."""
|
||||||
doc1 = make_connector_document(
|
doc1 = make_connector_document(
|
||||||
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
document_type=DocumentType.GOOGLE_GMAIL_CONNECTOR,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
@ -57,7 +57,9 @@ async def test_index_calls_embed_and_chunk_via_to_thread(
|
||||||
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
|
"app.indexing_pipeline.indexing_pipeline_service.chunk_text",
|
||||||
mock_chunk,
|
mock_chunk,
|
||||||
)
|
)
|
||||||
mock_embed = MagicMock(side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts])
|
mock_embed = MagicMock(
|
||||||
|
side_effect=lambda texts: [[0.1] * _EMBEDDING_DIM for _ in texts]
|
||||||
|
)
|
||||||
mock_embed.__name__ = "embed_texts"
|
mock_embed.__name__ = "embed_texts"
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"app.indexing_pipeline.indexing_pipeline_service.embed_texts",
|
"app.indexing_pipeline.indexing_pipeline_service.embed_texts",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,110 @@
|
||||||
|
"""Unit tests for the duplicate-content safety logic in prepare_for_indexing.
|
||||||
|
|
||||||
|
Verifies that when an existing document's updated content matches another
|
||||||
|
document's content_hash, the system marks it as failed (for placeholders)
|
||||||
|
or leaves it untouched (for ready documents) — never deletes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.db import Document, DocumentStatus, DocumentType
|
||||||
|
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||||
|
from app.indexing_pipeline.document_hashing import (
|
||||||
|
compute_unique_identifier_hash,
|
||||||
|
)
|
||||||
|
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_connector_doc(**overrides) -> ConnectorDocument:
|
||||||
|
defaults = {
|
||||||
|
"title": "Test Doc",
|
||||||
|
"source_markdown": "## Some new content",
|
||||||
|
"unique_id": "file-001",
|
||||||
|
"document_type": DocumentType.GOOGLE_DRIVE_FILE,
|
||||||
|
"search_space_id": 1,
|
||||||
|
"connector_id": 42,
|
||||||
|
"created_by_id": "00000000-0000-0000-0000-000000000001",
|
||||||
|
}
|
||||||
|
defaults.update(overrides)
|
||||||
|
return ConnectorDocument(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_existing_doc(connector_doc: ConnectorDocument, *, status: dict) -> MagicMock:
|
||||||
|
"""Build a MagicMock that looks like an ORM Document with given status."""
|
||||||
|
doc = MagicMock(spec=Document)
|
||||||
|
doc.id = 999
|
||||||
|
doc.unique_identifier_hash = compute_unique_identifier_hash(connector_doc)
|
||||||
|
doc.content_hash = "old-placeholder-content-hash"
|
||||||
|
doc.title = connector_doc.title
|
||||||
|
doc.status = status
|
||||||
|
return doc
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_session_for_dedup(existing_doc, *, has_duplicate: bool):
|
||||||
|
"""Build a session whose sequential execute() calls return:
|
||||||
|
|
||||||
|
1. The *existing_doc* for the unique_identifier_hash lookup.
|
||||||
|
2. A row (or None) for the duplicate content_hash check.
|
||||||
|
"""
|
||||||
|
session = AsyncMock()
|
||||||
|
|
||||||
|
existing_result = MagicMock()
|
||||||
|
existing_result.scalars.return_value.first.return_value = existing_doc
|
||||||
|
|
||||||
|
dup_result = MagicMock()
|
||||||
|
dup_result.scalars.return_value.first.return_value = 42 if has_duplicate else None
|
||||||
|
|
||||||
|
session.execute = AsyncMock(side_effect=[existing_result, dup_result])
|
||||||
|
session.add = MagicMock()
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pending_placeholder_with_duplicate_content_is_marked_failed():
|
||||||
|
"""A placeholder (pending) whose updated content duplicates another doc
|
||||||
|
must be marked as FAILED — never deleted."""
|
||||||
|
cdoc = _make_connector_doc(source_markdown="## Shared content")
|
||||||
|
existing = _make_existing_doc(cdoc, status=DocumentStatus.pending())
|
||||||
|
|
||||||
|
session = _mock_session_for_dedup(existing, has_duplicate=True)
|
||||||
|
pipeline = IndexingPipelineService(session)
|
||||||
|
|
||||||
|
results = await pipeline.prepare_for_indexing([cdoc])
|
||||||
|
|
||||||
|
assert results == [], "duplicate should not be returned for indexing"
|
||||||
|
|
||||||
|
assert DocumentStatus.is_state(existing.status, DocumentStatus.FAILED)
|
||||||
|
assert "Duplicate content" in existing.status.get("reason", "")
|
||||||
|
session.delete.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_ready_document_with_duplicate_content_is_left_untouched():
|
||||||
|
"""A READY document whose updated content duplicates another doc
|
||||||
|
must be left completely untouched — not failed, not deleted."""
|
||||||
|
cdoc = _make_connector_doc(source_markdown="## Shared content")
|
||||||
|
existing = _make_existing_doc(cdoc, status=DocumentStatus.ready())
|
||||||
|
|
||||||
|
session = _mock_session_for_dedup(existing, has_duplicate=True)
|
||||||
|
pipeline = IndexingPipelineService(session)
|
||||||
|
|
||||||
|
results = await pipeline.prepare_for_indexing([cdoc])
|
||||||
|
|
||||||
|
assert results == [], "duplicate should not be returned for indexing"
|
||||||
|
|
||||||
|
assert DocumentStatus.is_state(existing.status, DocumentStatus.READY)
|
||||||
|
session.delete.assert_not_called()
|
||||||
133
surfsense_backend/tests/unit/middleware/test_knowledge_search.py
Normal file
133
surfsense_backend/tests/unit/middleware/test_knowledge_search.py
Normal file
|
|
@ -0,0 +1,133 @@
|
||||||
|
"""Unit tests for knowledge_search middleware helpers.
|
||||||
|
|
||||||
|
These test pure functions that don't require a database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.new_chat.middleware.knowledge_search import (
|
||||||
|
_build_document_xml,
|
||||||
|
_resolve_search_types,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
# ── _resolve_search_types ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolveSearchTypes:
|
||||||
|
def test_returns_none_when_no_inputs(self):
|
||||||
|
assert _resolve_search_types(None, None) is None
|
||||||
|
|
||||||
|
def test_returns_none_when_both_empty(self):
|
||||||
|
assert _resolve_search_types([], []) is None
|
||||||
|
|
||||||
|
def test_includes_legacy_type_for_google_gmail(self):
|
||||||
|
result = _resolve_search_types(["GOOGLE_GMAIL_CONNECTOR"], None)
|
||||||
|
assert "GOOGLE_GMAIL_CONNECTOR" in result
|
||||||
|
assert "COMPOSIO_GMAIL_CONNECTOR" in result
|
||||||
|
|
||||||
|
def test_includes_legacy_type_for_google_drive(self):
|
||||||
|
result = _resolve_search_types(None, ["GOOGLE_DRIVE_FILE"])
|
||||||
|
assert "GOOGLE_DRIVE_FILE" in result
|
||||||
|
assert "COMPOSIO_GOOGLE_DRIVE_CONNECTOR" in result
|
||||||
|
|
||||||
|
def test_includes_legacy_type_for_google_calendar(self):
|
||||||
|
result = _resolve_search_types(["GOOGLE_CALENDAR_CONNECTOR"], None)
|
||||||
|
assert "GOOGLE_CALENDAR_CONNECTOR" in result
|
||||||
|
assert "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR" in result
|
||||||
|
|
||||||
|
def test_no_legacy_expansion_for_unrelated_types(self):
|
||||||
|
result = _resolve_search_types(["FILE", "NOTE"], None)
|
||||||
|
assert set(result) == {"FILE", "NOTE"}
|
||||||
|
|
||||||
|
def test_combines_connectors_and_document_types(self):
|
||||||
|
result = _resolve_search_types(["FILE"], ["NOTE", "CRAWLED_URL"])
|
||||||
|
assert {"FILE", "NOTE", "CRAWLED_URL"}.issubset(set(result))
|
||||||
|
|
||||||
|
def test_deduplicates(self):
|
||||||
|
result = _resolve_search_types(["FILE", "FILE"], ["FILE"])
|
||||||
|
assert result.count("FILE") == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ── _build_document_xml ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildDocumentXml:
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_document(self):
|
||||||
|
return {
|
||||||
|
"document_id": 42,
|
||||||
|
"document": {
|
||||||
|
"id": 42,
|
||||||
|
"document_type": "FILE",
|
||||||
|
"title": "Test Doc",
|
||||||
|
"metadata": {"url": "https://example.com"},
|
||||||
|
},
|
||||||
|
"chunks": [
|
||||||
|
{"chunk_id": 101, "content": "First chunk content"},
|
||||||
|
{"chunk_id": 102, "content": "Second chunk content"},
|
||||||
|
{"chunk_id": 103, "content": "Third chunk content"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_contains_document_metadata(self, sample_document):
|
||||||
|
xml = _build_document_xml(sample_document)
|
||||||
|
assert "<document_id>42</document_id>" in xml
|
||||||
|
assert "<document_type>FILE</document_type>" in xml
|
||||||
|
assert "Test Doc" in xml
|
||||||
|
|
||||||
|
def test_contains_chunk_index(self, sample_document):
|
||||||
|
xml = _build_document_xml(sample_document)
|
||||||
|
assert "<chunk_index>" in xml
|
||||||
|
assert "</chunk_index>" in xml
|
||||||
|
assert 'chunk_id="101"' in xml
|
||||||
|
assert 'chunk_id="102"' in xml
|
||||||
|
assert 'chunk_id="103"' in xml
|
||||||
|
|
||||||
|
def test_matched_chunks_flagged_in_index(self, sample_document):
|
||||||
|
xml = _build_document_xml(sample_document, matched_chunk_ids={101, 103})
|
||||||
|
lines = xml.split("\n")
|
||||||
|
for line in lines:
|
||||||
|
if 'chunk_id="101"' in line:
|
||||||
|
assert 'matched="true"' in line
|
||||||
|
if 'chunk_id="102"' in line:
|
||||||
|
assert 'matched="true"' not in line
|
||||||
|
if 'chunk_id="103"' in line:
|
||||||
|
assert 'matched="true"' in line
|
||||||
|
|
||||||
|
def test_chunk_content_in_document_content_section(self, sample_document):
|
||||||
|
xml = _build_document_xml(sample_document)
|
||||||
|
assert "<document_content>" in xml
|
||||||
|
assert "First chunk content" in xml
|
||||||
|
assert "Second chunk content" in xml
|
||||||
|
assert "Third chunk content" in xml
|
||||||
|
|
||||||
|
def test_line_numbers_in_chunk_index_are_accurate(self, sample_document):
|
||||||
|
"""Verify that the line ranges in chunk_index actually point to the right content."""
|
||||||
|
xml = _build_document_xml(sample_document, matched_chunk_ids={101})
|
||||||
|
xml_lines = xml.split("\n")
|
||||||
|
|
||||||
|
for line in xml_lines:
|
||||||
|
if 'chunk_id="101"' in line and "lines=" in line:
|
||||||
|
import re
|
||||||
|
|
||||||
|
m = re.search(r'lines="(\d+)-(\d+)"', line)
|
||||||
|
assert m, f"No lines= attribute found in: {line}"
|
||||||
|
start, _end = int(m.group(1)), int(m.group(2))
|
||||||
|
target_line = xml_lines[start - 1]
|
||||||
|
assert "101" in target_line
|
||||||
|
assert "First chunk content" in target_line
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
pytest.fail("chunk_id=101 entry not found in chunk_index")
|
||||||
|
|
||||||
|
def test_splits_into_lines_correctly(self, sample_document):
|
||||||
|
"""Each chunk occupies exactly one line (no embedded newlines)."""
|
||||||
|
xml = _build_document_xml(sample_document)
|
||||||
|
lines = xml.split("\n")
|
||||||
|
chunk_lines = [
|
||||||
|
line for line in lines if "<![CDATA[" in line and "<chunk" in line
|
||||||
|
]
|
||||||
|
assert len(chunk_lines) == 3
|
||||||
8549
surfsense_backend/uv.lock
generated
8549
surfsense_backend/uv.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -30,6 +30,7 @@ import {
|
||||||
// extractWriteTodosFromContent,
|
// extractWriteTodosFromContent,
|
||||||
} from "@/atoms/chat/plan-state.atom";
|
} from "@/atoms/chat/plan-state.atom";
|
||||||
import { closeReportPanelAtom } from "@/atoms/chat/report-panel.atom";
|
import { closeReportPanelAtom } from "@/atoms/chat/report-panel.atom";
|
||||||
|
import { type AgentCreatedDocument, agentCreatedDocumentsAtom } from "@/atoms/documents/ui.atoms";
|
||||||
import { closeEditorPanelAtom } from "@/atoms/editor/editor-panel.atom";
|
import { closeEditorPanelAtom } from "@/atoms/editor/editor-panel.atom";
|
||||||
import { membersAtom } from "@/atoms/members/members-query.atoms";
|
import { membersAtom } from "@/atoms/members/members-query.atoms";
|
||||||
import { updateChatTabTitleAtom } from "@/atoms/tabs/tabs.atom";
|
import { updateChatTabTitleAtom } from "@/atoms/tabs/tabs.atom";
|
||||||
|
|
@ -191,6 +192,7 @@ export default function NewChatPage() {
|
||||||
const closeReportPanel = useSetAtom(closeReportPanelAtom);
|
const closeReportPanel = useSetAtom(closeReportPanelAtom);
|
||||||
const closeEditorPanel = useSetAtom(closeEditorPanelAtom);
|
const closeEditorPanel = useSetAtom(closeEditorPanelAtom);
|
||||||
const updateChatTabTitle = useSetAtom(updateChatTabTitleAtom);
|
const updateChatTabTitle = useSetAtom(updateChatTabTitleAtom);
|
||||||
|
const setAgentCreatedDocuments = useSetAtom(agentCreatedDocumentsAtom);
|
||||||
|
|
||||||
// Get current user for author info in shared chats
|
// Get current user for author info in shared chats
|
||||||
const { data: currentUser } = useAtomValue(currentUserAtom);
|
const { data: currentUser } = useAtomValue(currentUserAtom);
|
||||||
|
|
@ -740,6 +742,20 @@ export default function NewChatPage() {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "data-documents-updated": {
|
||||||
|
const docEvent = parsed.data as {
|
||||||
|
action: string;
|
||||||
|
document: AgentCreatedDocument;
|
||||||
|
};
|
||||||
|
if (docEvent?.document?.id) {
|
||||||
|
setAgentCreatedDocuments((prev) => {
|
||||||
|
if (prev.some((d) => d.id === docEvent.document.id)) return prev;
|
||||||
|
return [...prev, docEvent.document];
|
||||||
|
});
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
case "data-interrupt-request": {
|
case "data-interrupt-request": {
|
||||||
wasInterrupted = true;
|
wasInterrupted = true;
|
||||||
const interruptData = parsed.data as Record<string, unknown>;
|
const interruptData = parsed.data as Record<string, unknown>;
|
||||||
|
|
@ -1534,7 +1550,7 @@ export default function NewChatPage() {
|
||||||
// For new chats (urlChatId === 0), threadId being null is expected (lazy creation)
|
// For new chats (urlChatId === 0), threadId being null is expected (lazy creation)
|
||||||
if (!threadId && urlChatId > 0) {
|
if (!threadId && urlChatId > 0) {
|
||||||
return (
|
return (
|
||||||
<div className="flex h-[calc(100dvh-64px)] flex-col items-center justify-center gap-4">
|
<div className="flex h-full flex-col items-center justify-center gap-4">
|
||||||
<div className="text-destructive">Failed to load chat</div>
|
<div className="text-destructive">Failed to load chat</div>
|
||||||
<button
|
<button
|
||||||
type="button"
|
type="button"
|
||||||
|
|
@ -1553,7 +1569,7 @@ export default function NewChatPage() {
|
||||||
return (
|
return (
|
||||||
<AssistantRuntimeProvider runtime={runtime}>
|
<AssistantRuntimeProvider runtime={runtime}>
|
||||||
<ThinkingStepsDataUI />
|
<ThinkingStepsDataUI />
|
||||||
<div key={searchSpaceId} className="flex h-[calc(100dvh-64px)] overflow-hidden">
|
<div key={searchSpaceId} className="flex h-full overflow-hidden">
|
||||||
<div className="flex-1 flex flex-col min-w-0 overflow-hidden">
|
<div className="flex-1 flex flex-col min-w-0 overflow-hidden">
|
||||||
<Thread />
|
<Thread />
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import { Skeleton } from "@/components/ui/skeleton";
|
||||||
|
|
||||||
export default function Loading() {
|
export default function Loading() {
|
||||||
return (
|
return (
|
||||||
<div className="flex h-[calc(100dvh-64px)] flex-col bg-main-panel px-4">
|
<div className="flex h-full flex-col bg-main-panel px-4">
|
||||||
<div className="mx-auto w-full max-w-[44rem] flex flex-1 flex-col gap-6 py-8">
|
<div className="mx-auto w-full max-w-[44rem] flex flex-1 flex-col gap-6 py-8">
|
||||||
{/* User message */}
|
{/* User message */}
|
||||||
<div className="flex justify-end">
|
<div className="flex justify-end">
|
||||||
|
|
|
||||||
|
|
@ -7,3 +7,14 @@ export const globalDocumentsQueryParamsAtom = atom<GetDocumentsRequest["queryPar
|
||||||
});
|
});
|
||||||
|
|
||||||
export const documentsSidebarOpenAtom = atom(false);
|
export const documentsSidebarOpenAtom = atom(false);
|
||||||
|
|
||||||
|
export interface AgentCreatedDocument {
|
||||||
|
id: number;
|
||||||
|
title: string;
|
||||||
|
documentType: string;
|
||||||
|
searchSpaceId: number;
|
||||||
|
folderId: number | null;
|
||||||
|
createdById: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const agentCreatedDocumentsAtom = atom<AgentCreatedDocument[]>([]);
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ import { useAtomValue, useSetAtom } from "jotai";
|
||||||
import { AlertTriangle, Cable, Settings } from "lucide-react";
|
import { AlertTriangle, Cable, Settings } from "lucide-react";
|
||||||
import { forwardRef, useEffect, useImperativeHandle, useMemo, useState } from "react";
|
import { forwardRef, useEffect, useImperativeHandle, useMemo, useState } from "react";
|
||||||
import { createPortal } from "react-dom";
|
import { createPortal } from "react-dom";
|
||||||
import { useZeroDocumentTypeCounts } from "@/hooks/use-zero-document-type-counts";
|
|
||||||
import { statusInboxItemsAtom } from "@/atoms/inbox/status-inbox.atom";
|
import { statusInboxItemsAtom } from "@/atoms/inbox/status-inbox.atom";
|
||||||
import {
|
import {
|
||||||
globalNewLLMConfigsAtom,
|
globalNewLLMConfigsAtom,
|
||||||
|
|
@ -22,6 +21,7 @@ import { Tabs, TabsContent } from "@/components/ui/tabs";
|
||||||
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
||||||
import { useConnectorsSync } from "@/hooks/use-connectors-sync";
|
import { useConnectorsSync } from "@/hooks/use-connectors-sync";
|
||||||
import { PICKER_CLOSE_EVENT, PICKER_OPEN_EVENT } from "@/hooks/use-google-picker";
|
import { PICKER_CLOSE_EVENT, PICKER_OPEN_EVENT } from "@/hooks/use-google-picker";
|
||||||
|
import { useZeroDocumentTypeCounts } from "@/hooks/use-zero-document-type-counts";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { ConnectorDialogHeader } from "./connector-popup/components/connector-dialog-header";
|
import { ConnectorDialogHeader } from "./connector-popup/components/connector-dialog-header";
|
||||||
import { ConnectorConnectView } from "./connector-popup/connector-configs/views/connector-connect-view";
|
import { ConnectorConnectView } from "./connector-popup/connector-configs/views/connector-connect-view";
|
||||||
|
|
|
||||||
|
|
@ -421,7 +421,9 @@ const defaultComponents = memoizeMarkdownComponents({
|
||||||
<code
|
<code
|
||||||
className={cn("aui-md-inline-code rounded border bg-muted font-semibold", className)}
|
className={cn("aui-md-inline-code rounded border bg-muted font-semibold", className)}
|
||||||
{...props}
|
{...props}
|
||||||
/>
|
>
|
||||||
|
{children}
|
||||||
|
</code>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
const language = /language-(\w+)/.exec(className || "")?.[1] ?? "text";
|
const language = /language-(\w+)/.exec(className || "")?.[1] ?? "text";
|
||||||
|
|
|
||||||
|
|
@ -109,7 +109,7 @@ const ThreadContent: FC = () => {
|
||||||
>
|
>
|
||||||
<ThreadPrimitive.Viewport
|
<ThreadPrimitive.Viewport
|
||||||
turnAnchor="top"
|
turnAnchor="top"
|
||||||
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-auto px-4 pt-4"
|
className="aui-thread-viewport relative flex flex-1 min-h-0 flex-col overflow-y-scroll px-4 pt-4"
|
||||||
>
|
>
|
||||||
<AuiIf condition={({ thread }) => thread.isEmpty}>
|
<AuiIf condition={({ thread }) => thread.isEmpty}>
|
||||||
<ThreadWelcome />
|
<ThreadWelcome />
|
||||||
|
|
@ -1062,7 +1062,7 @@ interface ToolGroup {
|
||||||
const TOOL_GROUPS: ToolGroup[] = [
|
const TOOL_GROUPS: ToolGroup[] = [
|
||||||
{
|
{
|
||||||
label: "Research",
|
label: "Research",
|
||||||
tools: ["search_knowledge_base", "search_surfsense_docs", "scrape_webpage"],
|
tools: ["search_surfsense_docs", "scrape_webpage"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
label: "Generate",
|
label: "Generate",
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,9 @@ export function CreateFolderDialog({
|
||||||
|
|
||||||
<form onSubmit={handleSubmit} className="flex flex-col gap-3 sm:gap-4">
|
<form onSubmit={handleSubmit} className="flex flex-col gap-3 sm:gap-4">
|
||||||
<div className="flex flex-col gap-2">
|
<div className="flex flex-col gap-2">
|
||||||
<Label htmlFor="folder-name" className="text-sm">Folder name</Label>
|
<Label htmlFor="folder-name" className="text-sm">
|
||||||
|
Folder name
|
||||||
|
</Label>
|
||||||
<Input
|
<Input
|
||||||
ref={inputRef}
|
ref={inputRef}
|
||||||
id="folder-name"
|
id="folder-name"
|
||||||
|
|
@ -91,11 +93,7 @@ export function CreateFolderDialog({
|
||||||
>
|
>
|
||||||
Cancel
|
Cancel
|
||||||
</Button>
|
</Button>
|
||||||
<Button
|
<Button type="submit" disabled={!name.trim()} className="h-8 sm:h-9 text-xs sm:text-sm">
|
||||||
type="submit"
|
|
||||||
disabled={!name.trim()}
|
|
||||||
className="h-8 sm:h-9 text-xs sm:text-sm"
|
|
||||||
>
|
|
||||||
Create
|
Create
|
||||||
</Button>
|
</Button>
|
||||||
</DialogFooter>
|
</DialogFooter>
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,15 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { AlertCircle, Clock, Download, Eye, MoreHorizontal, Move, PenLine, Trash2 } from "lucide-react";
|
import {
|
||||||
|
AlertCircle,
|
||||||
|
Clock,
|
||||||
|
Download,
|
||||||
|
Eye,
|
||||||
|
MoreHorizontal,
|
||||||
|
Move,
|
||||||
|
PenLine,
|
||||||
|
Trash2,
|
||||||
|
} from "lucide-react";
|
||||||
import React, { useCallback, useRef, useState } from "react";
|
import React, { useCallback, useRef, useState } from "react";
|
||||||
import { useDrag } from "react-dnd";
|
import { useDrag } from "react-dnd";
|
||||||
import { getDocumentTypeIcon } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentTypeIcon";
|
import { getDocumentTypeIcon } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentTypeIcon";
|
||||||
|
|
@ -112,14 +121,15 @@ export const DocumentNode = React.memo(function DocumentNode({
|
||||||
return (
|
return (
|
||||||
<ContextMenu onOpenChange={onContextMenuOpenChange}>
|
<ContextMenu onOpenChange={onContextMenuOpenChange}>
|
||||||
<ContextMenuTrigger asChild>
|
<ContextMenuTrigger asChild>
|
||||||
|
{/* biome-ignore lint/a11y/useSemanticElements: contains nested interactive children (Checkbox) that render as <button>, making a semantic <button> wrapper invalid */}
|
||||||
<div
|
<div
|
||||||
role="button"
|
role="button"
|
||||||
tabIndex={0}
|
tabIndex={0}
|
||||||
ref={attachRef}
|
ref={attachRef}
|
||||||
className={cn(
|
className={cn(
|
||||||
"group flex h-8 w-full items-center gap-2.5 rounded-md px-1 text-sm hover:bg-accent/50 cursor-pointer select-none text-left",
|
"group flex h-8 w-full items-center gap-2.5 rounded-md px-1 text-sm hover:bg-accent/50 cursor-pointer select-none text-left",
|
||||||
isMentioned && "bg-accent/30",
|
isMentioned && "bg-accent/30",
|
||||||
isDragging && "opacity-40"
|
isDragging && "opacity-40"
|
||||||
)}
|
)}
|
||||||
style={{ paddingLeft: `${depth * 16 + 4}px` }}
|
style={{ paddingLeft: `${depth * 16 + 4}px` }}
|
||||||
onClick={handleCheckChange}
|
onClick={handleCheckChange}
|
||||||
|
|
@ -130,54 +140,54 @@ export const DocumentNode = React.memo(function DocumentNode({
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{(() => {
|
{(() => {
|
||||||
if (statusState === "pending") {
|
if (statusState === "pending") {
|
||||||
|
return (
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger asChild>
|
||||||
|
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
|
||||||
|
<Clock className="h-3.5 w-3.5 text-muted-foreground/60" />
|
||||||
|
</span>
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipContent side="top">Pending - waiting to be synced</TooltipContent>
|
||||||
|
</Tooltip>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (statusState === "processing") {
|
||||||
|
return (
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger asChild>
|
||||||
|
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
|
||||||
|
<Spinner size="xs" className="text-primary" />
|
||||||
|
</span>
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipContent side="top">Syncing</TooltipContent>
|
||||||
|
</Tooltip>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (statusState === "failed") {
|
||||||
|
return (
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger asChild>
|
||||||
|
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
|
||||||
|
<AlertCircle className="h-3.5 w-3.5 text-destructive" />
|
||||||
|
</span>
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipContent side="top" className="max-w-xs">
|
||||||
|
{doc.status?.reason || "Processing failed"}
|
||||||
|
</TooltipContent>
|
||||||
|
</Tooltip>
|
||||||
|
);
|
||||||
|
}
|
||||||
return (
|
return (
|
||||||
<Tooltip>
|
<Checkbox
|
||||||
<TooltipTrigger asChild>
|
checked={isMentioned}
|
||||||
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
|
onCheckedChange={handleCheckChange}
|
||||||
<Clock className="h-3.5 w-3.5 text-muted-foreground/60" />
|
onClick={(e) => e.stopPropagation()}
|
||||||
</span>
|
className="h-3.5 w-3.5 shrink-0"
|
||||||
</TooltipTrigger>
|
/>
|
||||||
<TooltipContent side="top">Pending - waiting to be synced</TooltipContent>
|
|
||||||
</Tooltip>
|
|
||||||
);
|
);
|
||||||
}
|
})()}
|
||||||
if (statusState === "processing") {
|
|
||||||
return (
|
|
||||||
<Tooltip>
|
|
||||||
<TooltipTrigger asChild>
|
|
||||||
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
|
|
||||||
<Spinner size="xs" className="text-primary" />
|
|
||||||
</span>
|
|
||||||
</TooltipTrigger>
|
|
||||||
<TooltipContent side="top">Syncing</TooltipContent>
|
|
||||||
</Tooltip>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
if (statusState === "failed") {
|
|
||||||
return (
|
|
||||||
<Tooltip>
|
|
||||||
<TooltipTrigger asChild>
|
|
||||||
<span className="flex h-3.5 w-3.5 shrink-0 items-center justify-center">
|
|
||||||
<AlertCircle className="h-3.5 w-3.5 text-destructive" />
|
|
||||||
</span>
|
|
||||||
</TooltipTrigger>
|
|
||||||
<TooltipContent side="top" className="max-w-xs">
|
|
||||||
{doc.status?.reason || "Processing failed"}
|
|
||||||
</TooltipContent>
|
|
||||||
</Tooltip>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return (
|
|
||||||
<Checkbox
|
|
||||||
checked={isMentioned}
|
|
||||||
onCheckedChange={handleCheckChange}
|
|
||||||
onClick={(e) => e.stopPropagation()}
|
|
||||||
className="h-3.5 w-3.5 shrink-0"
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
})()}
|
|
||||||
|
|
||||||
<span className="flex-1 min-w-0 truncate">{doc.title}</span>
|
<span className="flex-1 min-w-0 truncate">{doc.title}</span>
|
||||||
|
|
||||||
|
|
@ -188,17 +198,19 @@ export const DocumentNode = React.memo(function DocumentNode({
|
||||||
)}
|
)}
|
||||||
</span>
|
</span>
|
||||||
|
|
||||||
<DropdownMenu open={dropdownOpen} onOpenChange={setDropdownOpen}>
|
<DropdownMenu open={dropdownOpen} onOpenChange={setDropdownOpen}>
|
||||||
<DropdownMenuTrigger asChild>
|
<DropdownMenuTrigger asChild>
|
||||||
<Button
|
<Button
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
size="icon"
|
size="icon"
|
||||||
className={cn(
|
className={cn(
|
||||||
"hidden sm:inline-flex h-6 w-6 shrink-0 hover:bg-transparent",
|
"hidden sm:inline-flex h-6 w-6 shrink-0 hover:bg-transparent",
|
||||||
dropdownOpen ? "opacity-100 bg-accent hover:bg-accent" : "opacity-0 group-hover:opacity-100"
|
dropdownOpen
|
||||||
)}
|
? "opacity-100 bg-accent hover:bg-accent"
|
||||||
onClick={(e) => e.stopPropagation()}
|
: "opacity-0 group-hover:opacity-100"
|
||||||
>
|
)}
|
||||||
|
onClick={(e) => e.stopPropagation()}
|
||||||
|
>
|
||||||
<MoreHorizontal className="h-3.5 w-3.5" />
|
<MoreHorizontal className="h-3.5 w-3.5" />
|
||||||
</Button>
|
</Button>
|
||||||
</DropdownMenuTrigger>
|
</DropdownMenuTrigger>
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ import React, { useCallback, useEffect, useRef, useState } from "react";
|
||||||
import { useDrag, useDrop } from "react-dnd";
|
import { useDrag, useDrop } from "react-dnd";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import { Checkbox } from "@/components/ui/checkbox";
|
import { Checkbox } from "@/components/ui/checkbox";
|
||||||
import type { FolderSelectionState } from "./FolderTreeView";
|
|
||||||
import {
|
import {
|
||||||
ContextMenu,
|
ContextMenu,
|
||||||
ContextMenuContent,
|
ContextMenuContent,
|
||||||
|
|
@ -29,6 +28,7 @@ import {
|
||||||
DropdownMenuTrigger,
|
DropdownMenuTrigger,
|
||||||
} from "@/components/ui/dropdown-menu";
|
} from "@/components/ui/dropdown-menu";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
|
import type { FolderSelectionState } from "./FolderTreeView";
|
||||||
|
|
||||||
export const DND_TYPES = {
|
export const DND_TYPES = {
|
||||||
FOLDER: "FOLDER",
|
FOLDER: "FOLDER",
|
||||||
|
|
@ -263,7 +263,9 @@ export const FolderNode = React.memo(function FolderNode({
|
||||||
</span>
|
</span>
|
||||||
|
|
||||||
<Checkbox
|
<Checkbox
|
||||||
checked={selectionState === "all" ? true : selectionState === "some" ? "indeterminate" : false}
|
checked={
|
||||||
|
selectionState === "all" ? true : selectionState === "some" ? "indeterminate" : false
|
||||||
|
}
|
||||||
onCheckedChange={handleCheckChange}
|
onCheckedChange={handleCheckChange}
|
||||||
onClick={(e) => e.stopPropagation()}
|
onClick={(e) => e.stopPropagation()}
|
||||||
className="h-3.5 w-3.5 shrink-0"
|
className="h-3.5 w-3.5 shrink-0"
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ interface FolderTreeViewProps {
|
||||||
onMoveDocument: (doc: DocumentNodeDoc) => void;
|
onMoveDocument: (doc: DocumentNodeDoc) => void;
|
||||||
onExportDocument?: (doc: DocumentNodeDoc, format: string) => void;
|
onExportDocument?: (doc: DocumentNodeDoc, format: string) => void;
|
||||||
activeTypes: DocumentTypeEnum[];
|
activeTypes: DocumentTypeEnum[];
|
||||||
|
searchQuery?: string;
|
||||||
onDropIntoFolder?: (
|
onDropIntoFolder?: (
|
||||||
itemType: "folder" | "document",
|
itemType: "folder" | "document",
|
||||||
itemId: number,
|
itemId: number,
|
||||||
|
|
@ -69,6 +70,7 @@ export function FolderTreeView({
|
||||||
onMoveDocument,
|
onMoveDocument,
|
||||||
onExportDocument,
|
onExportDocument,
|
||||||
activeTypes,
|
activeTypes,
|
||||||
|
searchQuery,
|
||||||
onDropIntoFolder,
|
onDropIntoFolder,
|
||||||
onReorderFolder,
|
onReorderFolder,
|
||||||
}: FolderTreeViewProps) {
|
}: FolderTreeViewProps) {
|
||||||
|
|
@ -97,13 +99,13 @@ export function FolderTreeView({
|
||||||
const handleCancelRename = useCallback(() => setRenamingFolderId(null), [setRenamingFolderId]);
|
const handleCancelRename = useCallback(() => setRenamingFolderId(null), [setRenamingFolderId]);
|
||||||
|
|
||||||
const hasDescendantMatch = useMemo(() => {
|
const hasDescendantMatch = useMemo(() => {
|
||||||
if (activeTypes.length === 0) return null;
|
if (activeTypes.length === 0 && !searchQuery) return null;
|
||||||
const match: Record<number, boolean> = {};
|
const match: Record<number, boolean> = {};
|
||||||
|
|
||||||
function check(folderId: number): boolean {
|
function check(folderId: number): boolean {
|
||||||
if (match[folderId] !== undefined) return match[folderId];
|
if (match[folderId] !== undefined) return match[folderId];
|
||||||
const childDocs = (docsByFolder[folderId] ?? []).some((d) =>
|
const childDocs = (docsByFolder[folderId] ?? []).some(
|
||||||
activeTypes.includes(d.document_type as DocumentTypeEnum)
|
(d) => activeTypes.length === 0 || activeTypes.includes(d.document_type as DocumentTypeEnum)
|
||||||
);
|
);
|
||||||
if (childDocs) {
|
if (childDocs) {
|
||||||
match[folderId] = true;
|
match[folderId] = true;
|
||||||
|
|
@ -124,7 +126,7 @@ export function FolderTreeView({
|
||||||
check(f.id);
|
check(f.id);
|
||||||
}
|
}
|
||||||
return match;
|
return match;
|
||||||
}, [folders, docsByFolder, foldersByParent, activeTypes]);
|
}, [folders, docsByFolder, foldersByParent, activeTypes, searchQuery]);
|
||||||
|
|
||||||
const folderSelectionStates = useMemo(() => {
|
const folderSelectionStates = useMemo(() => {
|
||||||
const states: Record<number, FolderSelectionState> = {};
|
const states: Record<number, FolderSelectionState> = {};
|
||||||
|
|
@ -177,12 +179,15 @@ export function FolderTreeView({
|
||||||
after: i < visibleFolders.length - 1 ? visibleFolders[i + 1].position : null,
|
after: i < visibleFolders.length - 1 ? visibleFolders[i + 1].position : null,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const isAutoExpanded = !!searchQuery && !!hasDescendantMatch?.[f.id];
|
||||||
|
const isExpanded = expandedIds.has(f.id) || isAutoExpanded;
|
||||||
|
|
||||||
nodes.push(
|
nodes.push(
|
||||||
<FolderNode
|
<FolderNode
|
||||||
key={`folder-${f.id}`}
|
key={`folder-${f.id}`}
|
||||||
folder={f}
|
folder={f}
|
||||||
depth={depth}
|
depth={depth}
|
||||||
isExpanded={expandedIds.has(f.id)}
|
isExpanded={isExpanded}
|
||||||
isRenaming={renamingFolderId === f.id}
|
isRenaming={renamingFolderId === f.id}
|
||||||
childCount={folderChildCounts[f.id] ?? 0}
|
childCount={folderChildCounts[f.id] ?? 0}
|
||||||
selectionState={folderSelectionStates[f.id] ?? "none"}
|
selectionState={folderSelectionStates[f.id] ?? "none"}
|
||||||
|
|
@ -202,7 +207,7 @@ export function FolderTreeView({
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
|
|
||||||
if (expandedIds.has(f.id)) {
|
if (isExpanded) {
|
||||||
nodes.push(...renderLevel(f.id, depth + 1));
|
nodes.push(...renderLevel(f.id, depth + 1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -240,7 +245,7 @@ export function FolderTreeView({
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (treeNodes.length === 0 && activeTypes.length > 0) {
|
if (treeNodes.length === 0 && (activeTypes.length > 0 || searchQuery)) {
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-1 flex-col items-center justify-center gap-3 px-4 py-12 text-muted-foreground">
|
<div className="flex flex-1 flex-col items-center justify-center gap-3 px-4 py-12 text-muted-foreground">
|
||||||
<CirclePlus className="h-10 w-10 rotate-45" />
|
<CirclePlus className="h-10 w-10 rotate-45" />
|
||||||
|
|
|
||||||
|
|
@ -2,42 +2,50 @@
|
||||||
|
|
||||||
import { useQuery } from "@rocicorp/zero/react";
|
import { useQuery } from "@rocicorp/zero/react";
|
||||||
import { useAtom, useAtomValue, useSetAtom } from "jotai";
|
import { useAtom, useAtomValue, useSetAtom } from "jotai";
|
||||||
import { ChevronLeft, ChevronRight, Unplug } from "lucide-react";
|
import { ChevronLeft, ChevronRight, Trash2, Unplug } from "lucide-react";
|
||||||
import { useParams } from "next/navigation";
|
import { useParams } from "next/navigation";
|
||||||
import { useTranslations } from "next-intl";
|
import { useTranslations } from "next-intl";
|
||||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
import { EXPORT_FILE_EXTENSIONS } from "@/components/shared/ExportMenuItems";
|
|
||||||
import { DocumentsFilters } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsFilters";
|
import { DocumentsFilters } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsFilters";
|
||||||
import {
|
|
||||||
DocumentsTableShell,
|
|
||||||
type SortKey,
|
|
||||||
} from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsTableShell";
|
|
||||||
import { sidebarSelectedDocumentsAtom } from "@/atoms/chat/mentioned-documents.atom";
|
import { sidebarSelectedDocumentsAtom } from "@/atoms/chat/mentioned-documents.atom";
|
||||||
import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms";
|
import { connectorDialogOpenAtom } from "@/atoms/connector-dialog/connector-dialog.atoms";
|
||||||
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
|
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
|
||||||
import { deleteDocumentMutationAtom } from "@/atoms/documents/document-mutation.atoms";
|
import { deleteDocumentMutationAtom } from "@/atoms/documents/document-mutation.atoms";
|
||||||
import { expandedFolderIdsAtom } from "@/atoms/documents/folder.atoms";
|
import { expandedFolderIdsAtom } from "@/atoms/documents/folder.atoms";
|
||||||
|
import { agentCreatedDocumentsAtom } from "@/atoms/documents/ui.atoms";
|
||||||
import { openDocumentTabAtom } from "@/atoms/tabs/tabs.atom";
|
import { openDocumentTabAtom } from "@/atoms/tabs/tabs.atom";
|
||||||
import { CreateFolderDialog } from "@/components/documents/CreateFolderDialog";
|
import { CreateFolderDialog } from "@/components/documents/CreateFolderDialog";
|
||||||
import type { DocumentNodeDoc } from "@/components/documents/DocumentNode";
|
import type { DocumentNodeDoc } from "@/components/documents/DocumentNode";
|
||||||
import type { FolderDisplay } from "@/components/documents/FolderNode";
|
import type { FolderDisplay } from "@/components/documents/FolderNode";
|
||||||
import { FolderPickerDialog } from "@/components/documents/FolderPickerDialog";
|
import { FolderPickerDialog } from "@/components/documents/FolderPickerDialog";
|
||||||
import { FolderTreeView } from "@/components/documents/FolderTreeView";
|
import { FolderTreeView } from "@/components/documents/FolderTreeView";
|
||||||
|
import { EXPORT_FILE_EXTENSIONS } from "@/components/shared/ExportMenuItems";
|
||||||
|
import {
|
||||||
|
AlertDialog,
|
||||||
|
AlertDialogAction,
|
||||||
|
AlertDialogCancel,
|
||||||
|
AlertDialogContent,
|
||||||
|
AlertDialogDescription,
|
||||||
|
AlertDialogFooter,
|
||||||
|
AlertDialogHeader,
|
||||||
|
AlertDialogTitle,
|
||||||
|
} from "@/components/ui/alert-dialog";
|
||||||
import { Avatar, AvatarFallback, AvatarGroup } from "@/components/ui/avatar";
|
import { Avatar, AvatarFallback, AvatarGroup } from "@/components/ui/avatar";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { Spinner } from "@/components/ui/spinner";
|
||||||
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
|
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
|
||||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||||
import type { DocumentTypeEnum } from "@/contracts/types/document.types";
|
import type { DocumentTypeEnum } from "@/contracts/types/document.types";
|
||||||
import { useDebouncedValue } from "@/hooks/use-debounced-value";
|
import { useDebouncedValue } from "@/hooks/use-debounced-value";
|
||||||
import { useDocumentSearch } from "@/hooks/use-document-search";
|
|
||||||
import { useDocuments } from "@/hooks/use-documents";
|
|
||||||
import { useMediaQuery } from "@/hooks/use-media-query";
|
import { useMediaQuery } from "@/hooks/use-media-query";
|
||||||
import { foldersApiService } from "@/lib/apis/folders-api.service";
|
import { foldersApiService } from "@/lib/apis/folders-api.service";
|
||||||
import { authenticatedFetch } from "@/lib/auth-utils";
|
import { authenticatedFetch } from "@/lib/auth-utils";
|
||||||
import { queries } from "@/zero/queries/index";
|
import { queries } from "@/zero/queries/index";
|
||||||
import { SidebarSlideOutPanel } from "./SidebarSlideOutPanel";
|
import { SidebarSlideOutPanel } from "./SidebarSlideOutPanel";
|
||||||
|
|
||||||
|
const NON_DELETABLE_DOCUMENT_TYPES: readonly string[] = ["SURFSENSE_DOCS"];
|
||||||
|
|
||||||
const SHOWCASE_CONNECTORS = [
|
const SHOWCASE_CONNECTORS = [
|
||||||
{ type: "GOOGLE_DRIVE_CONNECTOR", label: "Google Drive" },
|
{ type: "GOOGLE_DRIVE_CONNECTOR", label: "Google Drive" },
|
||||||
{ type: "GOOGLE_GMAIL_CONNECTOR", label: "Gmail" },
|
{ type: "GOOGLE_GMAIL_CONNECTOR", label: "Gmail" },
|
||||||
|
|
@ -82,8 +90,6 @@ export function DocumentsSidebar({
|
||||||
const [search, setSearch] = useState("");
|
const [search, setSearch] = useState("");
|
||||||
const debouncedSearch = useDebouncedValue(search, 250);
|
const debouncedSearch = useDebouncedValue(search, 250);
|
||||||
const [activeTypes, setActiveTypes] = useState<DocumentTypeEnum[]>([]);
|
const [activeTypes, setActiveTypes] = useState<DocumentTypeEnum[]>([]);
|
||||||
const [sortKey, setSortKey] = useState<SortKey>("created_at");
|
|
||||||
const [sortDesc, setSortDesc] = useState(true);
|
|
||||||
const { mutateAsync: deleteDocumentMutation } = useAtomValue(deleteDocumentMutationAtom);
|
const { mutateAsync: deleteDocumentMutation } = useAtomValue(deleteDocumentMutationAtom);
|
||||||
|
|
||||||
const [sidebarDocs, setSidebarDocs] = useAtom(sidebarSelectedDocumentsAtom);
|
const [sidebarDocs, setSidebarDocs] = useAtom(sidebarSelectedDocumentsAtom);
|
||||||
|
|
@ -110,6 +116,7 @@ export function DocumentsSidebar({
|
||||||
// Zero queries for tree data
|
// Zero queries for tree data
|
||||||
const [zeroFolders] = useQuery(queries.folders.bySpace({ searchSpaceId }));
|
const [zeroFolders] = useQuery(queries.folders.bySpace({ searchSpaceId }));
|
||||||
const [zeroAllDocs] = useQuery(queries.documents.bySpace({ searchSpaceId }));
|
const [zeroAllDocs] = useQuery(queries.documents.bySpace({ searchSpaceId }));
|
||||||
|
const [agentCreatedDocs, setAgentCreatedDocs] = useAtom(agentCreatedDocumentsAtom);
|
||||||
|
|
||||||
const treeFolders: FolderDisplay[] = useMemo(
|
const treeFolders: FolderDisplay[] = useMemo(
|
||||||
() =>
|
() =>
|
||||||
|
|
@ -123,19 +130,41 @@ export function DocumentsSidebar({
|
||||||
[zeroFolders]
|
[zeroFolders]
|
||||||
);
|
);
|
||||||
|
|
||||||
const treeDocuments: DocumentNodeDoc[] = useMemo(
|
const treeDocuments: DocumentNodeDoc[] = useMemo(() => {
|
||||||
() =>
|
const zeroDocs = (zeroAllDocs ?? [])
|
||||||
(zeroAllDocs ?? [])
|
.filter((d) => d.title && d.title.trim() !== "")
|
||||||
.filter((d) => d.title && d.title.trim() !== "")
|
.map((d) => ({
|
||||||
.map((d) => ({
|
id: d.id,
|
||||||
id: d.id,
|
title: d.title,
|
||||||
title: d.title,
|
document_type: d.documentType,
|
||||||
document_type: d.documentType,
|
folderId: (d as { folderId?: number | null }).folderId ?? null,
|
||||||
folderId: (d as { folderId?: number | null }).folderId ?? null,
|
status: d.status as { state: string; reason?: string | null } | undefined,
|
||||||
status: d.status as { state: string; reason?: string | null } | undefined,
|
}));
|
||||||
})),
|
|
||||||
[zeroAllDocs]
|
const zeroIds = new Set(zeroDocs.map((d) => d.id));
|
||||||
);
|
|
||||||
|
const pendingAgentDocs = agentCreatedDocs
|
||||||
|
.filter((d) => d.searchSpaceId === searchSpaceId && !zeroIds.has(d.id))
|
||||||
|
.map((d) => ({
|
||||||
|
id: d.id,
|
||||||
|
title: d.title,
|
||||||
|
document_type: d.documentType,
|
||||||
|
folderId: d.folderId ?? null,
|
||||||
|
status: { state: "ready" } as { state: string; reason?: string | null },
|
||||||
|
}));
|
||||||
|
|
||||||
|
return [...pendingAgentDocs, ...zeroDocs];
|
||||||
|
}, [zeroAllDocs, agentCreatedDocs, searchSpaceId]);
|
||||||
|
|
||||||
|
// Prune agent-created docs once Zero has caught up
|
||||||
|
useEffect(() => {
|
||||||
|
if (!zeroAllDocs?.length || !agentCreatedDocs.length) return;
|
||||||
|
const zeroIds = new Set(zeroAllDocs.map((d) => d.id));
|
||||||
|
const remaining = agentCreatedDocs.filter((d) => !zeroIds.has(d.id));
|
||||||
|
if (remaining.length < agentCreatedDocs.length) {
|
||||||
|
setAgentCreatedDocs(remaining);
|
||||||
|
}
|
||||||
|
}, [zeroAllDocs, agentCreatedDocs, setAgentCreatedDocs]);
|
||||||
|
|
||||||
const foldersByParent = useMemo(() => {
|
const foldersByParent = useMemo(() => {
|
||||||
const map: Record<string, FolderDisplay[]> = {};
|
const map: Record<string, FolderDisplay[]> = {};
|
||||||
|
|
@ -355,7 +384,7 @@ export function DocumentsSidebar({
|
||||||
(d) =>
|
(d) =>
|
||||||
d.folderId === parentId &&
|
d.folderId === parentId &&
|
||||||
d.status?.state !== "pending" &&
|
d.status?.state !== "pending" &&
|
||||||
d.status?.state !== "processing",
|
d.status?.state !== "processing"
|
||||||
);
|
);
|
||||||
const childFolders = foldersByParent[String(parentId)] ?? [];
|
const childFolders = foldersByParent[String(parentId)] ?? [];
|
||||||
const descendantDocs = childFolders.flatMap((cf) => collectSubtreeDocs(cf.id));
|
const descendantDocs = childFolders.flatMap((cf) => collectSubtreeDocs(cf.id));
|
||||||
|
|
@ -382,38 +411,72 @@ export function DocumentsSidebar({
|
||||||
setSidebarDocs((prev) => prev.filter((d) => !idsToRemove.has(d.id)));
|
setSidebarDocs((prev) => prev.filter((d) => !idsToRemove.has(d.id)));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[treeDocuments, foldersByParent, setSidebarDocs],
|
[treeDocuments, foldersByParent, setSidebarDocs]
|
||||||
);
|
);
|
||||||
|
|
||||||
const isSearchMode = !!debouncedSearch.trim();
|
const searchFilteredDocuments = useMemo(() => {
|
||||||
|
const query = debouncedSearch.trim().toLowerCase();
|
||||||
|
if (!query) return treeDocuments;
|
||||||
|
return treeDocuments.filter((d) => d.title.toLowerCase().includes(query));
|
||||||
|
}, [treeDocuments, debouncedSearch]);
|
||||||
|
|
||||||
const {
|
const typeCounts = useMemo(() => {
|
||||||
documents: realtimeDocuments,
|
const counts: Partial<Record<string, number>> = {};
|
||||||
typeCounts: realtimeTypeCounts,
|
for (const d of treeDocuments) {
|
||||||
loading: realtimeLoading,
|
counts[d.document_type] = (counts[d.document_type] || 0) + 1;
|
||||||
loadingMore: realtimeLoadingMore,
|
}
|
||||||
hasMore: realtimeHasMore,
|
return counts;
|
||||||
loadMore: realtimeLoadMore,
|
}, [treeDocuments]);
|
||||||
removeItems: realtimeRemoveItems,
|
|
||||||
error: realtimeError,
|
|
||||||
} = useDocuments(searchSpaceId, activeTypes, sortKey, sortDesc ? "desc" : "asc");
|
|
||||||
|
|
||||||
const {
|
const deletableSelectedIds = useMemo(() => {
|
||||||
documents: searchDocuments,
|
const treeDocMap = new Map(treeDocuments.map((d) => [d.id, d]));
|
||||||
loading: searchLoading,
|
return sidebarDocs
|
||||||
loadingMore: searchLoadingMore,
|
.filter((doc) => {
|
||||||
hasMore: searchHasMore,
|
const fullDoc = treeDocMap.get(doc.id);
|
||||||
loadMore: searchLoadMore,
|
if (!fullDoc) return false;
|
||||||
error: searchError,
|
const state = fullDoc.status?.state ?? "ready";
|
||||||
removeItems: searchRemoveItems,
|
return (
|
||||||
} = useDocumentSearch(searchSpaceId, debouncedSearch, activeTypes, isSearchMode && open);
|
state !== "pending" &&
|
||||||
|
state !== "processing" &&
|
||||||
|
!NON_DELETABLE_DOCUMENT_TYPES.includes(doc.document_type)
|
||||||
|
);
|
||||||
|
})
|
||||||
|
.map((doc) => doc.id);
|
||||||
|
}, [sidebarDocs, treeDocuments]);
|
||||||
|
|
||||||
const displayDocs = isSearchMode ? searchDocuments : realtimeDocuments;
|
const [bulkDeleteConfirmOpen, setBulkDeleteConfirmOpen] = useState(false);
|
||||||
const loading = isSearchMode ? searchLoading : realtimeLoading;
|
const [isBulkDeleting, setIsBulkDeleting] = useState(false);
|
||||||
const error = isSearchMode ? searchError : !!realtimeError;
|
|
||||||
const hasMore = isSearchMode ? searchHasMore : realtimeHasMore;
|
const handleBulkDeleteSelected = useCallback(async () => {
|
||||||
const loadingMore = isSearchMode ? searchLoadingMore : realtimeLoadingMore;
|
if (deletableSelectedIds.length === 0) return;
|
||||||
const onLoadMore = isSearchMode ? searchLoadMore : realtimeLoadMore;
|
setIsBulkDeleting(true);
|
||||||
|
try {
|
||||||
|
const results = await Promise.allSettled(
|
||||||
|
deletableSelectedIds.map(async (id) => {
|
||||||
|
await deleteDocumentMutation({ id });
|
||||||
|
return id;
|
||||||
|
})
|
||||||
|
);
|
||||||
|
const successIds = results
|
||||||
|
.filter((r): r is PromiseFulfilledResult<number> => r.status === "fulfilled")
|
||||||
|
.map((r) => r.value);
|
||||||
|
const failed = results.length - successIds.length;
|
||||||
|
if (successIds.length > 0) {
|
||||||
|
setSidebarDocs((prev) => {
|
||||||
|
const idSet = new Set(successIds);
|
||||||
|
return prev.filter((d) => !idSet.has(d.id));
|
||||||
|
});
|
||||||
|
toast.success(`Deleted ${successIds.length} document${successIds.length !== 1 ? "s" : ""}`);
|
||||||
|
}
|
||||||
|
if (failed > 0) {
|
||||||
|
toast.error(`Failed to delete ${failed} document${failed !== 1 ? "s" : ""}`);
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
toast.error("Failed to delete documents");
|
||||||
|
}
|
||||||
|
setIsBulkDeleting(false);
|
||||||
|
setBulkDeleteConfirmOpen(false);
|
||||||
|
}, [deletableSelectedIds, deleteDocumentMutation, setSidebarDocs]);
|
||||||
|
|
||||||
const onToggleType = useCallback((type: DocumentTypeEnum, checked: boolean) => {
|
const onToggleType = useCallback((type: DocumentTypeEnum, checked: boolean) => {
|
||||||
setActiveTypes((prev) => {
|
setActiveTypes((prev) => {
|
||||||
|
|
@ -430,69 +493,15 @@ export function DocumentsSidebar({
|
||||||
await deleteDocumentMutation({ id });
|
await deleteDocumentMutation({ id });
|
||||||
toast.success(t("delete_success") || "Document deleted");
|
toast.success(t("delete_success") || "Document deleted");
|
||||||
setSidebarDocs((prev) => prev.filter((d) => d.id !== id));
|
setSidebarDocs((prev) => prev.filter((d) => d.id !== id));
|
||||||
realtimeRemoveItems([id]);
|
|
||||||
if (isSearchMode) {
|
|
||||||
searchRemoveItems([id]);
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.error("Error deleting document:", e);
|
console.error("Error deleting document:", e);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[
|
[deleteDocumentMutation, t, setSidebarDocs]
|
||||||
deleteDocumentMutation,
|
|
||||||
isSearchMode,
|
|
||||||
t,
|
|
||||||
searchRemoveItems,
|
|
||||||
realtimeRemoveItems,
|
|
||||||
setSidebarDocs,
|
|
||||||
]
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleBulkDeleteDocuments = useCallback(
|
|
||||||
async (ids: number[]): Promise<{ success: number; failed: number }> => {
|
|
||||||
const successIds: number[] = [];
|
|
||||||
const results = await Promise.allSettled(
|
|
||||||
ids.map(async (id) => {
|
|
||||||
await deleteDocumentMutation({ id });
|
|
||||||
successIds.push(id);
|
|
||||||
})
|
|
||||||
);
|
|
||||||
if (successIds.length > 0) {
|
|
||||||
setSidebarDocs((prev) => prev.filter((d) => !successIds.includes(d.id)));
|
|
||||||
realtimeRemoveItems(successIds);
|
|
||||||
if (isSearchMode) {
|
|
||||||
searchRemoveItems(successIds);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const success = results.filter((r) => r.status === "fulfilled").length;
|
|
||||||
const failed = results.filter((r) => r.status === "rejected").length;
|
|
||||||
return { success, failed };
|
|
||||||
},
|
|
||||||
[deleteDocumentMutation, isSearchMode, searchRemoveItems, realtimeRemoveItems, setSidebarDocs]
|
|
||||||
);
|
|
||||||
|
|
||||||
const sortKeyRef = useRef(sortKey);
|
|
||||||
const sortDescRef = useRef(sortDesc);
|
|
||||||
sortKeyRef.current = sortKey;
|
|
||||||
sortDescRef.current = sortDesc;
|
|
||||||
|
|
||||||
const handleSortChange = useCallback((key: SortKey) => {
|
|
||||||
const currentKey = sortKeyRef.current;
|
|
||||||
const currentDesc = sortDescRef.current;
|
|
||||||
|
|
||||||
if (currentKey === key && currentDesc) {
|
|
||||||
setSortKey("created_at");
|
|
||||||
setSortDesc(true);
|
|
||||||
} else if (currentKey === key) {
|
|
||||||
setSortDesc(true);
|
|
||||||
} else {
|
|
||||||
setSortKey(key);
|
|
||||||
setSortDesc(false);
|
|
||||||
}
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const handleEscape = (e: KeyboardEvent) => {
|
const handleEscape = (e: KeyboardEvent) => {
|
||||||
if (e.key === "Escape" && open) {
|
if (e.key === "Escape" && open) {
|
||||||
|
|
@ -627,7 +636,7 @@ export function DocumentsSidebar({
|
||||||
<div className="flex-1 min-h-0 overflow-x-hidden pt-0 flex flex-col">
|
<div className="flex-1 min-h-0 overflow-x-hidden pt-0 flex flex-col">
|
||||||
<div className="px-4 pb-2">
|
<div className="px-4 pb-2">
|
||||||
<DocumentsFilters
|
<DocumentsFilters
|
||||||
typeCounts={realtimeTypeCounts}
|
typeCounts={typeCounts}
|
||||||
onSearch={setSearch}
|
onSearch={setSearch}
|
||||||
searchValue={search}
|
searchValue={search}
|
||||||
onToggleType={onToggleType}
|
onToggleType={onToggleType}
|
||||||
|
|
@ -636,59 +645,54 @@ export function DocumentsSidebar({
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{isSearchMode ? (
|
{deletableSelectedIds.length > 0 && (
|
||||||
<DocumentsTableShell
|
<div className="shrink-0 flex items-center justify-center px-4 py-1.5 animate-in fade-in duration-150">
|
||||||
documents={displayDocs}
|
<button
|
||||||
loading={!!loading}
|
type="button"
|
||||||
error={!!error}
|
onClick={() => setBulkDeleteConfirmOpen(true)}
|
||||||
sortKey={sortKey}
|
className="flex items-center gap-1.5 px-3 py-1 rounded-md bg-destructive text-destructive-foreground shadow-sm text-xs font-medium hover:bg-destructive/90 transition-colors"
|
||||||
sortDesc={sortDesc}
|
>
|
||||||
onSortChange={handleSortChange}
|
<Trash2 size={12} />
|
||||||
deleteDocument={handleDeleteDocument}
|
Delete {deletableSelectedIds.length}{" "}
|
||||||
bulkDeleteDocuments={handleBulkDeleteDocuments}
|
{deletableSelectedIds.length === 1 ? "item" : "items"}
|
||||||
searchSpaceId={String(searchSpaceId)}
|
</button>
|
||||||
hasMore={hasMore}
|
</div>
|
||||||
loadingMore={loadingMore}
|
|
||||||
onLoadMore={onLoadMore}
|
|
||||||
mentionedDocIds={mentionedDocIds}
|
|
||||||
onToggleChatMention={handleToggleChatMention}
|
|
||||||
isSearchMode={isSearchMode || activeTypes.length > 0}
|
|
||||||
/>
|
|
||||||
) : (
|
|
||||||
<FolderTreeView
|
|
||||||
folders={treeFolders}
|
|
||||||
documents={treeDocuments}
|
|
||||||
expandedIds={expandedIds}
|
|
||||||
onToggleExpand={toggleFolderExpand}
|
|
||||||
mentionedDocIds={mentionedDocIds}
|
|
||||||
onToggleChatMention={handleToggleChatMention}
|
|
||||||
onToggleFolderSelect={handleToggleFolderSelect}
|
|
||||||
onRenameFolder={handleRenameFolder}
|
|
||||||
onDeleteFolder={handleDeleteFolder}
|
|
||||||
onMoveFolder={handleMoveFolder}
|
|
||||||
onCreateFolder={handleCreateFolder}
|
|
||||||
onPreviewDocument={(doc) => {
|
|
||||||
openDocumentTab({
|
|
||||||
documentId: doc.id,
|
|
||||||
searchSpaceId,
|
|
||||||
title: doc.title,
|
|
||||||
});
|
|
||||||
}}
|
|
||||||
onEditDocument={(doc) => {
|
|
||||||
openDocumentTab({
|
|
||||||
documentId: doc.id,
|
|
||||||
searchSpaceId,
|
|
||||||
title: doc.title,
|
|
||||||
});
|
|
||||||
}}
|
|
||||||
onDeleteDocument={(doc) => handleDeleteDocument(doc.id)}
|
|
||||||
onMoveDocument={handleMoveDocument}
|
|
||||||
onExportDocument={handleExportDocument}
|
|
||||||
activeTypes={activeTypes}
|
|
||||||
onDropIntoFolder={handleDropIntoFolder}
|
|
||||||
onReorderFolder={handleReorderFolder}
|
|
||||||
/>
|
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
<FolderTreeView
|
||||||
|
folders={treeFolders}
|
||||||
|
documents={searchFilteredDocuments}
|
||||||
|
expandedIds={expandedIds}
|
||||||
|
onToggleExpand={toggleFolderExpand}
|
||||||
|
mentionedDocIds={mentionedDocIds}
|
||||||
|
onToggleChatMention={handleToggleChatMention}
|
||||||
|
onToggleFolderSelect={handleToggleFolderSelect}
|
||||||
|
onRenameFolder={handleRenameFolder}
|
||||||
|
onDeleteFolder={handleDeleteFolder}
|
||||||
|
onMoveFolder={handleMoveFolder}
|
||||||
|
onCreateFolder={handleCreateFolder}
|
||||||
|
searchQuery={debouncedSearch.trim() || undefined}
|
||||||
|
onPreviewDocument={(doc) => {
|
||||||
|
openDocumentTab({
|
||||||
|
documentId: doc.id,
|
||||||
|
searchSpaceId,
|
||||||
|
title: doc.title,
|
||||||
|
});
|
||||||
|
}}
|
||||||
|
onEditDocument={(doc) => {
|
||||||
|
openDocumentTab({
|
||||||
|
documentId: doc.id,
|
||||||
|
searchSpaceId,
|
||||||
|
title: doc.title,
|
||||||
|
});
|
||||||
|
}}
|
||||||
|
onDeleteDocument={(doc) => handleDeleteDocument(doc.id)}
|
||||||
|
onMoveDocument={handleMoveDocument}
|
||||||
|
onExportDocument={handleExportDocument}
|
||||||
|
activeTypes={activeTypes}
|
||||||
|
onDropIntoFolder={handleDropIntoFolder}
|
||||||
|
onReorderFolder={handleReorderFolder}
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<FolderPickerDialog
|
<FolderPickerDialog
|
||||||
|
|
@ -707,6 +711,40 @@ export function DocumentsSidebar({
|
||||||
parentFolderName={createFolderParentName}
|
parentFolderName={createFolderParentName}
|
||||||
onConfirm={handleCreateFolderConfirm}
|
onConfirm={handleCreateFolderConfirm}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
<AlertDialog
|
||||||
|
open={bulkDeleteConfirmOpen}
|
||||||
|
onOpenChange={(open) => !open && !isBulkDeleting && setBulkDeleteConfirmOpen(false)}
|
||||||
|
>
|
||||||
|
<AlertDialogContent>
|
||||||
|
<AlertDialogHeader>
|
||||||
|
<AlertDialogTitle>
|
||||||
|
Delete {deletableSelectedIds.length} document
|
||||||
|
{deletableSelectedIds.length !== 1 ? "s" : ""}?
|
||||||
|
</AlertDialogTitle>
|
||||||
|
<AlertDialogDescription>
|
||||||
|
This action cannot be undone.{" "}
|
||||||
|
{deletableSelectedIds.length === 1
|
||||||
|
? "This document"
|
||||||
|
: `These ${deletableSelectedIds.length} documents`}{" "}
|
||||||
|
will be permanently deleted from your search space.
|
||||||
|
</AlertDialogDescription>
|
||||||
|
</AlertDialogHeader>
|
||||||
|
<AlertDialogFooter>
|
||||||
|
<AlertDialogCancel disabled={isBulkDeleting}>Cancel</AlertDialogCancel>
|
||||||
|
<AlertDialogAction
|
||||||
|
onClick={(e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
handleBulkDeleteSelected();
|
||||||
|
}}
|
||||||
|
disabled={isBulkDeleting}
|
||||||
|
className="bg-destructive text-destructive-foreground hover:bg-destructive/90"
|
||||||
|
>
|
||||||
|
{isBulkDeleting ? <Spinner size="sm" /> : "Delete"}
|
||||||
|
</AlertDialogAction>
|
||||||
|
</AlertDialogFooter>
|
||||||
|
</AlertDialogContent>
|
||||||
|
</AlertDialog>
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,10 @@ import { useTheme } from "next-themes";
|
||||||
import { useCallback, useEffect, useRef, useState } from "react";
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
import { createPortal } from "react-dom";
|
import { createPortal } from "react-dom";
|
||||||
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
|
import { connectorsAtom } from "@/atoms/connectors/connector-query.atoms";
|
||||||
import { useZeroDocumentTypeCounts } from "@/hooks/use-zero-document-type-counts";
|
|
||||||
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
||||||
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
|
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
|
||||||
import { useIsMobile } from "@/hooks/use-mobile";
|
import { useIsMobile } from "@/hooks/use-mobile";
|
||||||
|
import { useZeroDocumentTypeCounts } from "@/hooks/use-zero-document-type-counts";
|
||||||
import { fetchThreads } from "@/lib/chat/thread-persistence";
|
import { fetchThreads } from "@/lib/chat/thread-persistence";
|
||||||
|
|
||||||
interface TourStep {
|
interface TourStep {
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ import { currentThreadAtom } from "@/atoms/chat/current-thread.atom";
|
||||||
import { closeReportPanelAtom, reportPanelAtom } from "@/atoms/chat/report-panel.atom";
|
import { closeReportPanelAtom, reportPanelAtom } from "@/atoms/chat/report-panel.atom";
|
||||||
import { PlateEditor } from "@/components/editor/plate-editor";
|
import { PlateEditor } from "@/components/editor/plate-editor";
|
||||||
import { MarkdownViewer } from "@/components/markdown-viewer";
|
import { MarkdownViewer } from "@/components/markdown-viewer";
|
||||||
|
import { EXPORT_FILE_EXTENSIONS, ExportDropdownItems } from "@/components/shared/ExportMenuItems";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import { Drawer, DrawerContent, DrawerHandle, DrawerTitle } from "@/components/ui/drawer";
|
import { Drawer, DrawerContent, DrawerHandle, DrawerTitle } from "@/components/ui/drawer";
|
||||||
import {
|
import {
|
||||||
|
|
@ -17,7 +18,6 @@ import {
|
||||||
DropdownMenuItem,
|
DropdownMenuItem,
|
||||||
DropdownMenuTrigger,
|
DropdownMenuTrigger,
|
||||||
} from "@/components/ui/dropdown-menu";
|
} from "@/components/ui/dropdown-menu";
|
||||||
import { ExportDropdownItems, EXPORT_FILE_EXTENSIONS } from "@/components/shared/ExportMenuItems";
|
|
||||||
import { useMediaQuery } from "@/hooks/use-media-query";
|
import { useMediaQuery } from "@/hooks/use-media-query";
|
||||||
import { baseApiService } from "@/lib/apis/base-api.service";
|
import { baseApiService } from "@/lib/apis/base-api.service";
|
||||||
import { authenticatedFetch } from "@/lib/auth-utils";
|
import { authenticatedFetch } from "@/lib/auth-utils";
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,12 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { Loader2 } from "lucide-react";
|
import { Loader2 } from "lucide-react";
|
||||||
import { DropdownMenuItem, DropdownMenuLabel, DropdownMenuSeparator } from "@/components/ui/dropdown-menu";
|
|
||||||
import { ContextMenuItem } from "@/components/ui/context-menu";
|
import { ContextMenuItem } from "@/components/ui/context-menu";
|
||||||
|
import {
|
||||||
|
DropdownMenuItem,
|
||||||
|
DropdownMenuLabel,
|
||||||
|
DropdownMenuSeparator,
|
||||||
|
} from "@/components/ui/dropdown-menu";
|
||||||
|
|
||||||
export const EXPORT_FILE_EXTENSIONS: Record<string, string> = {
|
export const EXPORT_FILE_EXTENSIONS: Record<string, string> = {
|
||||||
pdf: "pdf",
|
pdf: "pdf",
|
||||||
|
|
@ -36,9 +40,7 @@ export function ExportDropdownItems({
|
||||||
<>
|
<>
|
||||||
{showAllFormats && (
|
{showAllFormats && (
|
||||||
<>
|
<>
|
||||||
<DropdownMenuLabel className="text-xs text-muted-foreground">
|
<DropdownMenuLabel className="text-xs text-muted-foreground">Documents</DropdownMenuLabel>
|
||||||
Documents
|
|
||||||
</DropdownMenuLabel>
|
|
||||||
<DropdownMenuItem onClick={handle("pdf")} disabled={exporting !== null}>
|
<DropdownMenuItem onClick={handle("pdf")} disabled={exporting !== null}>
|
||||||
{exporting === "pdf" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
|
{exporting === "pdf" && <Loader2 className="mr-2 h-3.5 w-3.5 animate-spin" />}
|
||||||
PDF (.pdf)
|
PDF (.pdf)
|
||||||
|
|
|
||||||
|
|
@ -287,13 +287,9 @@ function ApprovalCard({
|
||||||
? pendingEdits.end_datetime
|
? pendingEdits.end_datetime
|
||||||
: null,
|
: null,
|
||||||
new_location:
|
new_location:
|
||||||
pendingEdits.location !== (event?.location ?? "")
|
pendingEdits.location !== (event?.location ?? "") ? pendingEdits.location || null : null,
|
||||||
? pendingEdits.location || null
|
|
||||||
: null,
|
|
||||||
new_attendees:
|
new_attendees:
|
||||||
attendeesArr && attendeesArr.join(",") !== origAttendees.join(",")
|
attendeesArr && attendeesArr.join(",") !== origAttendees.join(",") ? attendeesArr : null,
|
||||||
? attendeesArr
|
|
||||||
: null,
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import {
|
import {
|
||||||
BookOpen,
|
BookOpen,
|
||||||
Brain,
|
Brain,
|
||||||
Database,
|
|
||||||
FileText,
|
FileText,
|
||||||
Film,
|
Film,
|
||||||
Globe,
|
Globe,
|
||||||
|
|
@ -13,7 +12,6 @@ import {
|
||||||
} from "lucide-react";
|
} from "lucide-react";
|
||||||
|
|
||||||
const TOOL_ICONS: Record<string, LucideIcon> = {
|
const TOOL_ICONS: Record<string, LucideIcon> = {
|
||||||
search_knowledge_base: Database,
|
|
||||||
generate_podcast: Podcast,
|
generate_podcast: Podcast,
|
||||||
generate_video_presentation: Film,
|
generate_video_presentation: Film,
|
||||||
generate_report: FileText,
|
generate_report: FileText,
|
||||||
|
|
|
||||||
|
|
@ -13,9 +13,7 @@ export function useZeroDocumentTypeCounts(
|
||||||
): Record<string, number> | undefined {
|
): Record<string, number> | undefined {
|
||||||
const numericId = searchSpaceId != null ? Number(searchSpaceId) : null;
|
const numericId = searchSpaceId != null ? Number(searchSpaceId) : null;
|
||||||
|
|
||||||
const [zeroDocuments] = useQuery(
|
const [zeroDocuments] = useQuery(queries.documents.bySpace({ searchSpaceId: numericId ?? -1 }));
|
||||||
queries.documents.bySpace({ searchSpaceId: numericId ?? -1 })
|
|
||||||
);
|
|
||||||
|
|
||||||
return useMemo(() => {
|
return useMemo(() => {
|
||||||
if (!zeroDocuments || numericId == null) return undefined;
|
if (!zeroDocuments || numericId == null) return undefined;
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue