mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-03 12:52:39 +02:00
feat: enhance performance logging and caching in various components
- Introduced slow callback logging in FastAPI to identify blocking calls. - Added performance logging for agent creation and tool loading processes. - Implemented caching for MCP tools to reduce redundant server calls. - Enhanced sandbox management with in-process caching for improved efficiency. - Refactored several functions for better readability and performance tracking. - Updated tests to ensure proper functionality of new features and optimizations.
This commit is contained in:
parent
2e99f1e853
commit
aabc24f82c
22 changed files with 637 additions and 200 deletions
|
|
@ -6,6 +6,9 @@ with configurable tools via the tools registry and configurable prompts
|
||||||
via NewLLMConfig.
|
via NewLLMConfig.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -26,6 +29,8 @@ from app.agents.new_chat.tools.registry import build_tools_async
|
||||||
from app.db import ChatVisibility
|
from app.db import ChatVisibility
|
||||||
from app.services.connector_service import ConnectorService
|
from app.services.connector_service import ConnectorService
|
||||||
|
|
||||||
|
_perf_log = logging.getLogger("surfsense.perf")
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Connector Type Mapping
|
# Connector Type Mapping
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -210,29 +215,29 @@ async def create_surfsense_deep_agent(
|
||||||
additional_tools=[my_custom_tool]
|
additional_tools=[my_custom_tool]
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
_t_agent_total = time.perf_counter()
|
||||||
|
|
||||||
# Discover available connectors and document types for this search space
|
# Discover available connectors and document types for this search space
|
||||||
# This enables dynamic tool docstrings that inform the LLM about what's actually available
|
|
||||||
available_connectors: list[str] | None = None
|
available_connectors: list[str] | None = None
|
||||||
available_document_types: list[str] | None = None
|
available_document_types: list[str] | None = None
|
||||||
|
|
||||||
|
_t0 = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
# Get enabled search source connectors for this search space
|
|
||||||
connector_types = await connector_service.get_available_connectors(
|
connector_types = await connector_service.get_available_connectors(
|
||||||
search_space_id
|
search_space_id
|
||||||
)
|
)
|
||||||
if connector_types:
|
if connector_types:
|
||||||
# Convert enum values to strings and also include mapped document types
|
|
||||||
available_connectors = _map_connectors_to_searchable_types(connector_types)
|
available_connectors = _map_connectors_to_searchable_types(connector_types)
|
||||||
|
|
||||||
# Get document types that have at least one document indexed
|
|
||||||
available_document_types = await connector_service.get_available_document_types(
|
available_document_types = await connector_service.get_available_document_types(
|
||||||
search_space_id
|
search_space_id
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log but don't fail - fall back to all connectors if discovery fails
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.warning(f"Failed to discover available connectors/document types: {e}")
|
logging.warning(f"Failed to discover available connectors/document types: {e}")
|
||||||
|
_perf_log.info(
|
||||||
|
"[create_agent] Connector/doc-type discovery in %.3fs",
|
||||||
|
time.perf_counter() - _t0,
|
||||||
|
)
|
||||||
|
|
||||||
# Build dependencies dict for the tools registry
|
# Build dependencies dict for the tools registry
|
||||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||||
|
|
@ -274,14 +279,21 @@ async def create_surfsense_deep_agent(
|
||||||
modified_disabled_tools.extend(linear_tools)
|
modified_disabled_tools.extend(linear_tools)
|
||||||
|
|
||||||
# Build tools using the async registry (includes MCP tools)
|
# Build tools using the async registry (includes MCP tools)
|
||||||
|
_t0 = time.perf_counter()
|
||||||
tools = await build_tools_async(
|
tools = await build_tools_async(
|
||||||
dependencies=dependencies,
|
dependencies=dependencies,
|
||||||
enabled_tools=enabled_tools,
|
enabled_tools=enabled_tools,
|
||||||
disabled_tools=modified_disabled_tools,
|
disabled_tools=modified_disabled_tools,
|
||||||
additional_tools=list(additional_tools) if additional_tools else None,
|
additional_tools=list(additional_tools) if additional_tools else None,
|
||||||
)
|
)
|
||||||
|
_perf_log.info(
|
||||||
|
"[create_agent] build_tools_async in %.3fs (%d tools)",
|
||||||
|
time.perf_counter() - _t0,
|
||||||
|
len(tools),
|
||||||
|
)
|
||||||
|
|
||||||
# Build system prompt based on agent_config
|
# Build system prompt based on agent_config
|
||||||
|
_t0 = time.perf_counter()
|
||||||
_sandbox_enabled = sandbox_backend is not None
|
_sandbox_enabled = sandbox_backend is not None
|
||||||
if agent_config is not None:
|
if agent_config is not None:
|
||||||
system_prompt = build_configurable_system_prompt(
|
system_prompt = build_configurable_system_prompt(
|
||||||
|
|
@ -296,15 +308,18 @@ async def create_surfsense_deep_agent(
|
||||||
thread_visibility=thread_visibility,
|
thread_visibility=thread_visibility,
|
||||||
sandbox_enabled=_sandbox_enabled,
|
sandbox_enabled=_sandbox_enabled,
|
||||||
)
|
)
|
||||||
|
_perf_log.info(
|
||||||
|
"[create_agent] System prompt built in %.3fs", time.perf_counter() - _t0
|
||||||
|
)
|
||||||
|
|
||||||
# Build optional kwargs for the deep agent
|
# Build optional kwargs for the deep agent
|
||||||
deep_agent_kwargs: dict[str, Any] = {}
|
deep_agent_kwargs: dict[str, Any] = {}
|
||||||
if sandbox_backend is not None:
|
if sandbox_backend is not None:
|
||||||
deep_agent_kwargs["backend"] = sandbox_backend
|
deep_agent_kwargs["backend"] = sandbox_backend
|
||||||
|
|
||||||
# Create the deep agent with system prompt and checkpointer
|
_t0 = time.perf_counter()
|
||||||
# Note: TodoListMiddleware (write_todos) is included by default in create_deep_agent
|
agent = await asyncio.to_thread(
|
||||||
agent = create_deep_agent(
|
create_deep_agent,
|
||||||
model=llm,
|
model=llm,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
|
|
@ -312,5 +327,13 @@ async def create_surfsense_deep_agent(
|
||||||
checkpointer=checkpointer,
|
checkpointer=checkpointer,
|
||||||
**deep_agent_kwargs,
|
**deep_agent_kwargs,
|
||||||
)
|
)
|
||||||
|
_perf_log.info(
|
||||||
|
"[create_agent] Graph compiled (create_deep_agent) in %.3fs",
|
||||||
|
time.perf_counter() - _t0,
|
||||||
|
)
|
||||||
|
|
||||||
|
_perf_log.info(
|
||||||
|
"[create_agent] Total agent creation in %.3fs",
|
||||||
|
time.perf_counter() - _t_agent_total,
|
||||||
|
)
|
||||||
return agent
|
return agent
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ the sandbox is deleted so they remain downloadable after cleanup.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
@ -56,6 +57,7 @@ class _TimeoutAwareSandbox(DaytonaSandbox):
|
||||||
|
|
||||||
|
|
||||||
_daytona_client: Daytona | None = None
|
_daytona_client: Daytona | None = None
|
||||||
|
_sandbox_cache: dict[str, _TimeoutAwareSandbox] = {}
|
||||||
THREAD_LABEL_KEY = "surfsense_thread"
|
THREAD_LABEL_KEY = "surfsense_thread"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -126,8 +128,8 @@ def _find_or_create(thread_id: str) -> _TimeoutAwareSandbox:
|
||||||
async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox:
|
async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox:
|
||||||
"""Get or create a sandbox for a conversation thread.
|
"""Get or create a sandbox for a conversation thread.
|
||||||
|
|
||||||
Uses the thread_id as a label so the same sandbox persists
|
Uses an in-process cache keyed by thread_id so subsequent messages
|
||||||
across multiple messages within the same conversation.
|
in the same conversation reuse the sandbox object without an API call.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
thread_id: The conversation thread identifier.
|
thread_id: The conversation thread identifier.
|
||||||
|
|
@ -135,11 +137,19 @@ async def get_or_create_sandbox(thread_id: int | str) -> _TimeoutAwareSandbox:
|
||||||
Returns:
|
Returns:
|
||||||
DaytonaSandbox connected to the sandbox.
|
DaytonaSandbox connected to the sandbox.
|
||||||
"""
|
"""
|
||||||
return await asyncio.to_thread(_find_or_create, str(thread_id))
|
key = str(thread_id)
|
||||||
|
cached = _sandbox_cache.get(key)
|
||||||
|
if cached is not None:
|
||||||
|
logger.info("Reusing cached sandbox for thread %s", key)
|
||||||
|
return cached
|
||||||
|
sandbox = await asyncio.to_thread(_find_or_create, key)
|
||||||
|
_sandbox_cache[key] = sandbox
|
||||||
|
return sandbox
|
||||||
|
|
||||||
|
|
||||||
async def delete_sandbox(thread_id: int | str) -> None:
|
async def delete_sandbox(thread_id: int | str) -> None:
|
||||||
"""Delete the sandbox for a conversation thread."""
|
"""Delete the sandbox for a conversation thread."""
|
||||||
|
_sandbox_cache.pop(str(thread_id), None)
|
||||||
|
|
||||||
def _delete() -> None:
|
def _delete() -> None:
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
@ -147,7 +157,9 @@ async def delete_sandbox(thread_id: int | str) -> None:
|
||||||
try:
|
try:
|
||||||
sandbox = client.find_one(labels=labels)
|
sandbox = client.find_one(labels=labels)
|
||||||
except DaytonaError:
|
except DaytonaError:
|
||||||
logger.debug("No sandbox to delete for thread %s (already removed)", thread_id)
|
logger.debug(
|
||||||
|
"No sandbox to delete for thread %s (already removed)", thread_id
|
||||||
|
)
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
client.delete(sandbox)
|
client.delete(sandbox)
|
||||||
|
|
@ -166,6 +178,7 @@ async def delete_sandbox(thread_id: int | str) -> None:
|
||||||
# Local file persistence
|
# Local file persistence
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _get_sandbox_files_dir() -> Path:
|
def _get_sandbox_files_dir() -> Path:
|
||||||
return Path(os.environ.get("SANDBOX_FILES_DIR", "sandbox_files"))
|
return Path(os.environ.get("SANDBOX_FILES_DIR", "sandbox_files"))
|
||||||
|
|
||||||
|
|
@ -206,6 +219,7 @@ async def persist_and_delete_sandbox(
|
||||||
Per-file errors are logged but do **not** prevent the sandbox from
|
Per-file errors are logged but do **not** prevent the sandbox from
|
||||||
being deleted — freeing Daytona storage is the priority.
|
being deleted — freeing Daytona storage is the priority.
|
||||||
"""
|
"""
|
||||||
|
_sandbox_cache.pop(str(thread_id), None)
|
||||||
|
|
||||||
def _persist_and_delete() -> None:
|
def _persist_and_delete() -> None:
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
@ -229,10 +243,8 @@ async def persist_and_delete_sandbox(
|
||||||
sandbox.id,
|
sandbox.id,
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
try:
|
with contextlib.suppress(Exception):
|
||||||
client.delete(sandbox)
|
client.delete(sandbox)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return
|
return
|
||||||
|
|
||||||
for path in sandbox_file_paths:
|
for path in sandbox_file_paths:
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ This implements real MCP protocol support similar to Cursor's implementation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import StructuredTool
|
from langchain_core.tools import StructuredTool
|
||||||
|
|
@ -25,6 +26,9 @@ from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_MCP_CACHE_TTL_SECONDS = 300 # 5 minutes
|
||||||
|
_mcp_tools_cache: dict[int, tuple[float, list[StructuredTool]]] = {}
|
||||||
|
|
||||||
|
|
||||||
def _create_dynamic_input_model_from_schema(
|
def _create_dynamic_input_model_from_schema(
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
|
|
@ -355,6 +359,19 @@ async def _load_http_mcp_tools(
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
||||||
|
"""Invalidate cached MCP tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_space_id: If provided, only invalidate for this search space.
|
||||||
|
If None, invalidate all cached MCP tools.
|
||||||
|
"""
|
||||||
|
if search_space_id is not None:
|
||||||
|
_mcp_tools_cache.pop(search_space_id, None)
|
||||||
|
else:
|
||||||
|
_mcp_tools_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
async def load_mcp_tools(
|
async def load_mcp_tools(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
|
|
@ -364,6 +381,9 @@ async def load_mcp_tools(
|
||||||
This discovers tools dynamically from MCP servers using the protocol.
|
This discovers tools dynamically from MCP servers using the protocol.
|
||||||
Supports both stdio (local process) and HTTP (remote server) transports.
|
Supports both stdio (local process) and HTTP (remote server) transports.
|
||||||
|
|
||||||
|
Results are cached per search space for up to 5 minutes to avoid
|
||||||
|
re-spawning MCP server processes on every chat message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: Database session
|
session: Database session
|
||||||
search_space_id: User's search space ID
|
search_space_id: User's search space ID
|
||||||
|
|
@ -372,8 +392,20 @@ async def load_mcp_tools(
|
||||||
List of LangChain StructuredTool instances
|
List of LangChain StructuredTool instances
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
now = time.monotonic()
|
||||||
|
cached = _mcp_tools_cache.get(search_space_id)
|
||||||
|
if cached is not None:
|
||||||
|
cached_at, cached_tools = cached
|
||||||
|
if now - cached_at < _MCP_CACHE_TTL_SECONDS:
|
||||||
|
logger.info(
|
||||||
|
"Using cached MCP tools for search space %s (%d tools, age=%.0fs)",
|
||||||
|
search_space_id,
|
||||||
|
len(cached_tools),
|
||||||
|
now - cached_at,
|
||||||
|
)
|
||||||
|
return list(cached_tools)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Fetch all MCP connectors for this search space
|
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(SearchSourceConnector).filter(
|
select(SearchSourceConnector).filter(
|
||||||
SearchSourceConnector.connector_type
|
SearchSourceConnector.connector_type
|
||||||
|
|
@ -385,27 +417,22 @@ async def load_mcp_tools(
|
||||||
tools: list[StructuredTool] = []
|
tools: list[StructuredTool] = []
|
||||||
for connector in result.scalars():
|
for connector in result.scalars():
|
||||||
try:
|
try:
|
||||||
# Early validation: Extract and validate connector config
|
|
||||||
config = connector.config or {}
|
config = connector.config or {}
|
||||||
server_config = config.get("server_config", {})
|
server_config = config.get("server_config", {})
|
||||||
|
|
||||||
# Validate server_config exists and is a dict
|
|
||||||
if not server_config or not isinstance(server_config, dict):
|
if not server_config or not isinstance(server_config, dict):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping"
|
f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping"
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Determine transport type
|
|
||||||
transport = server_config.get("transport", "stdio")
|
transport = server_config.get("transport", "stdio")
|
||||||
|
|
||||||
if transport in ("streamable-http", "http", "sse"):
|
if transport in ("streamable-http", "http", "sse"):
|
||||||
# HTTP-based MCP server
|
|
||||||
connector_tools = await _load_http_mcp_tools(
|
connector_tools = await _load_http_mcp_tools(
|
||||||
connector.id, connector.name, server_config
|
connector.id, connector.name, server_config
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# stdio-based MCP server (default)
|
|
||||||
connector_tools = await _load_stdio_mcp_tools(
|
connector_tools = await _load_stdio_mcp_tools(
|
||||||
connector.id, connector.name, server_config
|
connector.id, connector.name, server_config
|
||||||
)
|
)
|
||||||
|
|
@ -417,6 +444,7 @@ async def load_mcp_tools(
|
||||||
f"Failed to load tools from MCP connector {connector.id}: {e!s}"
|
f"Failed to load tools from MCP connector {connector.id}: {e!s}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_mcp_tools_cache[search_space_id] = (now, tools)
|
||||||
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
|
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -444,8 +444,18 @@ async def build_tools_async(
|
||||||
List of configured tool instances ready for the agent, including MCP tools.
|
List of configured tool instances ready for the agent, including MCP tools.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Build standard tools
|
import time
|
||||||
|
|
||||||
|
_perf_log = logging.getLogger("surfsense.perf")
|
||||||
|
_perf_log.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
_t0 = time.perf_counter()
|
||||||
tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools)
|
tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools)
|
||||||
|
_perf_log.info(
|
||||||
|
"[build_tools_async] Built-in tools in %.3fs (%d tools)",
|
||||||
|
time.perf_counter() - _t0,
|
||||||
|
len(tools),
|
||||||
|
)
|
||||||
|
|
||||||
# Load MCP tools if requested and dependencies are available
|
# Load MCP tools if requested and dependencies are available
|
||||||
if (
|
if (
|
||||||
|
|
@ -454,10 +464,16 @@ async def build_tools_async(
|
||||||
and "search_space_id" in dependencies
|
and "search_space_id" in dependencies
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
_t0 = time.perf_counter()
|
||||||
mcp_tools = await load_mcp_tools(
|
mcp_tools = await load_mcp_tools(
|
||||||
dependencies["db_session"],
|
dependencies["db_session"],
|
||||||
dependencies["search_space_id"],
|
dependencies["search_space_id"],
|
||||||
)
|
)
|
||||||
|
_perf_log.info(
|
||||||
|
"[build_tools_async] MCP tools loaded in %.3fs (%d tools)",
|
||||||
|
time.perf_counter() - _t0,
|
||||||
|
len(mcp_tools),
|
||||||
|
)
|
||||||
tools.extend(mcp_tools)
|
tools.extend(mcp_tools)
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}",
|
f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}",
|
||||||
|
|
|
||||||
|
|
@ -175,8 +175,39 @@ def rate_limit_password_reset(request: Request):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _enable_slow_callback_logging(threshold_sec: float = 0.5) -> None:
|
||||||
|
"""Monkey-patch the event loop to warn whenever a callback blocks longer than *threshold_sec*.
|
||||||
|
|
||||||
|
This helps pinpoint synchronous code that freezes the entire FastAPI server.
|
||||||
|
Only active when the PERF_DEBUG env var is set (to avoid overhead in production).
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
if not os.environ.get("PERF_DEBUG"):
|
||||||
|
return
|
||||||
|
|
||||||
|
_slow_log = logging.getLogger("surfsense.perf.slow")
|
||||||
|
_slow_log.setLevel(logging.WARNING)
|
||||||
|
if not _slow_log.handlers:
|
||||||
|
_h = logging.StreamHandler()
|
||||||
|
_h.setFormatter(logging.Formatter("%(asctime)s [SLOW-CALLBACK] %(message)s"))
|
||||||
|
_slow_log.addHandler(_h)
|
||||||
|
_slow_log.propagate = False
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
loop.slow_callback_duration = threshold_sec # type: ignore[attr-defined]
|
||||||
|
loop.set_debug(True)
|
||||||
|
_slow_log.warning(
|
||||||
|
"Event-loop slow-callback detector ENABLED (threshold=%.1fs). "
|
||||||
|
"Set PERF_DEBUG='' to disable.",
|
||||||
|
threshold_sec,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
|
# Enable slow-callback detection (set PERF_DEBUG=1 env var to activate)
|
||||||
|
_enable_slow_callback_logging(threshold_sec=0.5)
|
||||||
# Not needed if you setup a migration system like Alembic
|
# Not needed if you setup a migration system like Alembic
|
||||||
await create_db_and_tables()
|
await create_db_and_tables()
|
||||||
# Setup LangGraph checkpointer tables for conversation persistence
|
# Setup LangGraph checkpointer tables for conversation persistence
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from app.db import DocumentType
|
||||||
|
|
||||||
class ConnectorDocument(BaseModel):
|
class ConnectorDocument(BaseModel):
|
||||||
"""Canonical data transfer object produced by connector adapters and consumed by the indexing pipeline."""
|
"""Canonical data transfer object produced by connector adapters and consumed by the indexing pipeline."""
|
||||||
|
|
||||||
title: str
|
title: str
|
||||||
source_markdown: str
|
source_markdown: str
|
||||||
unique_id: str
|
unique_id: str
|
||||||
|
|
|
||||||
|
|
@ -3,5 +3,7 @@ from app.config import config
|
||||||
|
|
||||||
def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]:
|
def chunk_text(text: str, use_code_chunker: bool = False) -> list[str]:
|
||||||
"""Chunk a text string using the configured chunker and return the chunk texts."""
|
"""Chunk a text string using the configured chunker and return the chunk texts."""
|
||||||
chunker = config.code_chunker_instance if use_code_chunker else config.chunker_instance
|
chunker = (
|
||||||
|
config.code_chunker_instance if use_code_chunker else config.chunker_instance
|
||||||
|
)
|
||||||
return [c.text for c in chunker.chunk(text)]
|
return [c.text for c in chunker.chunk(text)]
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,9 @@ from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||||
from app.utils.document_converters import optimize_content_for_context_window
|
from app.utils.document_converters import optimize_content_for_context_window
|
||||||
|
|
||||||
|
|
||||||
async def summarize_document(source_markdown: str, llm, metadata: dict | None = None) -> str:
|
async def summarize_document(
|
||||||
|
source_markdown: str, llm, metadata: dict | None = None
|
||||||
|
) -> str:
|
||||||
"""Generate a text summary of a document using an LLM, prefixed with metadata when provided."""
|
"""Generate a text summary of a document using an LLM, prefixed with metadata when provided."""
|
||||||
model_name = getattr(llm, "model", "gpt-3.5-turbo")
|
model_name = getattr(llm, "model", "gpt-3.5-turbo")
|
||||||
optimized_content = optimize_content_for_context_window(
|
optimized_content = optimize_content_for_context_window(
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from litellm.exceptions import (
|
||||||
Timeout,
|
Timeout,
|
||||||
UnprocessableEntityError,
|
UnprocessableEntityError,
|
||||||
)
|
)
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError as IntegrityError
|
||||||
|
|
||||||
# Tuples for use directly in except clauses.
|
# Tuples for use directly in except clauses.
|
||||||
RETRYABLE_LLM_ERRORS = (
|
RETRYABLE_LLM_ERRORS = (
|
||||||
|
|
@ -53,10 +53,14 @@ class PipelineMessages:
|
||||||
LLM_PERMISSION = "LLM request denied. Check your account permissions."
|
LLM_PERMISSION = "LLM request denied. Check your account permissions."
|
||||||
LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
|
LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
|
||||||
LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid."
|
LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid."
|
||||||
LLM_UNPROCESSABLE = "Document exceeds the LLM context window even after optimization."
|
LLM_UNPROCESSABLE = (
|
||||||
|
"Document exceeds the LLM context window even after optimization."
|
||||||
|
)
|
||||||
LLM_RESPONSE = "LLM returned an invalid response."
|
LLM_RESPONSE = "LLM returned an invalid response."
|
||||||
|
|
||||||
EMBEDDING_FAILED = "Embedding failed. Check your embedding model configuration or service."
|
EMBEDDING_FAILED = (
|
||||||
|
"Embedding failed. Check your embedding model configuration or service."
|
||||||
|
)
|
||||||
EMBEDDING_MODEL = "Embedding model files are missing or corrupted."
|
EMBEDDING_MODEL = "Embedding model files are missing or corrupted."
|
||||||
EMBEDDING_MEMORY = "Not enough memory to embed this document."
|
EMBEDDING_MEMORY = "Not enough memory to embed this document."
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,9 @@ class LogMessages:
|
||||||
# index
|
# index
|
||||||
INDEX_STARTED = "Document indexing started."
|
INDEX_STARTED = "Document indexing started."
|
||||||
INDEX_SUCCESS = "Document indexed successfully."
|
INDEX_SUCCESS = "Document indexed successfully."
|
||||||
LLM_RETRYABLE = "Retryable LLM error — document marked failed, will retry on next sync."
|
LLM_RETRYABLE = (
|
||||||
|
"Retryable LLM error — document marked failed, will retry on next sync."
|
||||||
|
)
|
||||||
LLM_PERMANENT = "Permanent LLM error — document marked failed."
|
LLM_PERMANENT = "Permanent LLM error — document marked failed."
|
||||||
EMBEDDING_FAILED = "Embedding error — document marked failed."
|
EMBEDDING_FAILED = "Embedding error — document marked failed."
|
||||||
CHUNKING_OVERFLOW = "Chunking overflow — document marked failed."
|
CHUNKING_OVERFLOW = "Chunking overflow — document marked failed."
|
||||||
|
|
@ -52,7 +54,9 @@ def _build_message(msg: str, ctx: PipelineLogContext, **extra) -> str:
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
def _safe_log(level_fn, msg: str, ctx: PipelineLogContext, exc_info=None, **extra) -> None:
|
def _safe_log(
|
||||||
|
level_fn, msg: str, ctx: PipelineLogContext, exc_info=None, **extra
|
||||||
|
) -> None:
|
||||||
# Logging must never raise — a broken log call inside an except block would
|
# Logging must never raise — a broken log call inside an except block would
|
||||||
# chain with the original exception and mask it entirely.
|
# chain with the original exception and mask it entirely.
|
||||||
try:
|
try:
|
||||||
|
|
@ -64,6 +68,7 @@ def _safe_log(level_fn, msg: str, ctx: PipelineLogContext, exc_info=None, **extr
|
||||||
|
|
||||||
# ── prepare_for_indexing ──────────────────────────────────────────────────────
|
# ── prepare_for_indexing ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def log_document_queued(ctx: PipelineLogContext) -> None:
|
def log_document_queued(ctx: PipelineLogContext) -> None:
|
||||||
_safe_log(logger.info, LogMessages.DOCUMENT_QUEUED, ctx)
|
_safe_log(logger.info, LogMessages.DOCUMENT_QUEUED, ctx)
|
||||||
|
|
||||||
|
|
@ -77,7 +82,9 @@ def log_document_requeued(ctx: PipelineLogContext) -> None:
|
||||||
|
|
||||||
|
|
||||||
def log_doc_skipped_unknown(ctx: PipelineLogContext, exc: Exception) -> None:
|
def log_doc_skipped_unknown(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||||
_safe_log(logger.warning, LogMessages.DOC_SKIPPED_UNKNOWN, ctx, exc_info=exc, error=exc)
|
_safe_log(
|
||||||
|
logger.warning, LogMessages.DOC_SKIPPED_UNKNOWN, ctx, exc_info=exc, error=exc
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def log_race_condition(ctx: PipelineLogContext) -> None:
|
def log_race_condition(ctx: PipelineLogContext) -> None:
|
||||||
|
|
@ -90,6 +97,7 @@ def log_batch_aborted(ctx: PipelineLogContext, exc: Exception) -> None:
|
||||||
|
|
||||||
# ── index ─────────────────────────────────────────────────────────────────────
|
# ── index ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def log_index_started(ctx: PipelineLogContext) -> None:
|
def log_index_started(ctx: PipelineLogContext) -> None:
|
||||||
_safe_log(logger.info, LogMessages.INDEX_STARTED, ctx)
|
_safe_log(logger.info, LogMessages.INDEX_STARTED, ctx)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,8 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
|
||||||
- POST /threads/{thread_id}/messages - Append message
|
- POST /threads/{thread_id}/messages - Append message
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
|
@ -52,10 +54,8 @@ from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
from app.utils.rbac import check_permission
|
from app.utils.rbac import check_permission
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
_background_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
@ -75,15 +75,25 @@ def _try_delete_sandbox(thread_id: int) -> None:
|
||||||
try:
|
try:
|
||||||
await delete_sandbox(thread_id)
|
await delete_sandbox(thread_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
_logger.warning("Background sandbox delete failed for thread %s", thread_id, exc_info=True)
|
_logger.warning(
|
||||||
|
"Background sandbox delete failed for thread %s",
|
||||||
|
thread_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
delete_local_sandbox_files(thread_id)
|
delete_local_sandbox_files(thread_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
_logger.warning("Local sandbox file cleanup failed for thread %s", thread_id, exc_info=True)
|
_logger.warning(
|
||||||
|
"Local sandbox file cleanup failed for thread %s",
|
||||||
|
thread_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
loop.create_task(_bg())
|
task = loop.create_task(_bg())
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -87,7 +87,7 @@ async def download_sandbox_file(
|
||||||
# Fall back to live sandbox download
|
# Fall back to live sandbox download
|
||||||
try:
|
try:
|
||||||
sandbox = await get_or_create_sandbox(thread_id)
|
sandbox = await get_or_create_sandbox(thread_id)
|
||||||
raw_sandbox = sandbox._sandbox # noqa: SLF001
|
raw_sandbox = sandbox._sandbox
|
||||||
content: bytes = await asyncio.to_thread(raw_sandbox.fs.download_file, path)
|
content: bytes = await asyncio.to_thread(raw_sandbox.fs.download_file, path)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Sandbox file download failed for %s: %s", path, exc)
|
logger.warning("Sandbox file download failed for %s: %s", path, exc)
|
||||||
|
|
|
||||||
|
|
@ -2735,7 +2735,10 @@ async def create_mcp_connector(
|
||||||
f"for user {user.id} in search space {search_space_id}"
|
f"for user {user.id} in search space {search_space_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert to read schema
|
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
||||||
|
|
||||||
|
invalidate_mcp_tools_cache(search_space_id)
|
||||||
|
|
||||||
connector_read = SearchSourceConnectorRead.model_validate(db_connector)
|
connector_read = SearchSourceConnectorRead.model_validate(db_connector)
|
||||||
return MCPConnectorRead.from_connector(connector_read)
|
return MCPConnectorRead.from_connector(connector_read)
|
||||||
|
|
||||||
|
|
@ -2910,6 +2913,10 @@ async def update_mcp_connector(
|
||||||
|
|
||||||
logger.info(f"Updated MCP connector {connector_id}")
|
logger.info(f"Updated MCP connector {connector_id}")
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
||||||
|
|
||||||
|
invalidate_mcp_tools_cache(connector.search_space_id)
|
||||||
|
|
||||||
connector_read = SearchSourceConnectorRead.model_validate(connector)
|
connector_read = SearchSourceConnectorRead.model_validate(connector)
|
||||||
return MCPConnectorRead.from_connector(connector_read)
|
return MCPConnectorRead.from_connector(connector_read)
|
||||||
|
|
||||||
|
|
@ -2960,9 +2967,14 @@ async def delete_mcp_connector(
|
||||||
"You don't have permission to delete this connector",
|
"You don't have permission to delete this connector",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
search_space_id = connector.search_space_id
|
||||||
await session.delete(connector)
|
await session.delete(connector)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
||||||
|
|
||||||
|
invalidate_mcp_tools_cache(search_space_id)
|
||||||
|
|
||||||
logger.info(f"Deleted MCP connector {connector_id}")
|
logger.info(f"Deleted MCP connector {connector_id}")
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|
|
||||||
|
|
@ -13,14 +13,17 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
|
from sqlalchemy import func
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
|
||||||
from app.agents.new_chat.checkpointer import get_checkpointer
|
from app.agents.new_chat.checkpointer import get_checkpointer
|
||||||
|
|
@ -31,10 +34,17 @@ from app.agents.new_chat.llm_config import (
|
||||||
load_agent_config,
|
load_agent_config,
|
||||||
load_llm_config_from_yaml,
|
load_llm_config_from_yaml,
|
||||||
)
|
)
|
||||||
|
from app.agents.new_chat.sandbox import (
|
||||||
|
get_or_create_sandbox,
|
||||||
|
is_sandbox_enabled,
|
||||||
|
)
|
||||||
from app.db import (
|
from app.db import (
|
||||||
ChatVisibility,
|
ChatVisibility,
|
||||||
Document,
|
Document,
|
||||||
|
NewChatMessage,
|
||||||
|
NewChatThread,
|
||||||
Report,
|
Report,
|
||||||
|
SearchSourceConnectorType,
|
||||||
SurfsenseDocsDocument,
|
SurfsenseDocsDocument,
|
||||||
async_session_maker,
|
async_session_maker,
|
||||||
)
|
)
|
||||||
|
|
@ -47,6 +57,16 @@ from app.services.connector_service import ConnectorService
|
||||||
from app.services.new_streaming_service import VercelStreamingService
|
from app.services.new_streaming_service import VercelStreamingService
|
||||||
from app.utils.content_utils import bootstrap_history_from_db
|
from app.utils.content_utils import bootstrap_history_from_db
|
||||||
|
|
||||||
|
_perf_log = logging.getLogger("surfsense.perf")
|
||||||
|
_perf_log.setLevel(logging.DEBUG)
|
||||||
|
if not _perf_log.handlers:
|
||||||
|
_h = logging.StreamHandler()
|
||||||
|
_h.setFormatter(logging.Formatter("%(asctime)s [PERF] %(message)s"))
|
||||||
|
_perf_log.addHandler(_h)
|
||||||
|
_perf_log.propagate = False
|
||||||
|
|
||||||
|
_background_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
|
||||||
def format_mentioned_documents_as_context(documents: list[Document]) -> str:
|
def format_mentioned_documents_as_context(documents: list[Document]) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
@ -877,7 +897,9 @@ async def _stream_agent_events(
|
||||||
output_text = om.group(1) if om else ""
|
output_text = om.group(1) if om else ""
|
||||||
thread_id_str = config.get("configurable", {}).get("thread_id", "")
|
thread_id_str = config.get("configurable", {}).get("thread_id", "")
|
||||||
|
|
||||||
for sf_match in re.finditer(r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE):
|
for sf_match in re.finditer(
|
||||||
|
r"^SANDBOX_FILE:\s*(.+)$", output_text, re.MULTILINE
|
||||||
|
):
|
||||||
fpath = sf_match.group(1).strip()
|
fpath = sf_match.group(1).strip()
|
||||||
if fpath and fpath not in result.sandbox_files:
|
if fpath and fpath not in result.sandbox_files:
|
||||||
result.sandbox_files.append(fpath)
|
result.sandbox_files.append(fpath)
|
||||||
|
|
@ -963,7 +985,10 @@ def _try_persist_and_delete_sandbox(
|
||||||
sandbox_files: list[str],
|
sandbox_files: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Fire-and-forget: persist sandbox files locally then delete the sandbox."""
|
"""Fire-and-forget: persist sandbox files locally then delete the sandbox."""
|
||||||
from app.agents.new_chat.sandbox import is_sandbox_enabled, persist_and_delete_sandbox
|
from app.agents.new_chat.sandbox import (
|
||||||
|
is_sandbox_enabled,
|
||||||
|
persist_and_delete_sandbox,
|
||||||
|
)
|
||||||
|
|
||||||
if not is_sandbox_enabled():
|
if not is_sandbox_enabled():
|
||||||
return
|
return
|
||||||
|
|
@ -980,7 +1005,9 @@ def _try_persist_and_delete_sandbox(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
loop.create_task(_run())
|
task = loop.create_task(_run())
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -1022,6 +1049,7 @@ async def stream_new_chat(
|
||||||
"""
|
"""
|
||||||
streaming_service = VercelStreamingService()
|
streaming_service = VercelStreamingService()
|
||||||
stream_result = StreamResult()
|
stream_result = StreamResult()
|
||||||
|
_t_total = time.perf_counter()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Mark AI as responding to this user for live collaboration
|
# Mark AI as responding to this user for live collaboration
|
||||||
|
|
@ -1030,6 +1058,7 @@ async def stream_new_chat(
|
||||||
# Load LLM config - supports both YAML (negative IDs) and database (positive IDs)
|
# Load LLM config - supports both YAML (negative IDs) and database (positive IDs)
|
||||||
agent_config: AgentConfig | None = None
|
agent_config: AgentConfig | None = None
|
||||||
|
|
||||||
|
_t0 = time.perf_counter()
|
||||||
if llm_config_id >= 0:
|
if llm_config_id >= 0:
|
||||||
# Positive ID: Load from NewLLMConfig database table
|
# Positive ID: Load from NewLLMConfig database table
|
||||||
agent_config = await load_agent_config(
|
agent_config = await load_agent_config(
|
||||||
|
|
@ -1060,6 +1089,11 @@ async def stream_new_chat(
|
||||||
llm = create_chat_litellm_from_config(llm_config)
|
llm = create_chat_litellm_from_config(llm_config)
|
||||||
# Create AgentConfig from YAML for consistency (uses defaults for prompt settings)
|
# Create AgentConfig from YAML for consistency (uses defaults for prompt settings)
|
||||||
agent_config = AgentConfig.from_yaml_config(llm_config)
|
agent_config = AgentConfig.from_yaml_config(llm_config)
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_new_chat] LLM config loaded in %.3fs (config_id=%s)",
|
||||||
|
time.perf_counter() - _t0,
|
||||||
|
llm_config_id,
|
||||||
|
)
|
||||||
|
|
||||||
if not llm:
|
if not llm:
|
||||||
yield streaming_service.format_error("Failed to create LLM instance")
|
yield streaming_service.format_error("Failed to create LLM instance")
|
||||||
|
|
@ -1067,28 +1101,29 @@ async def stream_new_chat(
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create connector service
|
# Create connector service
|
||||||
|
_t0 = time.perf_counter()
|
||||||
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
||||||
|
|
||||||
# Get Firecrawl API key from webcrawler connector if configured
|
|
||||||
from app.db import SearchSourceConnectorType
|
|
||||||
|
|
||||||
firecrawl_api_key = None
|
firecrawl_api_key = None
|
||||||
webcrawler_connector = await connector_service.get_connector_by_type(
|
webcrawler_connector = await connector_service.get_connector_by_type(
|
||||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
|
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
|
||||||
)
|
)
|
||||||
if webcrawler_connector and webcrawler_connector.config:
|
if webcrawler_connector and webcrawler_connector.config:
|
||||||
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
|
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
|
||||||
|
_perf_log.info(
|
||||||
# Get the PostgreSQL checkpointer for persistent conversation memory
|
"[stream_new_chat] Connector service + firecrawl key in %.3fs",
|
||||||
checkpointer = await get_checkpointer()
|
time.perf_counter() - _t0,
|
||||||
|
|
||||||
# Optionally provision a sandboxed code execution environment
|
|
||||||
sandbox_backend = None
|
|
||||||
from app.agents.new_chat.sandbox import (
|
|
||||||
get_or_create_sandbox,
|
|
||||||
is_sandbox_enabled,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get the PostgreSQL checkpointer for persistent conversation memory
|
||||||
|
_t0 = time.perf_counter()
|
||||||
|
checkpointer = await get_checkpointer()
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_new_chat] Checkpointer ready in %.3fs", time.perf_counter() - _t0
|
||||||
|
)
|
||||||
|
|
||||||
|
sandbox_backend = None
|
||||||
|
_t0 = time.perf_counter()
|
||||||
if is_sandbox_enabled():
|
if is_sandbox_enabled():
|
||||||
try:
|
try:
|
||||||
sandbox_backend = await get_or_create_sandbox(chat_id)
|
sandbox_backend = await get_or_create_sandbox(chat_id)
|
||||||
|
|
@ -1097,8 +1132,14 @@ async def stream_new_chat(
|
||||||
"Sandbox creation failed, continuing without execute tool: %s",
|
"Sandbox creation failed, continuing without execute tool: %s",
|
||||||
sandbox_err,
|
sandbox_err,
|
||||||
)
|
)
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_new_chat] Sandbox provisioning in %.3fs (enabled=%s)",
|
||||||
|
time.perf_counter() - _t0,
|
||||||
|
sandbox_backend is not None,
|
||||||
|
)
|
||||||
|
|
||||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||||
|
_t0 = time.perf_counter()
|
||||||
agent = await create_surfsense_deep_agent(
|
agent = await create_surfsense_deep_agent(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
|
|
@ -1112,19 +1153,20 @@ async def stream_new_chat(
|
||||||
thread_visibility=visibility,
|
thread_visibility=visibility,
|
||||||
sandbox_backend=sandbox_backend,
|
sandbox_backend=sandbox_backend,
|
||||||
)
|
)
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
|
||||||
|
)
|
||||||
|
|
||||||
# Build input with message history
|
# Build input with message history
|
||||||
langchain_messages = []
|
langchain_messages = []
|
||||||
|
|
||||||
|
_t0 = time.perf_counter()
|
||||||
# Bootstrap history for cloned chats (no LangGraph checkpoint exists yet)
|
# Bootstrap history for cloned chats (no LangGraph checkpoint exists yet)
|
||||||
if needs_history_bootstrap:
|
if needs_history_bootstrap:
|
||||||
langchain_messages = await bootstrap_history_from_db(
|
langchain_messages = await bootstrap_history_from_db(
|
||||||
session, chat_id, thread_visibility=visibility
|
session, chat_id, thread_visibility=visibility
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clear the flag so we don't bootstrap again on next message
|
|
||||||
from app.db import NewChatThread
|
|
||||||
|
|
||||||
thread_result = await session.execute(
|
thread_result = await session.execute(
|
||||||
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
||||||
)
|
)
|
||||||
|
|
@ -1136,11 +1178,9 @@ async def stream_new_chat(
|
||||||
# Fetch mentioned documents if any (with chunks for proper citations)
|
# Fetch mentioned documents if any (with chunks for proper citations)
|
||||||
mentioned_documents: list[Document] = []
|
mentioned_documents: list[Document] = []
|
||||||
if mentioned_document_ids:
|
if mentioned_document_ids:
|
||||||
from sqlalchemy.orm import selectinload as doc_selectinload
|
|
||||||
|
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(Document)
|
select(Document)
|
||||||
.options(doc_selectinload(Document.chunks))
|
.options(selectinload(Document.chunks))
|
||||||
.filter(
|
.filter(
|
||||||
Document.id.in_(mentioned_document_ids),
|
Document.id.in_(mentioned_document_ids),
|
||||||
Document.search_space_id == search_space_id,
|
Document.search_space_id == search_space_id,
|
||||||
|
|
@ -1151,8 +1191,6 @@ async def stream_new_chat(
|
||||||
# Fetch mentioned SurfSense docs if any
|
# Fetch mentioned SurfSense docs if any
|
||||||
mentioned_surfsense_docs: list[SurfsenseDocsDocument] = []
|
mentioned_surfsense_docs: list[SurfsenseDocsDocument] = []
|
||||||
if mentioned_surfsense_doc_ids:
|
if mentioned_surfsense_doc_ids:
|
||||||
from sqlalchemy.orm import selectinload
|
|
||||||
|
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(SurfsenseDocsDocument)
|
select(SurfsenseDocsDocument)
|
||||||
.options(selectinload(SurfsenseDocsDocument.chunks))
|
.options(selectinload(SurfsenseDocsDocument.chunks))
|
||||||
|
|
@ -1236,6 +1274,11 @@ async def stream_new_chat(
|
||||||
"search_space_id": search_space_id,
|
"search_space_id": search_space_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_new_chat] History bootstrap + doc/report queries in %.3fs",
|
||||||
|
time.perf_counter() - _t0,
|
||||||
|
)
|
||||||
|
|
||||||
# All pre-streaming DB reads are done. Commit to release the
|
# All pre-streaming DB reads are done. Commit to release the
|
||||||
# transaction and its ACCESS SHARE locks so we don't block DDL
|
# transaction and its ACCESS SHARE locks so we don't block DDL
|
||||||
# (e.g. migrations) for the entire duration of LLM streaming.
|
# (e.g. migrations) for the entire duration of LLM streaming.
|
||||||
|
|
@ -1243,6 +1286,12 @@ async def stream_new_chat(
|
||||||
# short-lived transactions (or use isolated sessions).
|
# short-lived transactions (or use isolated sessions).
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)",
|
||||||
|
time.perf_counter() - _t_total,
|
||||||
|
chat_id,
|
||||||
|
)
|
||||||
|
|
||||||
# Configure LangGraph with thread_id for memory
|
# Configure LangGraph with thread_id for memory
|
||||||
# If checkpoint_id is provided, fork from that checkpoint (for edit/reload)
|
# If checkpoint_id is provided, fork from that checkpoint (for edit/reload)
|
||||||
configurable = {"thread_id": str(chat_id)}
|
configurable = {"thread_id": str(chat_id)}
|
||||||
|
|
@ -1304,6 +1353,8 @@ async def stream_new_chat(
|
||||||
items=initial_items,
|
items=initial_items,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_t_stream_start = time.perf_counter()
|
||||||
|
_first_event_logged = False
|
||||||
async for sse in _stream_agent_events(
|
async for sse in _stream_agent_events(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
config=config,
|
config=config,
|
||||||
|
|
@ -1315,8 +1366,23 @@ async def stream_new_chat(
|
||||||
initial_step_title=initial_title,
|
initial_step_title=initial_title,
|
||||||
initial_step_items=initial_items,
|
initial_step_items=initial_items,
|
||||||
):
|
):
|
||||||
|
if not _first_event_logged:
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_new_chat] First agent event in %.3fs (time since stream start), "
|
||||||
|
"%.3fs (total since request start) (chat_id=%s)",
|
||||||
|
time.perf_counter() - _t_stream_start,
|
||||||
|
time.perf_counter() - _t_total,
|
||||||
|
chat_id,
|
||||||
|
)
|
||||||
|
_first_event_logged = True
|
||||||
yield sse
|
yield sse
|
||||||
|
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
|
||||||
|
time.perf_counter() - _t_stream_start,
|
||||||
|
chat_id,
|
||||||
|
)
|
||||||
|
|
||||||
if stream_result.is_interrupted:
|
if stream_result.is_interrupted:
|
||||||
yield streaming_service.format_finish_step()
|
yield streaming_service.format_finish_step()
|
||||||
yield streaming_service.format_finish()
|
yield streaming_service.format_finish()
|
||||||
|
|
@ -1325,12 +1391,6 @@ async def stream_new_chat(
|
||||||
|
|
||||||
accumulated_text = stream_result.accumulated_text
|
accumulated_text = stream_result.accumulated_text
|
||||||
|
|
||||||
# Generate LLM title for new chats after first response
|
|
||||||
# Check if this is the first assistant response by counting existing assistant messages
|
|
||||||
from sqlalchemy import func
|
|
||||||
|
|
||||||
from app.db import NewChatMessage, NewChatThread
|
|
||||||
|
|
||||||
assistant_count_result = await session.execute(
|
assistant_count_result = await session.execute(
|
||||||
select(func.count(NewChatMessage.id)).filter(
|
select(func.count(NewChatMessage.id)).filter(
|
||||||
NewChatMessage.thread_id == chat_id,
|
NewChatMessage.thread_id == chat_id,
|
||||||
|
|
@ -1431,12 +1491,14 @@ async def stream_resume_chat(
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
streaming_service = VercelStreamingService()
|
streaming_service = VercelStreamingService()
|
||||||
stream_result = StreamResult()
|
stream_result = StreamResult()
|
||||||
|
_t_total = time.perf_counter()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if user_id:
|
if user_id:
|
||||||
await set_ai_responding(session, chat_id, UUID(user_id))
|
await set_ai_responding(session, chat_id, UUID(user_id))
|
||||||
|
|
||||||
agent_config: AgentConfig | None = None
|
agent_config: AgentConfig | None = None
|
||||||
|
_t0 = time.perf_counter()
|
||||||
if llm_config_id >= 0:
|
if llm_config_id >= 0:
|
||||||
agent_config = await load_agent_config(
|
agent_config = await load_agent_config(
|
||||||
session=session,
|
session=session,
|
||||||
|
|
@ -1460,31 +1522,37 @@ async def stream_resume_chat(
|
||||||
return
|
return
|
||||||
llm = create_chat_litellm_from_config(llm_config)
|
llm = create_chat_litellm_from_config(llm_config)
|
||||||
agent_config = AgentConfig.from_yaml_config(llm_config)
|
agent_config = AgentConfig.from_yaml_config(llm_config)
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_resume] LLM config loaded in %.3fs", time.perf_counter() - _t0
|
||||||
|
)
|
||||||
|
|
||||||
if not llm:
|
if not llm:
|
||||||
yield streaming_service.format_error("Failed to create LLM instance")
|
yield streaming_service.format_error("Failed to create LLM instance")
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
_t0 = time.perf_counter()
|
||||||
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
connector_service = ConnectorService(session, search_space_id=search_space_id)
|
||||||
|
|
||||||
from app.db import SearchSourceConnectorType
|
|
||||||
|
|
||||||
firecrawl_api_key = None
|
firecrawl_api_key = None
|
||||||
webcrawler_connector = await connector_service.get_connector_by_type(
|
webcrawler_connector = await connector_service.get_connector_by_type(
|
||||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
|
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR, search_space_id
|
||||||
)
|
)
|
||||||
if webcrawler_connector and webcrawler_connector.config:
|
if webcrawler_connector and webcrawler_connector.config:
|
||||||
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
|
firecrawl_api_key = webcrawler_connector.config.get("FIRECRAWL_API_KEY")
|
||||||
|
_perf_log.info(
|
||||||
checkpointer = await get_checkpointer()
|
"[stream_resume] Connector service + firecrawl key in %.3fs",
|
||||||
|
time.perf_counter() - _t0,
|
||||||
sandbox_backend = None
|
|
||||||
from app.agents.new_chat.sandbox import (
|
|
||||||
get_or_create_sandbox,
|
|
||||||
is_sandbox_enabled,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_t0 = time.perf_counter()
|
||||||
|
checkpointer = await get_checkpointer()
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_resume] Checkpointer ready in %.3fs", time.perf_counter() - _t0
|
||||||
|
)
|
||||||
|
|
||||||
|
sandbox_backend = None
|
||||||
|
_t0 = time.perf_counter()
|
||||||
if is_sandbox_enabled():
|
if is_sandbox_enabled():
|
||||||
try:
|
try:
|
||||||
sandbox_backend = await get_or_create_sandbox(chat_id)
|
sandbox_backend = await get_or_create_sandbox(chat_id)
|
||||||
|
|
@ -1493,9 +1561,15 @@ async def stream_resume_chat(
|
||||||
"Sandbox creation failed, continuing without execute tool: %s",
|
"Sandbox creation failed, continuing without execute tool: %s",
|
||||||
sandbox_err,
|
sandbox_err,
|
||||||
)
|
)
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_resume] Sandbox provisioning in %.3fs (enabled=%s)",
|
||||||
|
time.perf_counter() - _t0,
|
||||||
|
sandbox_backend is not None,
|
||||||
|
)
|
||||||
|
|
||||||
visibility = thread_visibility or ChatVisibility.PRIVATE
|
visibility = thread_visibility or ChatVisibility.PRIVATE
|
||||||
|
|
||||||
|
_t0 = time.perf_counter()
|
||||||
agent = await create_surfsense_deep_agent(
|
agent = await create_surfsense_deep_agent(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
|
|
@ -1509,10 +1583,19 @@ async def stream_resume_chat(
|
||||||
thread_visibility=visibility,
|
thread_visibility=visibility,
|
||||||
sandbox_backend=sandbox_backend,
|
sandbox_backend=sandbox_backend,
|
||||||
)
|
)
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_resume] Agent created in %.3fs", time.perf_counter() - _t0
|
||||||
|
)
|
||||||
|
|
||||||
# Release the transaction before streaming (same rationale as stream_new_chat).
|
# Release the transaction before streaming (same rationale as stream_new_chat).
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)",
|
||||||
|
time.perf_counter() - _t_total,
|
||||||
|
chat_id,
|
||||||
|
)
|
||||||
|
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
|
|
@ -1523,6 +1606,8 @@ async def stream_resume_chat(
|
||||||
yield streaming_service.format_message_start()
|
yield streaming_service.format_message_start()
|
||||||
yield streaming_service.format_start_step()
|
yield streaming_service.format_start_step()
|
||||||
|
|
||||||
|
_t_stream_start = time.perf_counter()
|
||||||
|
_first_event_logged = False
|
||||||
async for sse in _stream_agent_events(
|
async for sse in _stream_agent_events(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
config=config,
|
config=config,
|
||||||
|
|
@ -1531,7 +1616,20 @@ async def stream_resume_chat(
|
||||||
result=stream_result,
|
result=stream_result,
|
||||||
step_prefix="thinking-resume",
|
step_prefix="thinking-resume",
|
||||||
):
|
):
|
||||||
|
if not _first_event_logged:
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_resume] First agent event in %.3fs (stream), %.3fs (total) (chat_id=%s)",
|
||||||
|
time.perf_counter() - _t_stream_start,
|
||||||
|
time.perf_counter() - _t_total,
|
||||||
|
chat_id,
|
||||||
|
)
|
||||||
|
_first_event_logged = True
|
||||||
yield sse
|
yield sse
|
||||||
|
_perf_log.info(
|
||||||
|
"[stream_resume] Agent stream completed in %.3fs (chat_id=%s)",
|
||||||
|
time.perf_counter() - _t_stream_start,
|
||||||
|
chat_id,
|
||||||
|
)
|
||||||
if stream_result.is_interrupted:
|
if stream_result.is_interrupted:
|
||||||
yield streaming_service.format_finish_step()
|
yield streaming_service.format_finish_step()
|
||||||
yield streaming_service.format_finish()
|
yield streaming_service.format_finish()
|
||||||
|
|
|
||||||
|
|
@ -33,4 +33,5 @@ def make_connector_document():
|
||||||
}
|
}
|
||||||
defaults.update(overrides)
|
defaults.update(overrides)
|
||||||
return ConnectorDocument(**defaults)
|
return ConnectorDocument(**defaults)
|
||||||
|
|
||||||
return _make
|
return _make
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
@ -9,14 +8,21 @@ from sqlalchemy import text
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||||
from sqlalchemy.pool import NullPool
|
from sqlalchemy.pool import NullPool
|
||||||
|
|
||||||
from app.db import Base, SearchSpace, SearchSourceConnector, SearchSourceConnectorType
|
from app.db import (
|
||||||
from app.db import User
|
Base,
|
||||||
from app.db import DocumentType
|
DocumentType,
|
||||||
|
SearchSourceConnector,
|
||||||
|
SearchSourceConnectorType,
|
||||||
|
SearchSpace,
|
||||||
|
User,
|
||||||
|
)
|
||||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||||
|
|
||||||
_EMBEDDING_DIM = 1024 # must match the Vector() dimension used in DB column creation
|
_EMBEDDING_DIM = 1024 # must match the Vector() dimension used in DB column creation
|
||||||
|
|
||||||
_DEFAULT_TEST_DB = "postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test"
|
_DEFAULT_TEST_DB = (
|
||||||
|
"postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense_test"
|
||||||
|
)
|
||||||
TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB)
|
TEST_DATABASE_URL = os.environ.get("TEST_DATABASE_URL", _DEFAULT_TEST_DB)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -80,7 +86,9 @@ async def db_user(db_session: AsyncSession) -> User:
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
async def db_connector(db_session: AsyncSession, db_user: User, db_search_space: "SearchSpace") -> SearchSourceConnector:
|
async def db_connector(
|
||||||
|
db_session: AsyncSession, db_user: User, db_search_space: "SearchSpace"
|
||||||
|
) -> SearchSourceConnector:
|
||||||
connector = SearchSourceConnector(
|
connector = SearchSourceConnector(
|
||||||
name="Test Connector",
|
name="Test Connector",
|
||||||
connector_type=SearchSourceConnectorType.CLICKUP_CONNECTOR,
|
connector_type=SearchSourceConnectorType.CLICKUP_CONNECTOR,
|
||||||
|
|
@ -147,6 +155,7 @@ def patched_chunk_text(monkeypatch) -> MagicMock:
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def make_connector_document(db_connector, db_user):
|
def make_connector_document(db_connector, db_user):
|
||||||
"""Integration-scoped override: uses real DB connector and user IDs."""
|
"""Integration-scoped override: uses real DB connector and user IDs."""
|
||||||
|
|
||||||
def _make(**overrides):
|
def _make(**overrides):
|
||||||
defaults = {
|
defaults = {
|
||||||
"title": "Test Document",
|
"title": "Test Document",
|
||||||
|
|
@ -159,6 +168,5 @@ def make_connector_document(db_connector, db_user):
|
||||||
}
|
}
|
||||||
defaults.update(overrides)
|
defaults.update(overrides)
|
||||||
return ConnectorDocument(**defaults)
|
return ConnectorDocument(**defaults)
|
||||||
|
|
||||||
return _make
|
return _make
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,9 @@ from app.indexing_pipeline.adapters.file_upload_adapter import index_uploaded_fi
|
||||||
pytestmark = pytest.mark.integration
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_sets_status_ready(db_session, db_search_space, db_user, mocker):
|
async def test_sets_status_ready(db_session, db_search_space, db_user, mocker):
|
||||||
"""Document status is READY after successful indexing."""
|
"""Document status is READY after successful indexing."""
|
||||||
await index_uploaded_file(
|
await index_uploaded_file(
|
||||||
|
|
@ -28,7 +30,9 @@ async def test_sets_status_ready(db_session, db_search_space, db_user, mocker):
|
||||||
assert DocumentStatus.is_state(document.status, DocumentStatus.READY)
|
assert DocumentStatus.is_state(document.status, DocumentStatus.READY)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_content_is_summary(db_session, db_search_space, db_user, mocker):
|
async def test_content_is_summary(db_session, db_search_space, db_user, mocker):
|
||||||
"""Document content is set to the LLM-generated summary."""
|
"""Document content is set to the LLM-generated summary."""
|
||||||
await index_uploaded_file(
|
await index_uploaded_file(
|
||||||
|
|
@ -49,7 +53,9 @@ async def test_content_is_summary(db_session, db_search_space, db_user, mocker):
|
||||||
assert document.content == "Mocked summary."
|
assert document.content == "Mocked summary."
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_chunks_written_to_db(db_session, db_search_space, db_user, mocker):
|
async def test_chunks_written_to_db(db_session, db_search_space, db_user, mocker):
|
||||||
"""Chunks derived from the source markdown are persisted in the DB."""
|
"""Chunks derived from the source markdown are persisted in the DB."""
|
||||||
await index_uploaded_file(
|
await index_uploaded_file(
|
||||||
|
|
@ -76,7 +82,9 @@ async def test_chunks_written_to_db(db_session, db_search_space, db_user, mocker
|
||||||
assert chunks[0].content == "Test chunk content."
|
assert chunks[0].content == "Test chunk content."
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize_raises", "patched_embed_text", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_raises_on_indexing_failure(db_session, db_search_space, db_user, mocker):
|
async def test_raises_on_indexing_failure(db_session, db_search_space, db_user, mocker):
|
||||||
"""RuntimeError is raised when the indexing step fails so the caller can fire a failure notification."""
|
"""RuntimeError is raised when the indexing step fails so the caller can fire a failure notification."""
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
|
|
|
||||||
|
|
@ -7,9 +7,14 @@ from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineServ
|
||||||
pytestmark = pytest.mark.integration
|
pytestmark = pytest.mark.integration
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_sets_status_ready(
|
async def test_sets_status_ready(
|
||||||
db_session, db_search_space, make_connector_document, mocker,
|
db_session,
|
||||||
|
db_search_space,
|
||||||
|
make_connector_document,
|
||||||
|
mocker,
|
||||||
):
|
):
|
||||||
"""Document status is READY after successful indexing."""
|
"""Document status is READY after successful indexing."""
|
||||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||||
|
|
@ -21,15 +26,22 @@ async def test_sets_status_ready(
|
||||||
|
|
||||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
reloaded = result.scalars().first()
|
reloaded = result.scalars().first()
|
||||||
|
|
||||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
|
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_content_is_summary_when_should_summarize_true(
|
async def test_content_is_summary_when_should_summarize_true(
|
||||||
db_session, db_search_space, make_connector_document, mocker,
|
db_session,
|
||||||
|
db_search_space,
|
||||||
|
make_connector_document,
|
||||||
|
mocker,
|
||||||
):
|
):
|
||||||
"""Document content is set to the LLM-generated summary when should_summarize=True."""
|
"""Document content is set to the LLM-generated summary when should_summarize=True."""
|
||||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||||
|
|
@ -41,15 +53,21 @@ async def test_content_is_summary_when_should_summarize_true(
|
||||||
|
|
||||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
reloaded = result.scalars().first()
|
reloaded = result.scalars().first()
|
||||||
|
|
||||||
assert reloaded.content == "Mocked summary."
|
assert reloaded.content == "Mocked summary."
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_content_is_source_markdown_when_should_summarize_false(
|
async def test_content_is_source_markdown_when_should_summarize_false(
|
||||||
db_session, db_search_space, make_connector_document,
|
db_session,
|
||||||
|
db_search_space,
|
||||||
|
make_connector_document,
|
||||||
):
|
):
|
||||||
"""Document content is set to source_markdown verbatim when should_summarize=False."""
|
"""Document content is set to source_markdown verbatim when should_summarize=False."""
|
||||||
connector_doc = make_connector_document(
|
connector_doc = make_connector_document(
|
||||||
|
|
@ -65,15 +83,22 @@ async def test_content_is_source_markdown_when_should_summarize_false(
|
||||||
|
|
||||||
await service.index(document, connector_doc, llm=None)
|
await service.index(document, connector_doc, llm=None)
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
reloaded = result.scalars().first()
|
reloaded = result.scalars().first()
|
||||||
|
|
||||||
assert reloaded.content == "## Raw content"
|
assert reloaded.content == "## Raw content"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_chunks_written_to_db(
|
async def test_chunks_written_to_db(
|
||||||
db_session, db_search_space, make_connector_document, mocker,
|
db_session,
|
||||||
|
db_search_space,
|
||||||
|
make_connector_document,
|
||||||
|
mocker,
|
||||||
):
|
):
|
||||||
"""Chunks derived from source_markdown are persisted in the DB."""
|
"""Chunks derived from source_markdown are persisted in the DB."""
|
||||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||||
|
|
@ -94,9 +119,14 @@ async def test_chunks_written_to_db(
|
||||||
assert chunks[0].content == "Test chunk content."
|
assert chunks[0].content == "Test chunk content."
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_embedding_written_to_db(
|
async def test_embedding_written_to_db(
|
||||||
db_session, db_search_space, make_connector_document, mocker,
|
db_session,
|
||||||
|
db_search_space,
|
||||||
|
make_connector_document,
|
||||||
|
mocker,
|
||||||
):
|
):
|
||||||
"""Document embedding vector is persisted in the DB after indexing."""
|
"""Document embedding vector is persisted in the DB after indexing."""
|
||||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||||
|
|
@ -108,16 +138,23 @@ async def test_embedding_written_to_db(
|
||||||
|
|
||||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
reloaded = result.scalars().first()
|
reloaded = result.scalars().first()
|
||||||
|
|
||||||
assert reloaded.embedding is not None
|
assert reloaded.embedding is not None
|
||||||
assert len(reloaded.embedding) == 1024
|
assert len(reloaded.embedding) == 1024
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_updated_at_advances_after_indexing(
|
async def test_updated_at_advances_after_indexing(
|
||||||
db_session, db_search_space, make_connector_document, mocker,
|
db_session,
|
||||||
|
db_search_space,
|
||||||
|
make_connector_document,
|
||||||
|
mocker,
|
||||||
):
|
):
|
||||||
"""updated_at timestamp is later after indexing than it was at prepare time."""
|
"""updated_at timestamp is later after indexing than it was at prepare time."""
|
||||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||||
|
|
@ -127,20 +164,28 @@ async def test_updated_at_advances_after_indexing(
|
||||||
document = prepared[0]
|
document = prepared[0]
|
||||||
document_id = document.id
|
document_id = document.id
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
updated_at_pending = result.scalars().first().updated_at
|
updated_at_pending = result.scalars().first().updated_at
|
||||||
|
|
||||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
updated_at_ready = result.scalars().first().updated_at
|
updated_at_ready = result.scalars().first().updated_at
|
||||||
|
|
||||||
assert updated_at_ready > updated_at_pending
|
assert updated_at_ready > updated_at_pending
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_no_llm_falls_back_to_source_markdown(
|
async def test_no_llm_falls_back_to_source_markdown(
|
||||||
db_session, db_search_space, make_connector_document,
|
db_session,
|
||||||
|
db_search_space,
|
||||||
|
make_connector_document,
|
||||||
):
|
):
|
||||||
"""When llm=None and no fallback_summary, content falls back to source_markdown."""
|
"""When llm=None and no fallback_summary, content falls back to source_markdown."""
|
||||||
connector_doc = make_connector_document(
|
connector_doc = make_connector_document(
|
||||||
|
|
@ -156,16 +201,22 @@ async def test_no_llm_falls_back_to_source_markdown(
|
||||||
|
|
||||||
await service.index(document, connector_doc, llm=None)
|
await service.index(document, connector_doc, llm=None)
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
reloaded = result.scalars().first()
|
reloaded = result.scalars().first()
|
||||||
|
|
||||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
|
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
|
||||||
assert reloaded.content == "## Fallback content"
|
assert reloaded.content == "## Fallback content"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_fallback_summary_used_when_llm_unavailable(
|
async def test_fallback_summary_used_when_llm_unavailable(
|
||||||
db_session, db_search_space, make_connector_document,
|
db_session,
|
||||||
|
db_search_space,
|
||||||
|
make_connector_document,
|
||||||
):
|
):
|
||||||
"""fallback_summary is used as content when llm=None and should_summarize=True."""
|
"""fallback_summary is used as content when llm=None and should_summarize=True."""
|
||||||
connector_doc = make_connector_document(
|
connector_doc = make_connector_document(
|
||||||
|
|
@ -181,16 +232,23 @@ async def test_fallback_summary_used_when_llm_unavailable(
|
||||||
|
|
||||||
await service.index(prepared[0], connector_doc, llm=None)
|
await service.index(prepared[0], connector_doc, llm=None)
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
reloaded = result.scalars().first()
|
reloaded = result.scalars().first()
|
||||||
|
|
||||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
|
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.READY)
|
||||||
assert reloaded.content == "Short pre-built summary."
|
assert reloaded.content == "Short pre-built summary."
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_reindex_replaces_old_chunks(
|
async def test_reindex_replaces_old_chunks(
|
||||||
db_session, db_search_space, make_connector_document, mocker,
|
db_session,
|
||||||
|
db_search_space,
|
||||||
|
make_connector_document,
|
||||||
|
mocker,
|
||||||
):
|
):
|
||||||
"""Re-indexing a document replaces its old chunks rather than appending."""
|
"""Re-indexing a document replaces its old chunks rather than appending."""
|
||||||
connector_doc = make_connector_document(
|
connector_doc = make_connector_document(
|
||||||
|
|
@ -220,9 +278,14 @@ async def test_reindex_replaces_old_chunks(
|
||||||
assert len(chunks) == 1
|
assert len(chunks) == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize_raises", "patched_embed_text", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_llm_error_sets_status_failed(
|
async def test_llm_error_sets_status_failed(
|
||||||
db_session, db_search_space, make_connector_document, mocker,
|
db_session,
|
||||||
|
db_search_space,
|
||||||
|
make_connector_document,
|
||||||
|
mocker,
|
||||||
):
|
):
|
||||||
"""Document status is FAILED when the LLM raises during indexing."""
|
"""Document status is FAILED when the LLM raises during indexing."""
|
||||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||||
|
|
@ -234,15 +297,22 @@ async def test_llm_error_sets_status_failed(
|
||||||
|
|
||||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
reloaded = result.scalars().first()
|
reloaded = result.scalars().first()
|
||||||
|
|
||||||
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.FAILED)
|
assert DocumentStatus.is_state(reloaded.status, DocumentStatus.FAILED)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize_raises", "patched_embed_text", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_llm_error_leaves_no_partial_data(
|
async def test_llm_error_leaves_no_partial_data(
|
||||||
db_session, db_search_space, make_connector_document, mocker,
|
db_session,
|
||||||
|
db_search_space,
|
||||||
|
make_connector_document,
|
||||||
|
mocker,
|
||||||
):
|
):
|
||||||
"""A failed indexing attempt leaves no partial embedding or chunks in the DB."""
|
"""A failed indexing attempt leaves no partial embedding or chunks in the DB."""
|
||||||
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
connector_doc = make_connector_document(search_space_id=db_search_space.id)
|
||||||
|
|
@ -254,7 +324,9 @@ async def test_llm_error_leaves_no_partial_data(
|
||||||
|
|
||||||
await service.index(document, connector_doc, llm=mocker.Mock())
|
await service.index(document, connector_doc, llm=mocker.Mock())
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
reloaded = result.scalars().first()
|
reloaded = result.scalars().first()
|
||||||
|
|
||||||
assert reloaded.embedding is None
|
assert reloaded.embedding is None
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,9 @@ import pytest
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from app.db import Document, DocumentStatus
|
from app.db import Document, DocumentStatus
|
||||||
from app.indexing_pipeline.document_hashing import compute_content_hash as real_compute_content_hash
|
from app.indexing_pipeline.document_hashing import (
|
||||||
|
compute_content_hash as real_compute_content_hash,
|
||||||
|
)
|
||||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||||
|
|
||||||
pytestmark = pytest.mark.integration
|
pytestmark = pytest.mark.integration
|
||||||
|
|
@ -20,7 +22,9 @@ async def test_new_document_is_persisted_with_pending_status(
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
document_id = results[0].id
|
document_id = results[0].id
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
reloaded = result.scalars().first()
|
reloaded = result.scalars().first()
|
||||||
|
|
||||||
assert reloaded is not None
|
assert reloaded is not None
|
||||||
|
|
@ -28,9 +32,14 @@ async def test_new_document_is_persisted_with_pending_status(
|
||||||
assert reloaded.source_markdown == doc.source_markdown
|
assert reloaded.source_markdown == doc.source_markdown
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_unchanged_ready_document_is_skipped(
|
async def test_unchanged_ready_document_is_skipped(
|
||||||
db_session, db_search_space, make_connector_document, mocker,
|
db_session,
|
||||||
|
db_search_space,
|
||||||
|
make_connector_document,
|
||||||
|
mocker,
|
||||||
):
|
):
|
||||||
"""A READY document with unchanged content is not returned for re-indexing."""
|
"""A READY document with unchanged content is not returned for re-indexing."""
|
||||||
doc = make_connector_document(search_space_id=db_search_space.id)
|
doc = make_connector_document(search_space_id=db_search_space.id)
|
||||||
|
|
@ -46,24 +55,35 @@ async def test_unchanged_ready_document_is_skipped(
|
||||||
assert results == []
|
assert results == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize", "patched_embed_text", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize", "patched_embed_text", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_title_only_change_updates_title_in_db(
|
async def test_title_only_change_updates_title_in_db(
|
||||||
db_session, db_search_space, make_connector_document, mocker,
|
db_session,
|
||||||
|
db_search_space,
|
||||||
|
make_connector_document,
|
||||||
|
mocker,
|
||||||
):
|
):
|
||||||
"""A title-only change updates the DB title without re-queuing the document."""
|
"""A title-only change updates the DB title without re-queuing the document."""
|
||||||
original = make_connector_document(search_space_id=db_search_space.id, title="Original Title")
|
original = make_connector_document(
|
||||||
|
search_space_id=db_search_space.id, title="Original Title"
|
||||||
|
)
|
||||||
service = IndexingPipelineService(session=db_session)
|
service = IndexingPipelineService(session=db_session)
|
||||||
|
|
||||||
prepared = await service.prepare_for_indexing([original])
|
prepared = await service.prepare_for_indexing([original])
|
||||||
document_id = prepared[0].id
|
document_id = prepared[0].id
|
||||||
await service.index(prepared[0], original, llm=mocker.Mock())
|
await service.index(prepared[0], original, llm=mocker.Mock())
|
||||||
|
|
||||||
renamed = make_connector_document(search_space_id=db_search_space.id, title="Updated Title")
|
renamed = make_connector_document(
|
||||||
|
search_space_id=db_search_space.id, title="Updated Title"
|
||||||
|
)
|
||||||
results = await service.prepare_for_indexing([renamed])
|
results = await service.prepare_for_indexing([renamed])
|
||||||
|
|
||||||
assert results == []
|
assert results == []
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
reloaded = result.scalars().first()
|
reloaded = result.scalars().first()
|
||||||
|
|
||||||
assert reloaded.title == "Updated Title"
|
assert reloaded.title == "Updated Title"
|
||||||
|
|
@ -73,19 +93,25 @@ async def test_changed_content_is_returned_for_reprocessing(
|
||||||
db_session, db_search_space, make_connector_document
|
db_session, db_search_space, make_connector_document
|
||||||
):
|
):
|
||||||
"""A document with changed content is returned for re-indexing with updated markdown."""
|
"""A document with changed content is returned for re-indexing with updated markdown."""
|
||||||
original = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v1")
|
original = make_connector_document(
|
||||||
|
search_space_id=db_search_space.id, source_markdown="## v1"
|
||||||
|
)
|
||||||
service = IndexingPipelineService(session=db_session)
|
service = IndexingPipelineService(session=db_session)
|
||||||
|
|
||||||
first = await service.prepare_for_indexing([original])
|
first = await service.prepare_for_indexing([original])
|
||||||
original_id = first[0].id
|
original_id = first[0].id
|
||||||
|
|
||||||
updated = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v2")
|
updated = make_connector_document(
|
||||||
|
search_space_id=db_search_space.id, source_markdown="## v2"
|
||||||
|
)
|
||||||
results = await service.prepare_for_indexing([updated])
|
results = await service.prepare_for_indexing([updated])
|
||||||
|
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert results[0].id == original_id
|
assert results[0].id == original_id
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == original_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == original_id)
|
||||||
|
)
|
||||||
reloaded = result.scalars().first()
|
reloaded = result.scalars().first()
|
||||||
|
|
||||||
assert reloaded.source_markdown == "## v2"
|
assert reloaded.source_markdown == "## v2"
|
||||||
|
|
@ -97,9 +123,24 @@ async def test_all_documents_in_batch_are_persisted(
|
||||||
):
|
):
|
||||||
"""All documents in a batch are persisted and returned."""
|
"""All documents in a batch are persisted and returned."""
|
||||||
docs = [
|
docs = [
|
||||||
make_connector_document(search_space_id=db_search_space.id, unique_id="id-1", title="Doc 1", source_markdown="## Content 1"),
|
make_connector_document(
|
||||||
make_connector_document(search_space_id=db_search_space.id, unique_id="id-2", title="Doc 2", source_markdown="## Content 2"),
|
search_space_id=db_search_space.id,
|
||||||
make_connector_document(search_space_id=db_search_space.id, unique_id="id-3", title="Doc 3", source_markdown="## Content 3"),
|
unique_id="id-1",
|
||||||
|
title="Doc 1",
|
||||||
|
source_markdown="## Content 1",
|
||||||
|
),
|
||||||
|
make_connector_document(
|
||||||
|
search_space_id=db_search_space.id,
|
||||||
|
unique_id="id-2",
|
||||||
|
title="Doc 2",
|
||||||
|
source_markdown="## Content 2",
|
||||||
|
),
|
||||||
|
make_connector_document(
|
||||||
|
search_space_id=db_search_space.id,
|
||||||
|
unique_id="id-3",
|
||||||
|
title="Doc 3",
|
||||||
|
source_markdown="## Content 3",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
service = IndexingPipelineService(session=db_session)
|
service = IndexingPipelineService(session=db_session)
|
||||||
|
|
||||||
|
|
@ -107,7 +148,9 @@ async def test_all_documents_in_batch_are_persisted(
|
||||||
|
|
||||||
assert len(results) == 3
|
assert len(results) == 3
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.search_space_id == db_search_space.id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.search_space_id == db_search_space.id)
|
||||||
|
)
|
||||||
rows = result.scalars().all()
|
rows = result.scalars().all()
|
||||||
|
|
||||||
assert len(rows) == 3
|
assert len(rows) == 3
|
||||||
|
|
@ -124,7 +167,9 @@ async def test_duplicate_in_batch_is_persisted_once(
|
||||||
|
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.search_space_id == db_search_space.id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.search_space_id == db_search_space.id)
|
||||||
|
)
|
||||||
rows = result.scalars().all()
|
rows = result.scalars().all()
|
||||||
|
|
||||||
assert len(rows) == 1
|
assert len(rows) == 1
|
||||||
|
|
@ -143,7 +188,9 @@ async def test_created_by_id_is_persisted(
|
||||||
results = await service.prepare_for_indexing([doc])
|
results = await service.prepare_for_indexing([doc])
|
||||||
document_id = results[0].id
|
document_id = results[0].id
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
reloaded = result.scalars().first()
|
reloaded = result.scalars().first()
|
||||||
|
|
||||||
assert str(reloaded.created_by_id) == str(db_user.id)
|
assert str(reloaded.created_by_id) == str(db_user.id)
|
||||||
|
|
@ -170,7 +217,9 @@ async def test_metadata_is_updated_when_content_changes(
|
||||||
)
|
)
|
||||||
await service.prepare_for_indexing([updated])
|
await service.prepare_for_indexing([updated])
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
reloaded = result.scalars().first()
|
reloaded = result.scalars().first()
|
||||||
|
|
||||||
assert reloaded.document_metadata == {"status": "done"}
|
assert reloaded.document_metadata == {"status": "done"}
|
||||||
|
|
@ -180,19 +229,27 @@ async def test_updated_at_advances_when_title_only_changes(
|
||||||
db_session, db_search_space, make_connector_document
|
db_session, db_search_space, make_connector_document
|
||||||
):
|
):
|
||||||
"""updated_at advances even when only the title changes."""
|
"""updated_at advances even when only the title changes."""
|
||||||
original = make_connector_document(search_space_id=db_search_space.id, title="Old Title")
|
original = make_connector_document(
|
||||||
|
search_space_id=db_search_space.id, title="Old Title"
|
||||||
|
)
|
||||||
service = IndexingPipelineService(session=db_session)
|
service = IndexingPipelineService(session=db_session)
|
||||||
|
|
||||||
first = await service.prepare_for_indexing([original])
|
first = await service.prepare_for_indexing([original])
|
||||||
document_id = first[0].id
|
document_id = first[0].id
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
updated_at_v1 = result.scalars().first().updated_at
|
updated_at_v1 = result.scalars().first().updated_at
|
||||||
|
|
||||||
renamed = make_connector_document(search_space_id=db_search_space.id, title="New Title")
|
renamed = make_connector_document(
|
||||||
|
search_space_id=db_search_space.id, title="New Title"
|
||||||
|
)
|
||||||
await service.prepare_for_indexing([renamed])
|
await service.prepare_for_indexing([renamed])
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
updated_at_v2 = result.scalars().first().updated_at
|
updated_at_v2 = result.scalars().first().updated_at
|
||||||
|
|
||||||
assert updated_at_v2 > updated_at_v1
|
assert updated_at_v2 > updated_at_v1
|
||||||
|
|
@ -202,19 +259,27 @@ async def test_updated_at_advances_when_content_changes(
|
||||||
db_session, db_search_space, make_connector_document
|
db_session, db_search_space, make_connector_document
|
||||||
):
|
):
|
||||||
"""updated_at advances when document content changes."""
|
"""updated_at advances when document content changes."""
|
||||||
original = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v1")
|
original = make_connector_document(
|
||||||
|
search_space_id=db_search_space.id, source_markdown="## v1"
|
||||||
|
)
|
||||||
service = IndexingPipelineService(session=db_session)
|
service = IndexingPipelineService(session=db_session)
|
||||||
|
|
||||||
first = await service.prepare_for_indexing([original])
|
first = await service.prepare_for_indexing([original])
|
||||||
document_id = first[0].id
|
document_id = first[0].id
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
updated_at_v1 = result.scalars().first().updated_at
|
updated_at_v1 = result.scalars().first().updated_at
|
||||||
|
|
||||||
updated = make_connector_document(search_space_id=db_search_space.id, source_markdown="## v2")
|
updated = make_connector_document(
|
||||||
|
search_space_id=db_search_space.id, source_markdown="## v2"
|
||||||
|
)
|
||||||
await service.prepare_for_indexing([updated])
|
await service.prepare_for_indexing([updated])
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
updated_at_v2 = result.scalars().first().updated_at
|
updated_at_v2 = result.scalars().first().updated_at
|
||||||
|
|
||||||
assert updated_at_v2 > updated_at_v1
|
assert updated_at_v2 > updated_at_v1
|
||||||
|
|
@ -273,9 +338,14 @@ async def test_same_content_from_different_source_is_skipped(
|
||||||
assert len(result.scalars().all()) == 1
|
assert len(result.scalars().all()) == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("patched_summarize_raises", "patched_embed_text", "patched_chunk_text")
|
@pytest.mark.usefixtures(
|
||||||
|
"patched_summarize_raises", "patched_embed_text", "patched_chunk_text"
|
||||||
|
)
|
||||||
async def test_failed_document_with_unchanged_content_is_requeued(
|
async def test_failed_document_with_unchanged_content_is_requeued(
|
||||||
db_session, db_search_space, make_connector_document, mocker,
|
db_session,
|
||||||
|
db_search_space,
|
||||||
|
make_connector_document,
|
||||||
|
mocker,
|
||||||
):
|
):
|
||||||
"""A FAILED document with unchanged content is re-queued as PENDING on the next run."""
|
"""A FAILED document with unchanged content is re-queued as PENDING on the next run."""
|
||||||
doc = make_connector_document(search_space_id=db_search_space.id)
|
doc = make_connector_document(search_space_id=db_search_space.id)
|
||||||
|
|
@ -286,8 +356,12 @@ async def test_failed_document_with_unchanged_content_is_requeued(
|
||||||
document_id = prepared[0].id
|
document_id = prepared[0].id
|
||||||
await service.index(prepared[0], doc, llm=mocker.Mock())
|
await service.index(prepared[0], doc, llm=mocker.Mock())
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
result = await db_session.execute(
|
||||||
assert DocumentStatus.is_state(result.scalars().first().status, DocumentStatus.FAILED)
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
|
assert DocumentStatus.is_state(
|
||||||
|
result.scalars().first().status, DocumentStatus.FAILED
|
||||||
|
)
|
||||||
|
|
||||||
# Next run: same content, pipeline must re-queue the failed document
|
# Next run: same content, pipeline must re-queue the failed document
|
||||||
results = await service.prepare_for_indexing([doc])
|
results = await service.prepare_for_indexing([doc])
|
||||||
|
|
@ -295,8 +369,12 @@ async def test_failed_document_with_unchanged_content_is_requeued(
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert results[0].id == document_id
|
assert results[0].id == document_id
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == document_id))
|
result = await db_session.execute(
|
||||||
assert DocumentStatus.is_state(result.scalars().first().status, DocumentStatus.PENDING)
|
select(Document).filter(Document.id == document_id)
|
||||||
|
)
|
||||||
|
assert DocumentStatus.is_state(
|
||||||
|
result.scalars().first().status, DocumentStatus.PENDING
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_title_and_content_change_updates_both_and_returns_document(
|
async def test_title_and_content_change_updates_both_and_returns_document(
|
||||||
|
|
@ -323,16 +401,20 @@ async def test_title_and_content_change_updates_both_and_returns_document(
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert results[0].id == original_id
|
assert results[0].id == original_id
|
||||||
|
|
||||||
result = await db_session.execute(select(Document).filter(Document.id == original_id))
|
result = await db_session.execute(
|
||||||
|
select(Document).filter(Document.id == original_id)
|
||||||
|
)
|
||||||
reloaded = result.scalars().first()
|
reloaded = result.scalars().first()
|
||||||
|
|
||||||
assert reloaded.title == "Updated Title"
|
assert reloaded.title == "Updated Title"
|
||||||
assert reloaded.source_markdown == "## v2"
|
assert reloaded.source_markdown == "## v2"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def test_one_bad_document_in_batch_does_not_prevent_others_from_being_persisted(
|
async def test_one_bad_document_in_batch_does_not_prevent_others_from_being_persisted(
|
||||||
db_session, db_search_space, make_connector_document, monkeypatch,
|
db_session,
|
||||||
|
db_search_space,
|
||||||
|
make_connector_document,
|
||||||
|
monkeypatch,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
A per-document error during prepare_for_indexing must be isolated.
|
A per-document error during prepare_for_indexing must be isolated.
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import pytest
|
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def patched_summarizer_chain(monkeypatch):
|
def patched_summarizer_chain(monkeypatch):
|
||||||
|
|
@ -21,7 +22,9 @@ def patched_summarizer_chain(monkeypatch):
|
||||||
def patched_chunker_instance(monkeypatch):
|
def patched_chunker_instance(monkeypatch):
|
||||||
mock = MagicMock()
|
mock = MagicMock()
|
||||||
mock.chunk.return_value = [MagicMock(text="prose chunk")]
|
mock.chunk.return_value = [MagicMock(text="prose chunk")]
|
||||||
monkeypatch.setattr("app.indexing_pipeline.document_chunker.config.chunker_instance", mock)
|
monkeypatch.setattr(
|
||||||
|
"app.indexing_pipeline.document_chunker.config.chunker_instance", mock
|
||||||
|
)
|
||||||
return mock
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -29,5 +32,7 @@ def patched_chunker_instance(monkeypatch):
|
||||||
def patched_code_chunker_instance(monkeypatch):
|
def patched_code_chunker_instance(monkeypatch):
|
||||||
mock = MagicMock()
|
mock = MagicMock()
|
||||||
mock.chunk.return_value = [MagicMock(text="code chunk")]
|
mock.chunk.return_value = [MagicMock(text="code chunk")]
|
||||||
monkeypatch.setattr("app.indexing_pipeline.document_chunker.config.code_chunker_instance", mock)
|
monkeypatch.setattr(
|
||||||
|
"app.indexing_pipeline.document_chunker.config.code_chunker_instance", mock
|
||||||
|
)
|
||||||
return mock
|
return mock
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.db import DocumentType
|
from app.db import DocumentType
|
||||||
from app.indexing_pipeline.document_hashing import compute_content_hash, compute_unique_identifier_hash
|
from app.indexing_pipeline.document_hashing import (
|
||||||
|
compute_content_hash,
|
||||||
|
compute_unique_identifier_hash,
|
||||||
|
)
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
@ -10,21 +13,31 @@ def test_different_unique_id_produces_different_hash(make_connector_document):
|
||||||
"""Two documents with different unique_ids produce different identifier hashes."""
|
"""Two documents with different unique_ids produce different identifier hashes."""
|
||||||
doc_a = make_connector_document(unique_id="id-001")
|
doc_a = make_connector_document(unique_id="id-001")
|
||||||
doc_b = make_connector_document(unique_id="id-002")
|
doc_b = make_connector_document(unique_id="id-002")
|
||||||
assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(doc_b)
|
assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(
|
||||||
|
doc_b
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_different_search_space_produces_different_identifier_hash(make_connector_document):
|
def test_different_search_space_produces_different_identifier_hash(
|
||||||
|
make_connector_document,
|
||||||
|
):
|
||||||
"""Same document in different search spaces produces different identifier hashes."""
|
"""Same document in different search spaces produces different identifier hashes."""
|
||||||
doc_a = make_connector_document(search_space_id=1)
|
doc_a = make_connector_document(search_space_id=1)
|
||||||
doc_b = make_connector_document(search_space_id=2)
|
doc_b = make_connector_document(search_space_id=2)
|
||||||
assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(doc_b)
|
assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(
|
||||||
|
doc_b
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_different_document_type_produces_different_identifier_hash(make_connector_document):
|
def test_different_document_type_produces_different_identifier_hash(
|
||||||
|
make_connector_document,
|
||||||
|
):
|
||||||
"""Same unique_id with different document types produces different identifier hashes."""
|
"""Same unique_id with different document types produces different identifier hashes."""
|
||||||
doc_a = make_connector_document(document_type=DocumentType.CLICKUP_CONNECTOR)
|
doc_a = make_connector_document(document_type=DocumentType.CLICKUP_CONNECTOR)
|
||||||
doc_b = make_connector_document(document_type=DocumentType.NOTION_CONNECTOR)
|
doc_b = make_connector_document(document_type=DocumentType.NOTION_CONNECTOR)
|
||||||
assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(doc_b)
|
assert compute_unique_identifier_hash(doc_a) != compute_unique_identifier_hash(
|
||||||
|
doc_b
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_same_content_same_space_produces_same_content_hash(make_connector_document):
|
def test_same_content_same_space_produces_same_content_hash(make_connector_document):
|
||||||
|
|
@ -34,7 +47,9 @@ def test_same_content_same_space_produces_same_content_hash(make_connector_docum
|
||||||
assert compute_content_hash(doc_a) == compute_content_hash(doc_b)
|
assert compute_content_hash(doc_a) == compute_content_hash(doc_b)
|
||||||
|
|
||||||
|
|
||||||
def test_same_content_different_space_produces_different_content_hash(make_connector_document):
|
def test_same_content_different_space_produces_different_content_hash(
|
||||||
|
make_connector_document,
|
||||||
|
):
|
||||||
"""Identical content in different search spaces produces different content hashes."""
|
"""Identical content in different search spaces produces different content hashes."""
|
||||||
doc_a = make_connector_document(source_markdown="Hello world", search_space_id=1)
|
doc_a = make_connector_document(source_markdown="Hello world", search_space_id=1)
|
||||||
doc_b = make_connector_document(source_markdown="Hello world", search_space_id=2)
|
doc_b = make_connector_document(source_markdown="Hello world", search_space_id=2)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import pytest
|
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from app.indexing_pipeline.document_summarizer import summarize_document
|
from app.indexing_pipeline.document_summarizer import summarize_document
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
pytestmark = pytest.mark.unit
|
||||||
|
|
@ -38,5 +39,3 @@ async def test_with_metadata_omits_empty_fields_from_output():
|
||||||
|
|
||||||
assert "Alice" in result
|
assert "Alice" in result
|
||||||
assert "description" not in result.lower()
|
assert "description" not in result.lower()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue