mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-28 21:49:40 +02:00
Merge upstream/dev into feat/vision-autocomplete
This commit is contained in:
commit
d7315e7f27
142 changed files with 9440 additions and 3390 deletions
|
|
@ -42,9 +42,7 @@ def upgrade() -> None:
|
|||
if not exists:
|
||||
table_list = ", ".join(TABLES)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE {table_list}"
|
||||
)
|
||||
sa.text(f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE {table_list}")
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,123 @@
|
|||
"""optimize zero_publication with column lists
|
||||
|
||||
Recreates the zero_publication using column lists for the documents
|
||||
table so that large text columns (content, source_markdown,
|
||||
blocknote_document, etc.) are excluded from WAL replication.
|
||||
This prevents RangeError: Invalid string length in zero-cache's
|
||||
change-streamer when documents have very large content.
|
||||
|
||||
Also resets REPLICA IDENTITY to DEFAULT on tables that had it set
|
||||
to FULL for the old Electric SQL setup (migration 66/75/76).
|
||||
With DEFAULT (primary-key) identity, column-list publications
|
||||
only need to include the PK — not every column.
|
||||
|
||||
IMPORTANT — before AND after running this migration:
|
||||
1. Stop zero-cache (it holds replication locks that will deadlock DDL)
|
||||
2. Run: alembic upgrade head
|
||||
3. Delete / reset the zero-cache data volume
|
||||
4. Restart zero-cache (it will do a fresh initial sync)
|
||||
|
||||
Revision ID: 117
|
||||
Revises: 116
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "117"
|
||||
down_revision: str | None = "116"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
PUBLICATION_NAME = "zero_publication"
|
||||
|
||||
TABLES_WITH_FULL_IDENTITY = [
|
||||
"documents",
|
||||
"notifications",
|
||||
"search_source_connectors",
|
||||
"new_chat_messages",
|
||||
"chat_comments",
|
||||
"chat_session_state",
|
||||
]
|
||||
|
||||
DOCUMENT_COLS = [
|
||||
"id",
|
||||
"title",
|
||||
"document_type",
|
||||
"search_space_id",
|
||||
"folder_id",
|
||||
"created_by_id",
|
||||
"status",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
]
|
||||
|
||||
PUBLICATION_DDL_FULL = f"""\
|
||||
CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE
|
||||
notifications, documents, folders,
|
||||
search_source_connectors, new_chat_messages,
|
||||
chat_comments, chat_session_state
|
||||
"""
|
||||
|
||||
|
||||
def _terminate_blocked_pids(conn, table: str) -> None:
|
||||
"""Kill backends whose locks on *table* would block our AccessExclusiveLock."""
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"SELECT pg_terminate_backend(l.pid) "
|
||||
"FROM pg_locks l "
|
||||
"JOIN pg_class c ON c.oid = l.relation "
|
||||
"WHERE c.relname = :tbl "
|
||||
" AND l.pid != pg_backend_pid()"
|
||||
),
|
||||
{"tbl": table},
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
conn.execute(sa.text("SET lock_timeout = '10s'"))
|
||||
|
||||
for tbl in sorted(TABLES_WITH_FULL_IDENTITY):
|
||||
_terminate_blocked_pids(conn, tbl)
|
||||
conn.execute(sa.text(f'LOCK TABLE "{tbl}" IN ACCESS EXCLUSIVE MODE'))
|
||||
|
||||
for tbl in TABLES_WITH_FULL_IDENTITY:
|
||||
conn.execute(sa.text(f'ALTER TABLE "{tbl}" REPLICA IDENTITY DEFAULT'))
|
||||
|
||||
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||
|
||||
has_zero_ver = conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM information_schema.columns "
|
||||
"WHERE table_name = 'documents' AND column_name = '_0_version'"
|
||||
)
|
||||
).fetchone()
|
||||
|
||||
cols = DOCUMENT_COLS + (['"_0_version"'] if has_zero_ver else [])
|
||||
col_list = ", ".join(cols)
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
f"CREATE PUBLICATION {PUBLICATION_NAME} FOR TABLE "
|
||||
f"notifications, "
|
||||
f"documents ({col_list}), "
|
||||
f"folders, "
|
||||
f"search_source_connectors, "
|
||||
f"new_chat_messages, "
|
||||
f"chat_comments, "
|
||||
f"chat_session_state"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text(f"DROP PUBLICATION IF EXISTS {PUBLICATION_NAME}"))
|
||||
conn.execute(sa.text(PUBLICATION_DDL_FULL))
|
||||
for tbl in TABLES_WITH_FULL_IDENTITY:
|
||||
conn.execute(sa.text(f'ALTER TABLE "{tbl}" REPLICA IDENTITY FULL'))
|
||||
|
|
@ -0,0 +1,149 @@
|
|||
"""Add LOCAL_FOLDER_FILE document type, folder metadata, and document_versions table
|
||||
|
||||
Revision ID: 118
|
||||
Revises: 117
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "118"
|
||||
down_revision: str | None = "117"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
PUBLICATION_NAME = "zero_publication"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Add LOCAL_FOLDER_FILE to documenttype enum
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_type t
|
||||
JOIN pg_enum e ON t.oid = e.enumtypid
|
||||
WHERE t.typname = 'documenttype' AND e.enumlabel = 'LOCAL_FOLDER_FILE'
|
||||
) THEN
|
||||
ALTER TYPE documenttype ADD VALUE 'LOCAL_FOLDER_FILE';
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
# Add JSONB metadata column to folders table
|
||||
col_exists = conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM information_schema.columns "
|
||||
"WHERE table_name = 'folders' AND column_name = 'metadata'"
|
||||
)
|
||||
).fetchone()
|
||||
if not col_exists:
|
||||
op.add_column(
|
||||
"folders",
|
||||
sa.Column("metadata", sa.dialects.postgresql.JSONB, nullable=True),
|
||||
)
|
||||
|
||||
# Create document_versions table
|
||||
table_exists = conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM information_schema.tables WHERE table_name = 'document_versions'"
|
||||
)
|
||||
).fetchone()
|
||||
if not table_exists:
|
||||
op.create_table(
|
||||
"document_versions",
|
||||
sa.Column("id", sa.Integer(), nullable=False, autoincrement=True),
|
||||
sa.Column("document_id", sa.Integer(), nullable=False),
|
||||
sa.Column("version_number", sa.Integer(), nullable=False),
|
||||
sa.Column("source_markdown", sa.Text(), nullable=True),
|
||||
sa.Column("content_hash", sa.String(), nullable=False),
|
||||
sa.Column("title", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["document_id"],
|
||||
["documents.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"document_id",
|
||||
"version_number",
|
||||
name="uq_document_version",
|
||||
),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_document_versions_document_id "
|
||||
"ON document_versions (document_id)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ix_document_versions_created_at "
|
||||
"ON document_versions (created_at)"
|
||||
)
|
||||
|
||||
# Add document_versions to Zero publication
|
||||
pub_exists = conn.execute(
|
||||
sa.text("SELECT 1 FROM pg_publication WHERE pubname = :name"),
|
||||
{"name": PUBLICATION_NAME},
|
||||
).fetchone()
|
||||
if pub_exists:
|
||||
already_in_pub = conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM pg_publication_tables "
|
||||
"WHERE pubname = :name AND tablename = 'document_versions'"
|
||||
),
|
||||
{"name": PUBLICATION_NAME},
|
||||
).fetchone()
|
||||
if not already_in_pub:
|
||||
op.execute(
|
||||
f"ALTER PUBLICATION {PUBLICATION_NAME} ADD TABLE document_versions"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Remove from publication
|
||||
pub_exists = conn.execute(
|
||||
sa.text("SELECT 1 FROM pg_publication WHERE pubname = :name"),
|
||||
{"name": PUBLICATION_NAME},
|
||||
).fetchone()
|
||||
if pub_exists:
|
||||
already_in_pub = conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM pg_publication_tables "
|
||||
"WHERE pubname = :name AND tablename = 'document_versions'"
|
||||
),
|
||||
{"name": PUBLICATION_NAME},
|
||||
).fetchone()
|
||||
if already_in_pub:
|
||||
op.execute(
|
||||
f"ALTER PUBLICATION {PUBLICATION_NAME} DROP TABLE document_versions"
|
||||
)
|
||||
|
||||
op.execute("DROP INDEX IF EXISTS ix_document_versions_created_at")
|
||||
op.execute("DROP INDEX IF EXISTS ix_document_versions_document_id")
|
||||
op.execute("DROP TABLE IF EXISTS document_versions")
|
||||
|
||||
# Drop metadata column from folders
|
||||
col_exists = conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM information_schema.columns "
|
||||
"WHERE table_name = 'folders' AND column_name = 'metadata'"
|
||||
)
|
||||
).fetchone()
|
||||
if col_exists:
|
||||
op.drop_column("folders", "metadata")
|
||||
|
|
@ -17,10 +17,10 @@ depends_on: str | Sequence[str] | None = None
|
|||
|
||||
def upgrade() -> None:
|
||||
"""
|
||||
Add the new_llm_configs table that combines LLM model settings with prompt configuration.
|
||||
Add the new_llm_configs table that combines model settings with prompt configuration.
|
||||
|
||||
This table includes:
|
||||
- LLM model configuration (provider, model_name, api_key, etc.)
|
||||
- Model configuration (provider, model_name, api_key, etc.)
|
||||
- Configurable system instructions
|
||||
- Citation toggle
|
||||
"""
|
||||
|
|
@ -41,7 +41,7 @@ def upgrade() -> None:
|
|||
name VARCHAR(100) NOT NULL,
|
||||
description VARCHAR(500),
|
||||
|
||||
-- LLM Model Configuration (same as llm_configs, excluding language)
|
||||
-- Model Configuration (same as llm_configs, excluding language)
|
||||
provider litellmprovider NOT NULL,
|
||||
custom_provider VARCHAR(100),
|
||||
model_name VARCHAR(100) NOT NULL,
|
||||
|
|
|
|||
|
|
@ -159,6 +159,7 @@ async def create_surfsense_deep_agent(
|
|||
additional_tools: Sequence[BaseTool] | None = None,
|
||||
firecrawl_api_key: str | None = None,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
):
|
||||
"""
|
||||
Create a SurfSense deep agent with configurable tools and prompts.
|
||||
|
|
@ -451,6 +452,7 @@ async def create_surfsense_deep_agent(
|
|||
search_space_id=search_space_id,
|
||||
available_connectors=available_connectors,
|
||||
available_document_types=available_document_types,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
),
|
||||
SurfSenseFilesystemMiddleware(
|
||||
search_space_id=search_space_id,
|
||||
|
|
|
|||
|
|
@ -66,6 +66,16 @@ the `<chunk_index>`, identify chunks marked `matched="true"`, then use
|
|||
those sections instead of reading the entire file sequentially.
|
||||
|
||||
Use `<chunk id='...'>` values as citation IDs in your answers.
|
||||
|
||||
## User-Mentioned Documents
|
||||
|
||||
When the `ls` output tags a file with `[MENTIONED BY USER — read deeply]`,
|
||||
the user **explicitly selected** that document. These files are your highest-
|
||||
priority sources:
|
||||
1. **Always read them thoroughly** — scan the full `<chunk_index>`, then read
|
||||
all major sections, not just matched chunks.
|
||||
2. **Prefer their content** over other search results when answering.
|
||||
3. **Cite from them first** whenever applicable.
|
||||
"""
|
||||
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -28,7 +28,13 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.utils import parse_date_or_datetime, resolve_date_range
|
||||
from app.db import NATIVE_TO_LEGACY_DOCTYPE, Document, Folder, shielded_async_session
|
||||
from app.db import (
|
||||
NATIVE_TO_LEGACY_DOCTYPE,
|
||||
Chunk,
|
||||
Document,
|
||||
Folder,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.retriever.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
from app.utils.document_converters import embed_texts
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
|
@ -430,21 +436,36 @@ async def _get_folder_paths(
|
|||
def _build_synthetic_ls(
|
||||
existing_files: dict[str, Any] | None,
|
||||
new_files: dict[str, Any],
|
||||
*,
|
||||
mentioned_paths: set[str] | None = None,
|
||||
) -> tuple[AIMessage, ToolMessage]:
|
||||
"""Build a synthetic ls("/documents") tool-call + result for the LLM context.
|
||||
|
||||
Paths are listed with *new* (rank-ordered) files first, then existing files
|
||||
that were already in state from prior turns.
|
||||
Mentioned files are listed first. A separate header tells the LLM which
|
||||
files the user explicitly selected; the path list itself stays clean so
|
||||
paths can be passed directly to ``read_file`` without stripping tags.
|
||||
"""
|
||||
_mentioned = mentioned_paths or set()
|
||||
merged: dict[str, Any] = {**(existing_files or {}), **new_files}
|
||||
doc_paths = [
|
||||
p for p, v in merged.items() if p.startswith("/documents/") and v is not None
|
||||
]
|
||||
|
||||
new_set = set(new_files)
|
||||
new_paths = [p for p in doc_paths if p in new_set]
|
||||
mentioned_list = [p for p in doc_paths if p in _mentioned]
|
||||
new_non_mentioned = [p for p in doc_paths if p in new_set and p not in _mentioned]
|
||||
old_paths = [p for p in doc_paths if p not in new_set]
|
||||
ordered = new_paths + old_paths
|
||||
ordered = mentioned_list + new_non_mentioned + old_paths
|
||||
|
||||
parts: list[str] = []
|
||||
if mentioned_list:
|
||||
parts.append(
|
||||
"USER-MENTIONED documents (read these thoroughly before answering):"
|
||||
)
|
||||
for p in mentioned_list:
|
||||
parts.append(f" {p}")
|
||||
parts.append("")
|
||||
parts.append(str(ordered) if ordered else "No documents found.")
|
||||
|
||||
tool_call_id = f"auto_ls_{uuid.uuid4().hex[:12]}"
|
||||
ai_msg = AIMessage(
|
||||
|
|
@ -452,7 +473,7 @@ def _build_synthetic_ls(
|
|||
tool_calls=[{"name": "ls", "args": {"path": "/documents"}, "id": tool_call_id}],
|
||||
)
|
||||
tool_msg = ToolMessage(
|
||||
content=str(ordered) if ordered else "No documents found.",
|
||||
content="\n".join(parts),
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
return ai_msg, tool_msg
|
||||
|
|
@ -524,12 +545,92 @@ async def search_knowledge_base(
|
|||
return results[:top_k]
|
||||
|
||||
|
||||
async def fetch_mentioned_documents(
|
||||
*,
|
||||
document_ids: list[int],
|
||||
search_space_id: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch explicitly mentioned documents with *all* their chunks.
|
||||
|
||||
Returns the same dict structure as ``search_knowledge_base`` so results
|
||||
can be merged directly into ``build_scoped_filesystem``. Unlike search
|
||||
results, every chunk is included (no top-K limiting) and none are marked
|
||||
as ``matched`` since the entire document is relevant by virtue of the
|
||||
user's explicit mention.
|
||||
"""
|
||||
if not document_ids:
|
||||
return []
|
||||
|
||||
async with shielded_async_session() as session:
|
||||
doc_result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.id.in_(document_ids),
|
||||
Document.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
docs = {doc.id: doc for doc in doc_result.scalars().all()}
|
||||
|
||||
if not docs:
|
||||
return []
|
||||
|
||||
chunk_result = await session.execute(
|
||||
select(Chunk.id, Chunk.content, Chunk.document_id)
|
||||
.where(Chunk.document_id.in_(list(docs.keys())))
|
||||
.order_by(Chunk.document_id, Chunk.id)
|
||||
)
|
||||
chunks_by_doc: dict[int, list[dict[str, Any]]] = {doc_id: [] for doc_id in docs}
|
||||
for row in chunk_result.all():
|
||||
if row.document_id in chunks_by_doc:
|
||||
chunks_by_doc[row.document_id].append(
|
||||
{"chunk_id": row.id, "content": row.content}
|
||||
)
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
for doc_id in document_ids:
|
||||
doc = docs.get(doc_id)
|
||||
if doc is None:
|
||||
continue
|
||||
metadata = doc.document_metadata or {}
|
||||
results.append(
|
||||
{
|
||||
"document_id": doc.id,
|
||||
"content": "",
|
||||
"score": 1.0,
|
||||
"chunks": chunks_by_doc.get(doc.id, []),
|
||||
"matched_chunk_ids": [],
|
||||
"document": {
|
||||
"id": doc.id,
|
||||
"title": doc.title,
|
||||
"document_type": (
|
||||
doc.document_type.value
|
||||
if getattr(doc, "document_type", None)
|
||||
else None
|
||||
),
|
||||
"metadata": metadata,
|
||||
},
|
||||
"source": (
|
||||
doc.document_type.value
|
||||
if getattr(doc, "document_type", None)
|
||||
else None
|
||||
),
|
||||
"_user_mentioned": True,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
async def build_scoped_filesystem(
|
||||
*,
|
||||
documents: Sequence[dict[str, Any]],
|
||||
search_space_id: int,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""Build a StateBackend-compatible files dict from search results."""
|
||||
) -> tuple[dict[str, dict[str, str]], dict[int, str]]:
|
||||
"""Build a StateBackend-compatible files dict from search results.
|
||||
|
||||
Returns ``(files, doc_id_to_path)`` so callers can reliably map a
|
||||
document id back to its filesystem path without guessing by title.
|
||||
Paths are collision-proof: when two documents resolve to the same
|
||||
path the doc-id is appended to disambiguate.
|
||||
"""
|
||||
async with shielded_async_session() as session:
|
||||
folder_paths = await _get_folder_paths(session, search_space_id)
|
||||
doc_ids = [
|
||||
|
|
@ -551,6 +652,7 @@ async def build_scoped_filesystem(
|
|||
}
|
||||
|
||||
files: dict[str, dict[str, str]] = {}
|
||||
doc_id_to_path: dict[int, str] = {}
|
||||
for document in documents:
|
||||
doc_meta = document.get("document") or {}
|
||||
title = str(doc_meta.get("title") or "untitled")
|
||||
|
|
@ -559,6 +661,9 @@ async def build_scoped_filesystem(
|
|||
base_folder = folder_paths.get(folder_id, "/documents")
|
||||
file_name = _safe_filename(title)
|
||||
path = f"{base_folder}/{file_name}"
|
||||
if path in files:
|
||||
stem = file_name.removesuffix(".xml")
|
||||
path = f"{base_folder}/{stem} ({doc_id}).xml"
|
||||
matched_ids = set(document.get("matched_chunk_ids") or [])
|
||||
xml_content = _build_document_xml(document, matched_chunk_ids=matched_ids)
|
||||
files[path] = {
|
||||
|
|
@ -567,7 +672,9 @@ async def build_scoped_filesystem(
|
|||
"created_at": "",
|
||||
"modified_at": "",
|
||||
}
|
||||
return files
|
||||
if isinstance(doc_id, int):
|
||||
doc_id_to_path[doc_id] = path
|
||||
return files, doc_id_to_path
|
||||
|
||||
|
||||
class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
||||
|
|
@ -583,12 +690,14 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
available_connectors: list[str] | None = None,
|
||||
available_document_types: list[str] | None = None,
|
||||
top_k: int = 10,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
self.llm = llm
|
||||
self.search_space_id = search_space_id
|
||||
self.available_connectors = available_connectors
|
||||
self.available_document_types = available_document_types
|
||||
self.top_k = top_k
|
||||
self.mentioned_document_ids = mentioned_document_ids or []
|
||||
|
||||
async def _plan_search_inputs(
|
||||
self,
|
||||
|
|
@ -680,6 +789,18 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
user_text=user_text,
|
||||
)
|
||||
|
||||
# --- 1. Fetch mentioned documents (user-selected, all chunks) ---
|
||||
mentioned_results: list[dict[str, Any]] = []
|
||||
if self.mentioned_document_ids:
|
||||
mentioned_results = await fetch_mentioned_documents(
|
||||
document_ids=self.mentioned_document_ids,
|
||||
search_space_id=self.search_space_id,
|
||||
)
|
||||
# Clear after first turn so they are not re-fetched on subsequent
|
||||
# messages within the same agent instance.
|
||||
self.mentioned_document_ids = []
|
||||
|
||||
# --- 2. Run KB hybrid search ---
|
||||
search_results = await search_knowledge_base(
|
||||
query=planned_query,
|
||||
search_space_id=self.search_space_id,
|
||||
|
|
@ -689,19 +810,50 @@ class KnowledgeBaseSearchMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
new_files = await build_scoped_filesystem(
|
||||
documents=search_results,
|
||||
|
||||
# --- 3. Merge: mentioned first, then search (dedup by doc id) ---
|
||||
seen_doc_ids: set[int] = set()
|
||||
merged: list[dict[str, Any]] = []
|
||||
for doc in mentioned_results:
|
||||
doc_id = (doc.get("document") or {}).get("id")
|
||||
if doc_id is not None:
|
||||
seen_doc_ids.add(doc_id)
|
||||
merged.append(doc)
|
||||
for doc in search_results:
|
||||
doc_id = (doc.get("document") or {}).get("id")
|
||||
if doc_id is not None and doc_id in seen_doc_ids:
|
||||
continue
|
||||
merged.append(doc)
|
||||
|
||||
# --- 4. Build scoped filesystem ---
|
||||
new_files, doc_id_to_path = await build_scoped_filesystem(
|
||||
documents=merged,
|
||||
search_space_id=self.search_space_id,
|
||||
)
|
||||
|
||||
ai_msg, tool_msg = _build_synthetic_ls(existing_files, new_files)
|
||||
# Identify which paths belong to user-mentioned documents using
|
||||
# the authoritative doc_id -> path mapping (no title guessing).
|
||||
mentioned_doc_ids = {
|
||||
(d.get("document") or {}).get("id") for d in mentioned_results
|
||||
}
|
||||
mentioned_paths = {
|
||||
doc_id_to_path[did] for did in mentioned_doc_ids if did in doc_id_to_path
|
||||
}
|
||||
|
||||
ai_msg, tool_msg = _build_synthetic_ls(
|
||||
existing_files,
|
||||
new_files,
|
||||
mentioned_paths=mentioned_paths,
|
||||
)
|
||||
|
||||
if t0 is not None:
|
||||
_perf_log.info(
|
||||
"[kb_fs_middleware] completed in %.3fs query=%r optimized=%r new_files=%d total=%d",
|
||||
"[kb_fs_middleware] completed in %.3fs query=%r optimized=%r "
|
||||
"mentioned=%d new_files=%d total=%d",
|
||||
asyncio.get_event_loop().time() - t0,
|
||||
user_text[:80],
|
||||
planned_query[:120],
|
||||
len(mentioned_results),
|
||||
len(new_files),
|
||||
len(new_files) + len(existing_files or {}),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@
|
|||
# - Configure router_settings below to customize the load balancing behavior
|
||||
#
|
||||
# Structure matches NewLLMConfig:
|
||||
# - LLM model configuration (provider, model_name, api_key, etc.)
|
||||
# - Model configuration (provider, model_name, api_key, etc.)
|
||||
# - Prompt configuration (system_instructions, citations_enabled)
|
||||
|
||||
# Router Settings for Auto Mode
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ class DocumentType(StrEnum):
|
|||
COMPOSIO_GOOGLE_DRIVE_CONNECTOR = "COMPOSIO_GOOGLE_DRIVE_CONNECTOR"
|
||||
COMPOSIO_GMAIL_CONNECTOR = "COMPOSIO_GMAIL_CONNECTOR"
|
||||
COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"
|
||||
LOCAL_FOLDER_FILE = "LOCAL_FOLDER_FILE"
|
||||
|
||||
|
||||
# Native Google document types → their legacy Composio equivalents.
|
||||
|
|
@ -955,6 +956,7 @@ class Folder(BaseModel, TimestampMixin):
|
|||
onupdate=lambda: datetime.now(UTC),
|
||||
index=True,
|
||||
)
|
||||
folder_metadata = Column("metadata", JSONB, nullable=True)
|
||||
|
||||
parent = relationship("Folder", remote_side="Folder.id", backref="children")
|
||||
search_space = relationship("SearchSpace", back_populates="folders")
|
||||
|
|
@ -1039,6 +1041,26 @@ class Document(BaseModel, TimestampMixin):
|
|||
)
|
||||
|
||||
|
||||
class DocumentVersion(BaseModel, TimestampMixin):
|
||||
__tablename__ = "document_versions"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("document_id", "version_number", name="uq_document_version"),
|
||||
)
|
||||
|
||||
document_id = Column(
|
||||
Integer,
|
||||
ForeignKey("documents.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
version_number = Column(Integer, nullable=False)
|
||||
source_markdown = Column(Text, nullable=True)
|
||||
content_hash = Column(String, nullable=False)
|
||||
title = Column(String, nullable=True)
|
||||
|
||||
document = relationship("Document", backref="versions")
|
||||
|
||||
|
||||
class Chunk(BaseModel, TimestampMixin):
|
||||
__tablename__ = "chunks"
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class PipelineMessages:
|
|||
|
||||
LLM_AUTH = "LLM authentication failed. Check your API key."
|
||||
LLM_PERMISSION = "LLM request denied. Check your account permissions."
|
||||
LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
|
||||
LLM_NOT_FOUND = "Model not found. Check your model configuration."
|
||||
LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid."
|
||||
LLM_UNPROCESSABLE = (
|
||||
"Document exceeds the LLM context window even after optimization."
|
||||
|
|
@ -67,7 +67,7 @@ class PipelineMessages:
|
|||
LLM_RESPONSE = "LLM returned an invalid response."
|
||||
LLM_AUTH = "LLM authentication failed. Check your API key."
|
||||
LLM_PERMISSION = "LLM request denied. Check your account permissions."
|
||||
LLM_NOT_FOUND = "LLM model not found. Check your model configuration."
|
||||
LLM_NOT_FOUND = "Model not found. Check your model configuration."
|
||||
LLM_BAD_REQUEST = "LLM rejected the request. Document content may be invalid."
|
||||
LLM_UNPROCESSABLE = (
|
||||
"Document exceeds the LLM context window even after optimization."
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ router.include_router(confluence_add_connector_router)
|
|||
router.include_router(clickup_add_connector_router)
|
||||
router.include_router(dropbox_add_connector_router)
|
||||
router.include_router(new_llm_config_router) # LLM configs with prompt configuration
|
||||
router.include_router(model_list_router) # Dynamic LLM model catalogue from OpenRouter
|
||||
router.include_router(model_list_router) # Dynamic model catalogue from OpenRouter
|
||||
router.include_router(logs_router)
|
||||
router.include_router(circleback_webhook_router) # Circleback meeting webhooks
|
||||
router.include_router(surfsense_docs_router) # Surfsense documentation for citations
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
# Force asyncio to use standard event loop before unstructured imports
|
||||
import asyncio
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, UploadFile
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Query, UploadFile
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
|
@ -10,6 +11,8 @@ from app.db import (
|
|||
Chunk,
|
||||
Document,
|
||||
DocumentType,
|
||||
DocumentVersion,
|
||||
Folder,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
|
|
@ -17,6 +20,7 @@ from app.db import (
|
|||
get_async_session,
|
||||
)
|
||||
from app.schemas import (
|
||||
ChunkRead,
|
||||
DocumentRead,
|
||||
DocumentsCreate,
|
||||
DocumentStatusBatchResponse,
|
||||
|
|
@ -26,6 +30,7 @@ from app.schemas import (
|
|||
DocumentTitleSearchResponse,
|
||||
DocumentUpdate,
|
||||
DocumentWithChunksRead,
|
||||
FolderRead,
|
||||
PaginatedResponse,
|
||||
)
|
||||
from app.services.task_dispatcher import TaskDispatcher, get_task_dispatcher
|
||||
|
|
@ -45,9 +50,7 @@ os.environ["UNSTRUCTURED_HAS_PATCHED_LOOP"] = "1"
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
MAX_FILES_PER_UPLOAD = 10
|
||||
MAX_FILE_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB per file
|
||||
MAX_TOTAL_SIZE_BYTES = 200 * 1024 * 1024 # 200 MB total
|
||||
MAX_FILE_SIZE_BYTES = 500 * 1024 * 1024 # 500 MB per file
|
||||
|
||||
|
||||
@router.post("/documents")
|
||||
|
|
@ -156,13 +159,6 @@ async def create_documents_file_upload(
|
|||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files provided")
|
||||
|
||||
if len(files) > MAX_FILES_PER_UPLOAD:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"Too many files. Maximum {MAX_FILES_PER_UPLOAD} files per upload.",
|
||||
)
|
||||
|
||||
total_size = 0
|
||||
for file in files:
|
||||
file_size = file.size or 0
|
||||
if file_size > MAX_FILE_SIZE_BYTES:
|
||||
|
|
@ -171,14 +167,6 @@ async def create_documents_file_upload(
|
|||
detail=f"File '{file.filename}' ({file_size / (1024 * 1024):.1f} MB) "
|
||||
f"exceeds the {MAX_FILE_SIZE_BYTES // (1024 * 1024)} MB per-file limit.",
|
||||
)
|
||||
total_size += file_size
|
||||
|
||||
if total_size > MAX_TOTAL_SIZE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"Total upload size ({total_size / (1024 * 1024):.1f} MB) "
|
||||
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
|
||||
)
|
||||
|
||||
# ===== Read all files concurrently to avoid blocking the event loop =====
|
||||
async def _read_and_save(file: UploadFile) -> tuple[str, str, int]:
|
||||
|
|
@ -206,16 +194,6 @@ async def create_documents_file_upload(
|
|||
|
||||
saved_files = await asyncio.gather(*(_read_and_save(f) for f in files))
|
||||
|
||||
actual_total_size = sum(size for _, _, size in saved_files)
|
||||
if actual_total_size > MAX_TOTAL_SIZE_BYTES:
|
||||
for temp_path, _, _ in saved_files:
|
||||
os.unlink(temp_path)
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"Total upload size ({actual_total_size / (1024 * 1024):.1f} MB) "
|
||||
f"exceeds the {MAX_TOTAL_SIZE_BYTES // (1024 * 1024)} MB limit.",
|
||||
)
|
||||
|
||||
# ===== PHASE 1: Create pending documents for all files =====
|
||||
created_documents: list[Document] = []
|
||||
files_to_process: list[tuple[Document, str, str]] = []
|
||||
|
|
@ -451,13 +429,15 @@ async def read_documents(
|
|||
reason=doc.status.get("reason"),
|
||||
)
|
||||
|
||||
raw_content = doc.content or ""
|
||||
api_documents.append(
|
||||
DocumentRead(
|
||||
id=doc.id,
|
||||
title=doc.title,
|
||||
document_type=doc.document_type,
|
||||
document_metadata=doc.document_metadata,
|
||||
content=doc.content,
|
||||
content="",
|
||||
content_preview=raw_content[:300],
|
||||
content_hash=doc.content_hash,
|
||||
unique_identifier_hash=doc.unique_identifier_hash,
|
||||
created_at=doc.created_at,
|
||||
|
|
@ -609,13 +589,15 @@ async def search_documents(
|
|||
reason=doc.status.get("reason"),
|
||||
)
|
||||
|
||||
raw_content = doc.content or ""
|
||||
api_documents.append(
|
||||
DocumentRead(
|
||||
id=doc.id,
|
||||
title=doc.title,
|
||||
document_type=doc.document_type,
|
||||
document_metadata=doc.document_metadata,
|
||||
content=doc.content,
|
||||
content="",
|
||||
content_preview=raw_content[:300],
|
||||
content_hash=doc.content_hash,
|
||||
unique_identifier_hash=doc.unique_identifier_hash,
|
||||
created_at=doc.created_at,
|
||||
|
|
@ -884,16 +866,19 @@ async def get_document_type_counts(
|
|||
@router.get("/documents/by-chunk/{chunk_id}", response_model=DocumentWithChunksRead)
|
||||
async def get_document_by_chunk_id(
|
||||
chunk_id: int,
|
||||
chunk_window: int = Query(
|
||||
5, ge=0, description="Number of chunks before/after the cited chunk to include"
|
||||
),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Retrieves a document based on a chunk ID, including all its chunks ordered by creation time.
|
||||
Requires DOCUMENTS_READ permission for the search space.
|
||||
The document's embedding and chunk embeddings are excluded from the response.
|
||||
Retrieves a document based on a chunk ID, including a window of chunks around the cited one.
|
||||
Uses SQL-level pagination to avoid loading all chunks into memory.
|
||||
"""
|
||||
try:
|
||||
# First, get the chunk and verify it exists
|
||||
from sqlalchemy import and_, func, or_
|
||||
|
||||
chunk_result = await session.execute(select(Chunk).filter(Chunk.id == chunk_id))
|
||||
chunk = chunk_result.scalars().first()
|
||||
|
||||
|
|
@ -902,11 +887,8 @@ async def get_document_by_chunk_id(
|
|||
status_code=404, detail=f"Chunk with id {chunk_id} not found"
|
||||
)
|
||||
|
||||
# Get the associated document
|
||||
document_result = await session.execute(
|
||||
select(Document)
|
||||
.options(selectinload(Document.chunks))
|
||||
.filter(Document.id == chunk.document_id)
|
||||
select(Document).filter(Document.id == chunk.document_id)
|
||||
)
|
||||
document = document_result.scalars().first()
|
||||
|
||||
|
|
@ -916,7 +898,6 @@ async def get_document_by_chunk_id(
|
|||
detail="Document not found",
|
||||
)
|
||||
|
||||
# Check permission for the search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
|
|
@ -925,10 +906,38 @@ async def get_document_by_chunk_id(
|
|||
"You don't have permission to read documents in this search space",
|
||||
)
|
||||
|
||||
# Sort chunks by creation time
|
||||
sorted_chunks = sorted(document.chunks, key=lambda x: x.created_at)
|
||||
total_result = await session.execute(
|
||||
select(func.count())
|
||||
.select_from(Chunk)
|
||||
.filter(Chunk.document_id == document.id)
|
||||
)
|
||||
total_chunks = total_result.scalar() or 0
|
||||
|
||||
cited_idx_result = await session.execute(
|
||||
select(func.count())
|
||||
.select_from(Chunk)
|
||||
.filter(
|
||||
Chunk.document_id == document.id,
|
||||
or_(
|
||||
Chunk.created_at < chunk.created_at,
|
||||
and_(Chunk.created_at == chunk.created_at, Chunk.id < chunk.id),
|
||||
),
|
||||
)
|
||||
)
|
||||
cited_idx = cited_idx_result.scalar() or 0
|
||||
|
||||
start = max(0, cited_idx - chunk_window)
|
||||
end = min(total_chunks, cited_idx + chunk_window + 1)
|
||||
|
||||
windowed_result = await session.execute(
|
||||
select(Chunk)
|
||||
.filter(Chunk.document_id == document.id)
|
||||
.order_by(Chunk.created_at, Chunk.id)
|
||||
.offset(start)
|
||||
.limit(end - start)
|
||||
)
|
||||
windowed_chunks = windowed_result.scalars().all()
|
||||
|
||||
# Return the document with its chunks
|
||||
return DocumentWithChunksRead(
|
||||
id=document.id,
|
||||
title=document.title,
|
||||
|
|
@ -940,7 +949,9 @@ async def get_document_by_chunk_id(
|
|||
created_at=document.created_at,
|
||||
updated_at=document.updated_at,
|
||||
search_space_id=document.search_space_id,
|
||||
chunks=sorted_chunks,
|
||||
chunks=windowed_chunks,
|
||||
total_chunks=total_chunks,
|
||||
chunk_start_index=start,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -950,6 +961,108 @@ async def get_document_by_chunk_id(
|
|||
) from e
|
||||
|
||||
|
||||
@router.get("/documents/watched-folders", response_model=list[FolderRead])
|
||||
async def get_watched_folders(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Return root folders that are marked as watched (metadata->>'watched' = 'true')."""
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
)
|
||||
|
||||
folders = (
|
||||
(
|
||||
await session.execute(
|
||||
select(Folder).where(
|
||||
Folder.search_space_id == search_space_id,
|
||||
Folder.parent_id.is_(None),
|
||||
Folder.folder_metadata.isnot(None),
|
||||
Folder.folder_metadata["watched"].astext == "true",
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
return folders
|
||||
|
||||
|
||||
@router.get(
|
||||
"/documents/{document_id}/chunks",
|
||||
response_model=PaginatedResponse[ChunkRead],
|
||||
)
|
||||
async def get_document_chunks_paginated(
|
||||
document_id: int,
|
||||
page: int = Query(0, ge=0),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
start_offset: int | None = Query(
|
||||
None, ge=0, description="Direct offset; overrides page * page_size"
|
||||
),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Paginated chunk loading for a document.
|
||||
Supports both page-based and offset-based access.
|
||||
"""
|
||||
try:
|
||||
from sqlalchemy import func
|
||||
|
||||
doc_result = await session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
|
||||
if not document:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
document.search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
)
|
||||
|
||||
total_result = await session.execute(
|
||||
select(func.count())
|
||||
.select_from(Chunk)
|
||||
.filter(Chunk.document_id == document_id)
|
||||
)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
offset = start_offset if start_offset is not None else page * page_size
|
||||
chunks_result = await session.execute(
|
||||
select(Chunk)
|
||||
.filter(Chunk.document_id == document_id)
|
||||
.order_by(Chunk.created_at, Chunk.id)
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
)
|
||||
chunks = chunks_result.scalars().all()
|
||||
|
||||
return PaginatedResponse(
|
||||
items=chunks,
|
||||
total=total,
|
||||
page=offset // page_size if page_size else page,
|
||||
page_size=page_size,
|
||||
has_more=(offset + len(chunks)) < total,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch chunks: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/documents/{document_id}", response_model=DocumentRead)
|
||||
async def read_document(
|
||||
document_id: int,
|
||||
|
|
@ -980,13 +1093,14 @@ async def read_document(
|
|||
"You don't have permission to read documents in this search space",
|
||||
)
|
||||
|
||||
# Convert database object to API-friendly format
|
||||
raw_content = document.content or ""
|
||||
return DocumentRead(
|
||||
id=document.id,
|
||||
title=document.title,
|
||||
document_type=document.document_type,
|
||||
document_metadata=document.document_metadata,
|
||||
content=document.content,
|
||||
content=raw_content,
|
||||
content_preview=raw_content[:300],
|
||||
content_hash=document.content_hash,
|
||||
unique_identifier_hash=document.unique_identifier_hash,
|
||||
created_at=document.created_at,
|
||||
|
|
@ -1135,3 +1249,297 @@ async def delete_document(
|
|||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to delete document: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# ====================================================================
|
||||
# Version History Endpoints
|
||||
# ====================================================================
|
||||
|
||||
|
||||
@router.get("/documents/{document_id}/versions")
|
||||
async def list_document_versions(
|
||||
document_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""List all versions for a document, ordered by version_number descending."""
|
||||
document = (
|
||||
await session.execute(select(Document).where(Document.id == document_id))
|
||||
).scalar_one_or_none()
|
||||
if not document:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
await check_permission(
|
||||
session, user, document.search_space_id, Permission.DOCUMENTS_READ.value
|
||||
)
|
||||
|
||||
versions = (
|
||||
(
|
||||
await session.execute(
|
||||
select(DocumentVersion)
|
||||
.where(DocumentVersion.document_id == document_id)
|
||||
.order_by(DocumentVersion.version_number.desc())
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"version_number": v.version_number,
|
||||
"title": v.title,
|
||||
"content_hash": v.content_hash,
|
||||
"created_at": v.created_at.isoformat() if v.created_at else None,
|
||||
}
|
||||
for v in versions
|
||||
]
|
||||
|
||||
|
||||
@router.get("/documents/{document_id}/versions/{version_number}")
|
||||
async def get_document_version(
|
||||
document_id: int,
|
||||
version_number: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get full version content including source_markdown."""
|
||||
document = (
|
||||
await session.execute(select(Document).where(Document.id == document_id))
|
||||
).scalar_one_or_none()
|
||||
if not document:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
await check_permission(
|
||||
session, user, document.search_space_id, Permission.DOCUMENTS_READ.value
|
||||
)
|
||||
|
||||
version = (
|
||||
await session.execute(
|
||||
select(DocumentVersion).where(
|
||||
DocumentVersion.document_id == document_id,
|
||||
DocumentVersion.version_number == version_number,
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if not version:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
return {
|
||||
"version_number": version.version_number,
|
||||
"title": version.title,
|
||||
"content_hash": version.content_hash,
|
||||
"source_markdown": version.source_markdown,
|
||||
"created_at": version.created_at.isoformat() if version.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/documents/{document_id}/versions/{version_number}/restore")
|
||||
async def restore_document_version(
|
||||
document_id: int,
|
||||
version_number: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Restore a previous version: snapshot current state, then overwrite document content."""
|
||||
document = (
|
||||
await session.execute(select(Document).where(Document.id == document_id))
|
||||
).scalar_one_or_none()
|
||||
if not document:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
await check_permission(
|
||||
session, user, document.search_space_id, Permission.DOCUMENTS_UPDATE.value
|
||||
)
|
||||
|
||||
version = (
|
||||
await session.execute(
|
||||
select(DocumentVersion).where(
|
||||
DocumentVersion.document_id == document_id,
|
||||
DocumentVersion.version_number == version_number,
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if not version:
|
||||
raise HTTPException(status_code=404, detail="Version not found")
|
||||
|
||||
# Snapshot current state before restoring
|
||||
from app.utils.document_versioning import create_version_snapshot
|
||||
|
||||
await create_version_snapshot(session, document)
|
||||
|
||||
# Restore the version's content onto the document
|
||||
document.source_markdown = version.source_markdown
|
||||
document.title = version.title or document.title
|
||||
document.content_needs_reindexing = True
|
||||
await session.commit()
|
||||
|
||||
from app.tasks.celery_tasks.document_reindex_tasks import reindex_document_task
|
||||
|
||||
reindex_document_task.delay(document_id, str(user.id))
|
||||
|
||||
return {
|
||||
"message": f"Restored version {version_number}",
|
||||
"document_id": document_id,
|
||||
"restored_version": version_number,
|
||||
}
|
||||
|
||||
|
||||
# ===== Local folder indexing endpoints =====
|
||||
|
||||
|
||||
class FolderIndexRequest(PydanticBaseModel):
|
||||
folder_path: str
|
||||
folder_name: str
|
||||
search_space_id: int
|
||||
exclude_patterns: list[str] | None = None
|
||||
file_extensions: list[str] | None = None
|
||||
root_folder_id: int | None = None
|
||||
enable_summary: bool = False
|
||||
|
||||
|
||||
class FolderIndexFilesRequest(PydanticBaseModel):
|
||||
folder_path: str
|
||||
folder_name: str
|
||||
search_space_id: int
|
||||
target_file_paths: list[str]
|
||||
root_folder_id: int | None = None
|
||||
enable_summary: bool = False
|
||||
|
||||
|
||||
@router.post("/documents/folder-index")
|
||||
async def folder_index(
|
||||
request: FolderIndexRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Full-scan index of a local folder. Creates the root Folder row synchronously
|
||||
and dispatches the heavy indexing work to a Celery task.
|
||||
Returns the root_folder_id so the desktop can persist it.
|
||||
"""
|
||||
from app.config import config as app_config
|
||||
|
||||
if not app_config.is_self_hosted():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Local folder indexing is only available in self-hosted mode",
|
||||
)
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
request.search_space_id,
|
||||
Permission.DOCUMENTS_CREATE.value,
|
||||
"You don't have permission to create documents in this search space",
|
||||
)
|
||||
|
||||
watched_metadata = {
|
||||
"watched": True,
|
||||
"folder_path": request.folder_path,
|
||||
"exclude_patterns": request.exclude_patterns,
|
||||
"file_extensions": request.file_extensions,
|
||||
}
|
||||
|
||||
root_folder_id = request.root_folder_id
|
||||
if root_folder_id:
|
||||
existing = (
|
||||
await session.execute(select(Folder).where(Folder.id == root_folder_id))
|
||||
).scalar_one_or_none()
|
||||
if not existing:
|
||||
root_folder_id = None
|
||||
else:
|
||||
existing.folder_metadata = watched_metadata
|
||||
await session.commit()
|
||||
|
||||
if not root_folder_id:
|
||||
root_folder = Folder(
|
||||
name=request.folder_name,
|
||||
search_space_id=request.search_space_id,
|
||||
created_by_id=str(user.id),
|
||||
position="a0",
|
||||
folder_metadata=watched_metadata,
|
||||
)
|
||||
session.add(root_folder)
|
||||
await session.flush()
|
||||
root_folder_id = root_folder.id
|
||||
await session.commit()
|
||||
|
||||
from app.tasks.celery_tasks.document_tasks import index_local_folder_task
|
||||
|
||||
index_local_folder_task.delay(
|
||||
search_space_id=request.search_space_id,
|
||||
user_id=str(user.id),
|
||||
folder_path=request.folder_path,
|
||||
folder_name=request.folder_name,
|
||||
exclude_patterns=request.exclude_patterns,
|
||||
file_extensions=request.file_extensions,
|
||||
root_folder_id=root_folder_id,
|
||||
enable_summary=request.enable_summary,
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Folder indexing started",
|
||||
"status": "processing",
|
||||
"root_folder_id": root_folder_id,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/documents/folder-index-files")
|
||||
async def folder_index_files(
|
||||
request: FolderIndexFilesRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Index multiple files within a watched folder (batched chokidar trigger).
|
||||
Validates that all target_file_paths are under folder_path.
|
||||
Dispatches a single Celery task that processes them in parallel.
|
||||
"""
|
||||
from app.config import config as app_config
|
||||
|
||||
if not app_config.is_self_hosted():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Local folder indexing is only available in self-hosted mode",
|
||||
)
|
||||
|
||||
if not request.target_file_paths:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="target_file_paths must not be empty"
|
||||
)
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
request.search_space_id,
|
||||
Permission.DOCUMENTS_CREATE.value,
|
||||
"You don't have permission to create documents in this search space",
|
||||
)
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
for fp in request.target_file_paths:
|
||||
try:
|
||||
Path(fp).relative_to(request.folder_path)
|
||||
except ValueError as err:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"target_file_path {fp} must be inside folder_path",
|
||||
) from err
|
||||
|
||||
from app.tasks.celery_tasks.document_tasks import index_local_folder_task
|
||||
|
||||
index_local_folder_task.delay(
|
||||
search_space_id=request.search_space_id,
|
||||
user_id=str(user.id),
|
||||
folder_path=request.folder_path,
|
||||
folder_name=request.folder_name,
|
||||
target_file_paths=request.target_file_paths,
|
||||
root_folder_id=request.root_folder_id,
|
||||
enable_summary=request.enable_summary,
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Batch indexing started for {len(request.target_file_paths)} file(s)",
|
||||
"status": "processing",
|
||||
"file_count": len(request.target_file_paths),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,11 +15,10 @@ import pypandoc
|
|||
import typst
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.db import Document, DocumentType, Permission, User, get_async_session
|
||||
from app.db import Chunk, Document, DocumentType, Permission, User, get_async_session
|
||||
from app.routes.reports_routes import (
|
||||
_FILE_EXTENSIONS,
|
||||
_MEDIA_TYPES,
|
||||
|
|
@ -44,6 +43,9 @@ router = APIRouter()
|
|||
async def get_editor_content(
|
||||
search_space_id: int,
|
||||
document_id: int,
|
||||
max_length: int | None = Query(
|
||||
None, description="Truncate source_markdown to this many characters"
|
||||
),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
|
|
@ -65,9 +67,7 @@ async def get_editor_content(
|
|||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(Document)
|
||||
.options(selectinload(Document.chunks))
|
||||
.filter(
|
||||
select(Document).filter(
|
||||
Document.id == document_id,
|
||||
Document.search_space_id == search_space_id,
|
||||
)
|
||||
|
|
@ -77,80 +77,152 @@ async def get_editor_content(
|
|||
if not document:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
# Priority 1: Return source_markdown if it exists (check `is not None` to allow empty strings)
|
||||
if document.source_markdown is not None:
|
||||
count_result = await session.execute(
|
||||
select(func.count()).select_from(Chunk).filter(Chunk.document_id == document_id)
|
||||
)
|
||||
chunk_count = count_result.scalar() or 0
|
||||
|
||||
def _build_response(md: str) -> dict:
|
||||
size_bytes = len(md.encode("utf-8"))
|
||||
truncated = False
|
||||
output_md = md
|
||||
if max_length is not None and size_bytes > max_length:
|
||||
output_md = md[:max_length]
|
||||
truncated = True
|
||||
return {
|
||||
"document_id": document.id,
|
||||
"title": document.title,
|
||||
"document_type": document.document_type.value,
|
||||
"source_markdown": document.source_markdown,
|
||||
"source_markdown": output_md,
|
||||
"content_size_bytes": size_bytes,
|
||||
"chunk_count": chunk_count,
|
||||
"truncated": truncated,
|
||||
"updated_at": document.updated_at.isoformat()
|
||||
if document.updated_at
|
||||
else None,
|
||||
}
|
||||
|
||||
# Priority 2: Lazy-migrate from blocknote_document (pure Python, no external deps)
|
||||
if document.source_markdown is not None:
|
||||
return _build_response(document.source_markdown)
|
||||
|
||||
if document.blocknote_document:
|
||||
from app.utils.blocknote_to_markdown import blocknote_to_markdown
|
||||
|
||||
markdown = blocknote_to_markdown(document.blocknote_document)
|
||||
if markdown:
|
||||
# Persist the migration so we don't repeat it
|
||||
document.source_markdown = markdown
|
||||
await session.commit()
|
||||
return {
|
||||
"document_id": document.id,
|
||||
"title": document.title,
|
||||
"document_type": document.document_type.value,
|
||||
"source_markdown": markdown,
|
||||
"updated_at": document.updated_at.isoformat()
|
||||
if document.updated_at
|
||||
else None,
|
||||
}
|
||||
return _build_response(markdown)
|
||||
|
||||
# Priority 3: For NOTE type with no content, return empty markdown
|
||||
if document.document_type == DocumentType.NOTE:
|
||||
empty_markdown = ""
|
||||
document.source_markdown = empty_markdown
|
||||
await session.commit()
|
||||
return {
|
||||
"document_id": document.id,
|
||||
"title": document.title,
|
||||
"document_type": document.document_type.value,
|
||||
"source_markdown": empty_markdown,
|
||||
"updated_at": document.updated_at.isoformat()
|
||||
if document.updated_at
|
||||
else None,
|
||||
}
|
||||
return _build_response(empty_markdown)
|
||||
|
||||
# Priority 4: Reconstruct from chunks
|
||||
chunks = sorted(document.chunks, key=lambda c: c.id)
|
||||
chunk_contents_result = await session.execute(
|
||||
select(Chunk.content)
|
||||
.filter(Chunk.document_id == document_id)
|
||||
.order_by(Chunk.id)
|
||||
)
|
||||
chunk_contents = chunk_contents_result.scalars().all()
|
||||
|
||||
if not chunks:
|
||||
if not chunk_contents:
|
||||
doc_status = document.status or {}
|
||||
state = (
|
||||
doc_status.get("state", "ready")
|
||||
if isinstance(doc_status, dict)
|
||||
else "ready"
|
||||
)
|
||||
if state in ("pending", "processing"):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="This document is still being processed. Please wait a moment and try again.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="This document has no content and cannot be edited. Please re-upload to enable editing.",
|
||||
detail="This document has no viewable content yet. It may still be syncing. Try again in a few seconds, or re-upload if the issue persists.",
|
||||
)
|
||||
|
||||
markdown_content = "\n\n".join(chunk.content for chunk in chunks)
|
||||
markdown_content = "\n\n".join(chunk_contents)
|
||||
|
||||
if not markdown_content.strip():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="This document has empty content and cannot be edited.",
|
||||
detail="This document appears to be empty. Try re-uploading or editing it to add content.",
|
||||
)
|
||||
|
||||
# Persist the lazy migration
|
||||
document.source_markdown = markdown_content
|
||||
await session.commit()
|
||||
|
||||
return {
|
||||
"document_id": document.id,
|
||||
"title": document.title,
|
||||
"document_type": document.document_type.value,
|
||||
"source_markdown": markdown_content,
|
||||
"updated_at": document.updated_at.isoformat() if document.updated_at else None,
|
||||
}
|
||||
return _build_response(markdown_content)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/search-spaces/{search_space_id}/documents/{document_id}/download-markdown"
|
||||
)
|
||||
async def download_document_markdown(
|
||||
search_space_id: int,
|
||||
document_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Download the full document content as a .md file.
|
||||
Reconstructs markdown from source_markdown or chunks.
|
||||
"""
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.DOCUMENTS_READ.value,
|
||||
"You don't have permission to read documents in this search space",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(Document).filter(
|
||||
Document.id == document_id,
|
||||
Document.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
if not document:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
markdown: str | None = document.source_markdown
|
||||
if markdown is None and document.blocknote_document:
|
||||
from app.utils.blocknote_to_markdown import blocknote_to_markdown
|
||||
|
||||
markdown = blocknote_to_markdown(document.blocknote_document)
|
||||
if markdown is None:
|
||||
chunk_contents_result = await session.execute(
|
||||
select(Chunk.content)
|
||||
.filter(Chunk.document_id == document_id)
|
||||
.order_by(Chunk.id)
|
||||
)
|
||||
chunk_contents = chunk_contents_result.scalars().all()
|
||||
if chunk_contents:
|
||||
markdown = "\n\n".join(chunk_contents)
|
||||
|
||||
if not markdown or not markdown.strip():
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Document has no content to download"
|
||||
)
|
||||
|
||||
safe_title = (
|
||||
"".join(
|
||||
c if c.isalnum() or c in " -_" else "_"
|
||||
for c in (document.title or "document")
|
||||
).strip()[:80]
|
||||
or "document"
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(markdown.encode("utf-8")),
|
||||
media_type="text/markdown; charset=utf-8",
|
||||
headers={"Content-Disposition": f'attachment; filename="{safe_title}.md"'},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/search-spaces/{search_space_id}/documents/{document_id}/save")
|
||||
|
|
@ -258,9 +330,7 @@ async def export_document(
|
|||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(Document)
|
||||
.options(selectinload(Document.chunks))
|
||||
.filter(
|
||||
select(Document).filter(
|
||||
Document.id == document_id,
|
||||
Document.search_space_id == search_space_id,
|
||||
)
|
||||
|
|
@ -269,16 +339,20 @@ async def export_document(
|
|||
if not document:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
# Resolve markdown content (same priority as editor-content endpoint)
|
||||
markdown_content: str | None = document.source_markdown
|
||||
if markdown_content is None and document.blocknote_document:
|
||||
from app.utils.blocknote_to_markdown import blocknote_to_markdown
|
||||
|
||||
markdown_content = blocknote_to_markdown(document.blocknote_document)
|
||||
if markdown_content is None:
|
||||
chunks = sorted(document.chunks, key=lambda c: c.id)
|
||||
if chunks:
|
||||
markdown_content = "\n\n".join(chunk.content for chunk in chunks)
|
||||
chunk_contents_result = await session.execute(
|
||||
select(Chunk.content)
|
||||
.filter(Chunk.document_id == document_id)
|
||||
.order_by(Chunk.id)
|
||||
)
|
||||
chunk_contents = chunk_contents_result.scalars().all()
|
||||
if chunk_contents:
|
||||
markdown_content = "\n\n".join(chunk_contents)
|
||||
|
||||
if not markdown_content or not markdown_content.strip():
|
||||
raise HTTPException(status_code=400, detail="Document has no content to export")
|
||||
|
|
|
|||
|
|
@ -192,6 +192,33 @@ async def get_folder_breadcrumb(
|
|||
) from e
|
||||
|
||||
|
||||
@router.patch("/folders/{folder_id}/watched")
|
||||
async def stop_watching_folder(
|
||||
folder_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Clear the watched flag from a folder's metadata."""
|
||||
folder = await session.get(Folder, folder_id)
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="Folder not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
folder.search_space_id,
|
||||
Permission.DOCUMENTS_UPDATE.value,
|
||||
"You don't have permission to update folders in this search space",
|
||||
)
|
||||
|
||||
if folder.folder_metadata and isinstance(folder.folder_metadata, dict):
|
||||
updated = {**folder.folder_metadata, "watched": False}
|
||||
folder.folder_metadata = updated
|
||||
await session.commit()
|
||||
|
||||
return {"message": "Folder watch status updated"}
|
||||
|
||||
|
||||
@router.put("/folders/{folder_id}", response_model=FolderRead)
|
||||
async def update_folder(
|
||||
folder_id: int,
|
||||
|
|
@ -340,7 +367,7 @@ async def delete_folder(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Delete a folder and cascade-delete subfolders. Documents are async-deleted via Celery."""
|
||||
"""Mark documents for deletion and dispatch Celery to delete docs first, then folders."""
|
||||
try:
|
||||
folder = await session.get(Folder, folder_id)
|
||||
if not folder:
|
||||
|
|
@ -372,30 +399,29 @@ async def delete_folder(
|
|||
)
|
||||
await session.commit()
|
||||
|
||||
await session.execute(Folder.__table__.delete().where(Folder.id == folder_id))
|
||||
await session.commit()
|
||||
try:
|
||||
from app.tasks.celery_tasks.document_tasks import (
|
||||
delete_folder_documents_task,
|
||||
)
|
||||
|
||||
if document_ids:
|
||||
try:
|
||||
from app.tasks.celery_tasks.document_tasks import (
|
||||
delete_folder_documents_task,
|
||||
)
|
||||
|
||||
delete_folder_documents_task.delay(document_ids)
|
||||
except Exception as err:
|
||||
delete_folder_documents_task.delay(
|
||||
document_ids, folder_subtree_ids=list(subtree_ids)
|
||||
)
|
||||
except Exception as err:
|
||||
if document_ids:
|
||||
await session.execute(
|
||||
Document.__table__.update()
|
||||
.where(Document.id.in_(document_ids))
|
||||
.values(status={"state": "ready"})
|
||||
)
|
||||
await session.commit()
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Folder deleted but document cleanup could not be queued. Documents have been restored.",
|
||||
) from err
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Could not queue folder deletion. Documents have been restored.",
|
||||
) from err
|
||||
|
||||
return {
|
||||
"message": "Folder deleted successfully",
|
||||
"message": "Folder deletion started",
|
||||
"documents_queued_for_deletion": len(document_ids),
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
API route for fetching the available LLM models catalogue.
|
||||
API route for fetching the available models catalogue.
|
||||
|
||||
Serves a dynamically-updated list sourced from the OpenRouter public API,
|
||||
with a local JSON fallback when the API is unreachable.
|
||||
|
|
@ -30,7 +30,7 @@ async def list_available_models(
|
|||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Return all available LLM models grouped by provider.
|
||||
Return all available models grouped by provider.
|
||||
|
||||
The list is sourced from the OpenRouter public API and cached for 1 hour.
|
||||
If the API is unreachable, a local fallback file is used instead.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
API routes for NewLLMConfig CRUD operations.
|
||||
|
||||
NewLLMConfig combines LLM model settings with prompt configuration:
|
||||
NewLLMConfig combines model settings with prompt configuration:
|
||||
- LLM provider, model, API key, etc.
|
||||
- Configurable system instructions
|
||||
- Citation toggle
|
||||
|
|
|
|||
|
|
@ -55,23 +55,12 @@ from app.schemas import (
|
|||
)
|
||||
from app.services.composio_service import ComposioService, get_composio_service
|
||||
from app.services.notification_service import NotificationService
|
||||
from app.tasks.connector_indexers import (
|
||||
index_airtable_records,
|
||||
index_clickup_tasks,
|
||||
index_confluence_pages,
|
||||
index_crawled_urls,
|
||||
index_discord_messages,
|
||||
index_elasticsearch_documents,
|
||||
index_github_repos,
|
||||
index_google_calendar_events,
|
||||
index_google_gmail_messages,
|
||||
index_jira_issues,
|
||||
index_linear_issues,
|
||||
index_luma_events,
|
||||
index_notion_pages,
|
||||
index_slack_messages,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
|
||||
# NOTE: connector indexer functions are imported lazily inside each
|
||||
# ``run_*_indexing`` helper to break a circular import cycle:
|
||||
# connector_indexers.__init__ → airtable_indexer → airtable_history
|
||||
# → app.routes.__init__ → this file → connector_indexers (not ready yet)
|
||||
from app.utils.connector_naming import ensure_unique_connector_name
|
||||
from app.utils.indexing_locks import (
|
||||
acquire_connector_indexing_lock,
|
||||
|
|
@ -1378,6 +1367,8 @@ async def run_slack_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_slack_messages
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -1824,6 +1815,8 @@ async def run_notion_indexing_with_new_session(
|
|||
Create a new session and run the Notion indexing task.
|
||||
This prevents session leaks by creating a dedicated session for the background task.
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_notion_pages
|
||||
|
||||
async with async_session_maker() as session:
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
|
|
@ -1858,6 +1851,8 @@ async def run_notion_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_notion_pages
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -1910,6 +1905,8 @@ async def run_github_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_github_repos
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -1961,6 +1958,8 @@ async def run_linear_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_linear_issues
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -2011,6 +2010,8 @@ async def run_discord_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_discord_messages
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -2113,6 +2114,8 @@ async def run_jira_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_jira_issues
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -2166,6 +2169,8 @@ async def run_confluence_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_confluence_pages
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -2217,6 +2222,8 @@ async def run_clickup_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_clickup_tasks
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -2268,6 +2275,8 @@ async def run_airtable_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_airtable_records
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -2321,6 +2330,8 @@ async def run_google_calendar_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_google_calendar_events
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -2370,6 +2381,7 @@ async def run_google_gmail_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_google_gmail_messages
|
||||
|
||||
# Create a wrapper function that calls index_google_gmail_messages with max_messages
|
||||
async def gmail_indexing_wrapper(
|
||||
|
|
@ -2836,6 +2848,8 @@ async def run_luma_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_luma_events
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -2888,6 +2902,8 @@ async def run_elasticsearch_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_elasticsearch_documents
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
@ -2938,6 +2954,8 @@ async def run_web_page_indexing(
|
|||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_crawled_urls
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
|
|
|
|||
|
|
@ -53,25 +53,26 @@ class DocumentRead(BaseModel):
|
|||
title: str
|
||||
document_type: DocumentType
|
||||
document_metadata: dict
|
||||
content: str # Changed to string to match frontend
|
||||
content: str = ""
|
||||
content_preview: str = ""
|
||||
content_hash: str
|
||||
unique_identifier_hash: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime | None
|
||||
search_space_id: int
|
||||
folder_id: int | None = None
|
||||
created_by_id: UUID | None = None # User who created/uploaded this document
|
||||
created_by_id: UUID | None = None
|
||||
created_by_name: str | None = None
|
||||
created_by_email: str | None = None
|
||||
status: DocumentStatusSchema | None = (
|
||||
None # Processing status (ready, processing, failed)
|
||||
)
|
||||
status: DocumentStatusSchema | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class DocumentWithChunksRead(DocumentRead):
|
||||
chunks: list[ChunkRead] = []
|
||||
total_chunks: int = 0
|
||||
chunk_start_index: int = 0
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Pydantic schemas for folder CRUD, move, and reorder operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
|
@ -34,6 +35,9 @@ class FolderRead(BaseModel):
|
|||
created_by_id: UUID | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
metadata: dict[str, Any] | None = Field(
|
||||
default=None, validation_alias="folder_metadata"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
Pydantic schemas for the NewLLMConfig API.
|
||||
|
||||
NewLLMConfig combines LLM model settings with prompt configuration:
|
||||
NewLLMConfig combines model settings with prompt configuration:
|
||||
- LLM provider, model, API key, etc.
|
||||
- Configurable system instructions
|
||||
- Citation toggle
|
||||
|
|
@ -26,7 +26,7 @@ class NewLLMConfigBase(BaseModel):
|
|||
None, max_length=500, description="Optional description"
|
||||
)
|
||||
|
||||
# LLM Model Configuration
|
||||
# Model Configuration
|
||||
provider: LiteLLMProvider = Field(..., description="LiteLLM provider type")
|
||||
custom_provider: str | None = Field(
|
||||
None, max_length=100, description="Custom provider name when provider is CUSTOM"
|
||||
|
|
@ -71,7 +71,7 @@ class NewLLMConfigUpdate(BaseModel):
|
|||
name: str | None = Field(None, max_length=100)
|
||||
description: str | None = Field(None, max_length=500)
|
||||
|
||||
# LLM Model Configuration
|
||||
# Model Configuration
|
||||
provider: LiteLLMProvider | None = None
|
||||
custom_provider: str | None = Field(None, max_length=100)
|
||||
model_name: str | None = Field(None, max_length=100)
|
||||
|
|
@ -106,7 +106,7 @@ class NewLLMConfigPublic(BaseModel):
|
|||
name: str
|
||||
description: str | None = None
|
||||
|
||||
# LLM Model Configuration (no api_key)
|
||||
# Model Configuration (no api_key)
|
||||
provider: LiteLLMProvider
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
|
|
@ -149,7 +149,7 @@ class GlobalNewLLMConfigRead(BaseModel):
|
|||
name: str
|
||||
description: str | None = None
|
||||
|
||||
# LLM Model Configuration (no api_key)
|
||||
# Model Configuration (no api_key)
|
||||
provider: str # String because YAML doesn't enforce enum, "AUTO" for Auto mode
|
||||
custom_provider: str | None = None
|
||||
model_name: str
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
Service for fetching and caching the available LLM model list.
|
||||
Service for fetching and caching the available model list.
|
||||
|
||||
Uses the OpenRouter public API as the primary source, with a local
|
||||
fallback JSON file when the API is unreachable.
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Celery tasks for document processing."""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
from uuid import UUID
|
||||
|
|
@ -10,6 +11,7 @@ from app.config import config
|
|||
from app.services.notification_service import NotificationService
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
from app.tasks.connector_indexers.local_folder_indexer import index_local_folder
|
||||
from app.tasks.document_processors import (
|
||||
add_extension_received_document,
|
||||
add_youtube_video_document,
|
||||
|
|
@ -141,21 +143,30 @@ async def _delete_document_background(document_id: int) -> None:
|
|||
retry_backoff_max=300,
|
||||
max_retries=5,
|
||||
)
|
||||
def delete_folder_documents_task(self, document_ids: list[int]):
|
||||
"""Celery task to batch-delete documents orphaned by folder deletion."""
|
||||
def delete_folder_documents_task(
|
||||
self,
|
||||
document_ids: list[int],
|
||||
folder_subtree_ids: list[int] | None = None,
|
||||
):
|
||||
"""Celery task to delete documents first, then the folder rows."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(_delete_folder_documents(document_ids))
|
||||
loop.run_until_complete(
|
||||
_delete_folder_documents(document_ids, folder_subtree_ids)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _delete_folder_documents(document_ids: list[int]) -> None:
|
||||
"""Delete chunks in batches, then document rows for each orphaned document."""
|
||||
async def _delete_folder_documents(
|
||||
document_ids: list[int],
|
||||
folder_subtree_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
"""Delete chunks in batches, then document rows, then folder rows."""
|
||||
from sqlalchemy import delete as sa_delete, select
|
||||
|
||||
from app.db import Chunk, Document
|
||||
from app.db import Chunk, Document, Folder
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
batch_size = 500
|
||||
|
|
@ -177,6 +188,12 @@ async def _delete_folder_documents(document_ids: list[int]) -> None:
|
|||
await session.delete(doc)
|
||||
await session.commit()
|
||||
|
||||
if folder_subtree_ids:
|
||||
await session.execute(
|
||||
sa_delete(Folder).where(Folder.id.in_(folder_subtree_ids))
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="delete_search_space_background",
|
||||
|
|
@ -1243,3 +1260,154 @@ async def _process_circleback_meeting(
|
|||
heartbeat_task.cancel()
|
||||
if notification:
|
||||
_stop_heartbeat(notification.id)
|
||||
|
||||
|
||||
# ===== Local folder indexing task =====
|
||||
|
||||
|
||||
@celery_app.task(name="index_local_folder", bind=True)
|
||||
def index_local_folder_task(
|
||||
self,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
folder_path: str,
|
||||
folder_name: str,
|
||||
exclude_patterns: list[str] | None = None,
|
||||
file_extensions: list[str] | None = None,
|
||||
root_folder_id: int | None = None,
|
||||
enable_summary: bool = False,
|
||||
target_file_paths: list[str] | None = None,
|
||||
):
|
||||
"""Celery task to index a local folder. Config is passed directly — no connector row."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_local_folder_async(
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
folder_path=folder_path,
|
||||
folder_name=folder_name,
|
||||
exclude_patterns=exclude_patterns,
|
||||
file_extensions=file_extensions,
|
||||
root_folder_id=root_folder_id,
|
||||
enable_summary=enable_summary,
|
||||
target_file_paths=target_file_paths,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_local_folder_async(
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
folder_path: str,
|
||||
folder_name: str,
|
||||
exclude_patterns: list[str] | None = None,
|
||||
file_extensions: list[str] | None = None,
|
||||
root_folder_id: int | None = None,
|
||||
enable_summary: bool = False,
|
||||
target_file_paths: list[str] | None = None,
|
||||
):
|
||||
"""Run local folder indexing with notification + heartbeat."""
|
||||
is_batch = bool(target_file_paths)
|
||||
is_full_scan = not target_file_paths
|
||||
file_count = len(target_file_paths) if target_file_paths else None
|
||||
|
||||
if is_batch:
|
||||
doc_name = f"{folder_name} ({file_count} file{'s' if file_count != 1 else ''})"
|
||||
else:
|
||||
doc_name = folder_name
|
||||
|
||||
notification = None
|
||||
notification_id: int | None = None
|
||||
heartbeat_task = None
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
try:
|
||||
notification = (
|
||||
await NotificationService.document_processing.notify_processing_started(
|
||||
session=session,
|
||||
user_id=UUID(user_id),
|
||||
document_type="LOCAL_FOLDER_FILE",
|
||||
document_name=doc_name,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
)
|
||||
notification_id = notification.id
|
||||
_start_heartbeat(notification_id)
|
||||
heartbeat_task = asyncio.create_task(_run_heartbeat_loop(notification_id))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to create notification for local folder indexing",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _heartbeat_progress(completed_count: int) -> None:
|
||||
"""Refresh heartbeat and optionally update notification progress."""
|
||||
if notification:
|
||||
with contextlib.suppress(Exception):
|
||||
await NotificationService.document_processing.notify_processing_progress(
|
||||
session=session,
|
||||
notification=notification,
|
||||
stage="indexing",
|
||||
stage_message=f"Syncing files ({completed_count}/{file_count or '?'})",
|
||||
)
|
||||
|
||||
try:
|
||||
_indexed, _skipped_or_failed, _rfid, err = await index_local_folder(
|
||||
session=session,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
folder_path=folder_path,
|
||||
folder_name=folder_name,
|
||||
exclude_patterns=exclude_patterns,
|
||||
file_extensions=file_extensions,
|
||||
root_folder_id=root_folder_id,
|
||||
enable_summary=enable_summary,
|
||||
target_file_paths=target_file_paths,
|
||||
on_heartbeat_callback=_heartbeat_progress
|
||||
if (is_batch or is_full_scan)
|
||||
else None,
|
||||
)
|
||||
|
||||
if notification:
|
||||
try:
|
||||
await session.refresh(notification)
|
||||
if err:
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message=err,
|
||||
)
|
||||
else:
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to update notification after local folder indexing",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Local folder indexing failed: {e}")
|
||||
if notification:
|
||||
try:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.document_processing.notify_processing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
error_message=str(e)[:200],
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
finally:
|
||||
if heartbeat_task:
|
||||
heartbeat_task.cancel()
|
||||
if notification_id is not None:
|
||||
_stop_heartbeat(notification_id)
|
||||
|
|
|
|||
|
|
@ -39,7 +39,6 @@ from app.agents.new_chat.llm_config import (
|
|||
)
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
Document,
|
||||
NewChatMessage,
|
||||
NewChatThread,
|
||||
Report,
|
||||
|
|
@ -63,74 +62,6 @@ _perf_log = get_perf_logger()
|
|||
_background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
def format_mentioned_documents_as_context(documents: list[Document]) -> str:
|
||||
"""
|
||||
Format mentioned documents as context for the agent.
|
||||
|
||||
Uses the same XML structure as knowledge_base.format_documents_for_context
|
||||
to ensure citations work properly with chunk IDs.
|
||||
"""
|
||||
if not documents:
|
||||
return ""
|
||||
|
||||
context_parts = ["<mentioned_documents>"]
|
||||
context_parts.append(
|
||||
"The user has explicitly mentioned the following documents from their knowledge base. "
|
||||
"These documents are directly relevant to the query and should be prioritized as primary sources. "
|
||||
"Use [citation:CHUNK_ID] format for citations (e.g., [citation:123])."
|
||||
)
|
||||
context_parts.append("")
|
||||
|
||||
for doc in documents:
|
||||
# Build metadata JSON
|
||||
metadata = doc.document_metadata or {}
|
||||
metadata_json = json.dumps(metadata, ensure_ascii=False)
|
||||
|
||||
# Get URL from metadata
|
||||
url = (
|
||||
metadata.get("url")
|
||||
or metadata.get("source")
|
||||
or metadata.get("page_url")
|
||||
or ""
|
||||
)
|
||||
|
||||
context_parts.append("<document>")
|
||||
context_parts.append("<document_metadata>")
|
||||
context_parts.append(f" <document_id>{doc.id}</document_id>")
|
||||
context_parts.append(
|
||||
f" <document_type>{doc.document_type.value}</document_type>"
|
||||
)
|
||||
context_parts.append(f" <title><![CDATA[{doc.title}]]></title>")
|
||||
context_parts.append(f" <url><![CDATA[{url}]]></url>")
|
||||
context_parts.append(
|
||||
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>"
|
||||
)
|
||||
context_parts.append("</document_metadata>")
|
||||
context_parts.append("")
|
||||
context_parts.append("<document_content>")
|
||||
|
||||
# Use chunks if available (preferred for proper citations)
|
||||
if hasattr(doc, "chunks") and doc.chunks:
|
||||
for chunk in doc.chunks:
|
||||
context_parts.append(
|
||||
f" <chunk id='{chunk.id}'><![CDATA[{chunk.content}]]></chunk>"
|
||||
)
|
||||
else:
|
||||
# Fallback to document content if chunks not loaded
|
||||
# Use document ID as chunk ID prefix for consistency
|
||||
context_parts.append(
|
||||
f" <chunk id='{doc.id}'><![CDATA[{doc.content}]]></chunk>"
|
||||
)
|
||||
|
||||
context_parts.append("</document_content>")
|
||||
context_parts.append("</document>")
|
||||
context_parts.append("")
|
||||
|
||||
context_parts.append("</mentioned_documents>")
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
|
||||
def format_mentioned_surfsense_docs_as_context(
|
||||
documents: list[SurfsenseDocsDocument],
|
||||
) -> str:
|
||||
|
|
@ -1317,6 +1248,7 @@ async def stream_new_chat(
|
|||
firecrawl_api_key=firecrawl_api_key,
|
||||
thread_visibility=visibility,
|
||||
disabled_tools=disabled_tools,
|
||||
mentioned_document_ids=mentioned_document_ids,
|
||||
)
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Agent created in %.3fs", time.perf_counter() - _t0
|
||||
|
|
@ -1340,18 +1272,9 @@ async def stream_new_chat(
|
|||
thread.needs_history_bootstrap = False
|
||||
await session.commit()
|
||||
|
||||
# Fetch mentioned documents if any (with chunks for proper citations)
|
||||
mentioned_documents: list[Document] = []
|
||||
if mentioned_document_ids:
|
||||
result = await session.execute(
|
||||
select(Document)
|
||||
.options(selectinload(Document.chunks))
|
||||
.filter(
|
||||
Document.id.in_(mentioned_document_ids),
|
||||
Document.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
mentioned_documents = list(result.scalars().all())
|
||||
# Mentioned KB documents are now handled by KnowledgeBaseSearchMiddleware
|
||||
# which merges them into the scoped filesystem with full document
|
||||
# structure. Only SurfSense docs and report context are inlined here.
|
||||
|
||||
# Fetch mentioned SurfSense docs if any
|
||||
mentioned_surfsense_docs: list[SurfsenseDocsDocument] = []
|
||||
|
|
@ -1379,15 +1302,10 @@ async def stream_new_chat(
|
|||
)
|
||||
recent_reports = list(recent_reports_result.scalars().all())
|
||||
|
||||
# Format the user query with context (mentioned documents + SurfSense docs)
|
||||
# Format the user query with context (SurfSense docs + reports only)
|
||||
final_query = user_query
|
||||
context_parts = []
|
||||
|
||||
if mentioned_documents:
|
||||
context_parts.append(
|
||||
format_mentioned_documents_as_context(mentioned_documents)
|
||||
)
|
||||
|
||||
if mentioned_surfsense_docs:
|
||||
context_parts.append(
|
||||
format_mentioned_surfsense_docs_as_context(mentioned_surfsense_docs)
|
||||
|
|
@ -1479,7 +1397,7 @@ async def stream_new_chat(
|
|||
yield streaming_service.format_start_step()
|
||||
|
||||
# Initial thinking step - analyzing the request
|
||||
if mentioned_documents or mentioned_surfsense_docs:
|
||||
if mentioned_surfsense_docs:
|
||||
initial_title = "Analyzing referenced content"
|
||||
action_verb = "Analyzing"
|
||||
else:
|
||||
|
|
@ -1490,18 +1408,6 @@ async def stream_new_chat(
|
|||
query_text = user_query[:80] + ("..." if len(user_query) > 80 else "")
|
||||
processing_parts.append(query_text)
|
||||
|
||||
if mentioned_documents:
|
||||
doc_names = []
|
||||
for doc in mentioned_documents:
|
||||
title = doc.title
|
||||
if len(title) > 30:
|
||||
title = title[:27] + "..."
|
||||
doc_names.append(title)
|
||||
if len(doc_names) == 1:
|
||||
processing_parts.append(f"[{doc_names[0]}]")
|
||||
else:
|
||||
processing_parts.append(f"[{len(doc_names)} documents]")
|
||||
|
||||
if mentioned_surfsense_docs:
|
||||
doc_names = []
|
||||
for doc in mentioned_surfsense_docs:
|
||||
|
|
@ -1527,7 +1433,7 @@ async def stream_new_chat(
|
|||
# These ORM objects (with eagerly-loaded chunks) can be very large.
|
||||
# They're only needed to build context strings already copied into
|
||||
# final_query / langchain_messages — release them before streaming.
|
||||
del mentioned_documents, mentioned_surfsense_docs, recent_reports
|
||||
del mentioned_surfsense_docs, recent_reports
|
||||
del langchain_messages, final_query
|
||||
|
||||
# Check if this is the first assistant response so we can generate
|
||||
|
|
|
|||
|
|
@ -42,9 +42,9 @@ from .jira_indexer import index_jira_issues
|
|||
|
||||
# Issue tracking and project management
|
||||
from .linear_indexer import index_linear_issues
|
||||
from .luma_indexer import index_luma_events
|
||||
|
||||
# Documentation and knowledge management
|
||||
from .luma_indexer import index_luma_events
|
||||
from .notion_indexer import index_notion_pages
|
||||
from .obsidian_indexer import index_obsidian_vault
|
||||
from .slack_indexer import index_slack_messages
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -12,16 +12,14 @@ Available processors:
|
|||
- YouTube processor: Process YouTube videos and extract transcripts
|
||||
"""
|
||||
|
||||
# URL crawler
|
||||
# Extension processor
|
||||
from .extension_processor import add_extension_received_document
|
||||
|
||||
# File processors
|
||||
from .file_processors import (
|
||||
# File processors (backward-compatible re-exports from _save)
|
||||
from ._save import (
|
||||
add_received_file_document_using_docling,
|
||||
add_received_file_document_using_llamacloud,
|
||||
add_received_file_document_using_unstructured,
|
||||
)
|
||||
from .extension_processor import add_extension_received_document
|
||||
|
||||
# Markdown processor
|
||||
from .markdown_processor import add_received_markdown_file_document
|
||||
|
|
@ -32,9 +30,9 @@ from .youtube_processor import add_youtube_video_document
|
|||
__all__ = [
|
||||
# Extension processing
|
||||
"add_extension_received_document",
|
||||
# File processing with different ETL services
|
||||
"add_received_file_document_using_docling",
|
||||
"add_received_file_document_using_llamacloud",
|
||||
# File processing with different ETL services
|
||||
"add_received_file_document_using_unstructured",
|
||||
# Markdown file processing
|
||||
"add_received_markdown_file_document",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,74 @@
|
|||
"""
|
||||
Constants for file document processing.
|
||||
|
||||
Centralizes file type classification, LlamaCloud retry configuration,
|
||||
and timeout calculation parameters.
|
||||
"""
|
||||
|
||||
import ssl
|
||||
from enum import Enum
|
||||
|
||||
import httpx
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File type classification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MARKDOWN_EXTENSIONS = (".md", ".markdown", ".txt")
|
||||
AUDIO_EXTENSIONS = (".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")
|
||||
DIRECT_CONVERT_EXTENSIONS = (".csv", ".tsv", ".html", ".htm")
|
||||
|
||||
|
||||
class FileCategory(Enum):
|
||||
MARKDOWN = "markdown"
|
||||
AUDIO = "audio"
|
||||
DIRECT_CONVERT = "direct_convert"
|
||||
DOCUMENT = "document"
|
||||
|
||||
|
||||
def classify_file(filename: str) -> FileCategory:
|
||||
"""Classify a file by its extension into a processing category."""
|
||||
lower = filename.lower()
|
||||
if lower.endswith(MARKDOWN_EXTENSIONS):
|
||||
return FileCategory.MARKDOWN
|
||||
if lower.endswith(AUDIO_EXTENSIONS):
|
||||
return FileCategory.AUDIO
|
||||
if lower.endswith(DIRECT_CONVERT_EXTENSIONS):
|
||||
return FileCategory.DIRECT_CONVERT
|
||||
return FileCategory.DOCUMENT
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LlamaCloud retry configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
LLAMACLOUD_MAX_RETRIES = 5
|
||||
LLAMACLOUD_BASE_DELAY = 10 # seconds (exponential backoff base)
|
||||
LLAMACLOUD_MAX_DELAY = 120 # max delay between retries (2 minutes)
|
||||
LLAMACLOUD_RETRYABLE_EXCEPTIONS = (
|
||||
ssl.SSLError,
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.ReadError,
|
||||
httpx.ReadTimeout,
|
||||
httpx.WriteError,
|
||||
httpx.WriteTimeout,
|
||||
httpx.RemoteProtocolError,
|
||||
httpx.LocalProtocolError,
|
||||
ConnectionError,
|
||||
ConnectionResetError,
|
||||
TimeoutError,
|
||||
OSError,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Timeout calculation constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
UPLOAD_BYTES_PER_SECOND_SLOW = (
|
||||
100 * 1024
|
||||
) # 100 KB/s (conservative for slow connections)
|
||||
MIN_UPLOAD_TIMEOUT = 120 # Minimum 2 minutes for any file
|
||||
MAX_UPLOAD_TIMEOUT = 1800 # Maximum 30 minutes for very large files
|
||||
BASE_JOB_TIMEOUT = 600 # 10 minutes base for job processing
|
||||
PER_PAGE_JOB_TIMEOUT = 60 # 1 minute per page for processing
|
||||
|
|
@ -0,0 +1,90 @@
|
|||
"""
|
||||
Lossless file-to-markdown converters for text-based formats.
|
||||
|
||||
These converters handle file types that can be faithfully represented as
|
||||
markdown without any external ETL/OCR service:
|
||||
|
||||
- CSV / TSV → markdown table (stdlib ``csv``)
|
||||
- HTML / HTM → markdown (``markdownify``)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
from markdownify import markdownify
|
||||
|
||||
# The stdlib csv module defaults to a 128 KB field-size limit which is too
|
||||
# small for real-world exports (e.g. chat logs, CRM dumps). We raise it once
|
||||
# at import time so every csv.reader call in this module can handle large fields.
|
||||
csv.field_size_limit(2**31 - 1)
|
||||
|
||||
|
||||
def _escape_pipe(cell: str) -> str:
|
||||
"""Escape literal pipe characters inside a markdown table cell."""
|
||||
return cell.replace("|", "\\|")
|
||||
|
||||
|
||||
def csv_to_markdown(file_path: str, *, delimiter: str = ",") -> str:
|
||||
"""Convert a CSV (or TSV) file to a markdown table.
|
||||
|
||||
The first row is treated as the header. An empty file returns an
|
||||
empty string so the caller can decide how to handle it.
|
||||
"""
|
||||
with open(file_path, encoding="utf-8", newline="") as fh:
|
||||
reader = csv.reader(fh, delimiter=delimiter)
|
||||
rows = list(reader)
|
||||
|
||||
if not rows:
|
||||
return ""
|
||||
|
||||
header, *body = rows
|
||||
col_count = len(header)
|
||||
|
||||
lines: list[str] = []
|
||||
|
||||
header_cells = [_escape_pipe(c.strip()) for c in header]
|
||||
lines.append("| " + " | ".join(header_cells) + " |")
|
||||
lines.append("| " + " | ".join(["---"] * col_count) + " |")
|
||||
|
||||
for row in body:
|
||||
padded = row + [""] * (col_count - len(row))
|
||||
cells = [_escape_pipe(c.strip()) for c in padded[:col_count]]
|
||||
lines.append("| " + " | ".join(cells) + " |")
|
||||
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def tsv_to_markdown(file_path: str) -> str:
|
||||
"""Convert a TSV file to a markdown table."""
|
||||
return csv_to_markdown(file_path, delimiter="\t")
|
||||
|
||||
|
||||
def html_to_markdown(file_path: str) -> str:
|
||||
"""Convert an HTML file to markdown via ``markdownify``."""
|
||||
html = Path(file_path).read_text(encoding="utf-8")
|
||||
return markdownify(html).strip()
|
||||
|
||||
|
||||
_CONVERTER_MAP: dict[str, Callable[..., str]] = {
|
||||
".csv": csv_to_markdown,
|
||||
".tsv": tsv_to_markdown,
|
||||
".html": html_to_markdown,
|
||||
".htm": html_to_markdown,
|
||||
}
|
||||
|
||||
|
||||
def convert_file_directly(file_path: str, filename: str) -> str:
|
||||
"""Dispatch to the appropriate lossless converter based on file extension.
|
||||
|
||||
Raises ``ValueError`` if the extension is not supported.
|
||||
"""
|
||||
suffix = Path(filename).suffix.lower()
|
||||
converter = _CONVERTER_MAP.get(suffix)
|
||||
if converter is None:
|
||||
raise ValueError(
|
||||
f"No direct converter for extension '{suffix}' (file: {filename})"
|
||||
)
|
||||
return converter(file_path)
|
||||
209
surfsense_backend/app/tasks/document_processors/_etl.py
Normal file
209
surfsense_backend/app/tasks/document_processors/_etl.py
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
"""
|
||||
ETL parsing strategies for different document processing services.
|
||||
|
||||
Provides parse functions for Unstructured, LlamaCloud, and Docling, along with
|
||||
LlamaCloud retry logic and dynamic timeout calculations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from logging import ERROR, getLogger
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import Log
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
|
||||
from ._constants import (
|
||||
LLAMACLOUD_BASE_DELAY,
|
||||
LLAMACLOUD_MAX_DELAY,
|
||||
LLAMACLOUD_MAX_RETRIES,
|
||||
LLAMACLOUD_RETRYABLE_EXCEPTIONS,
|
||||
PER_PAGE_JOB_TIMEOUT,
|
||||
)
|
||||
from ._helpers import calculate_job_timeout, calculate_upload_timeout
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LlamaCloud parsing with retry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def parse_with_llamacloud_retry(
|
||||
file_path: str,
|
||||
estimated_pages: int,
|
||||
task_logger: TaskLoggingService | None = None,
|
||||
log_entry: Log | None = None,
|
||||
):
|
||||
"""
|
||||
Parse a file with LlamaCloud with retry logic for transient SSL/connection errors.
|
||||
|
||||
Uses dynamic timeout calculations based on file size and page count to handle
|
||||
very large files reliably.
|
||||
|
||||
Returns:
|
||||
LlamaParse result object
|
||||
|
||||
Raises:
|
||||
Exception: If all retries fail
|
||||
"""
|
||||
from llama_cloud_services import LlamaParse
|
||||
from llama_cloud_services.parse.utils import ResultType
|
||||
|
||||
file_size_bytes = os.path.getsize(file_path)
|
||||
file_size_mb = file_size_bytes / (1024 * 1024)
|
||||
|
||||
upload_timeout = calculate_upload_timeout(file_size_bytes)
|
||||
job_timeout = calculate_job_timeout(estimated_pages, file_size_bytes)
|
||||
|
||||
custom_timeout = httpx.Timeout(
|
||||
connect=120.0,
|
||||
read=upload_timeout,
|
||||
write=upload_timeout,
|
||||
pool=120.0,
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"LlamaCloud upload configured: file_size={file_size_mb:.1f}MB, "
|
||||
f"pages={estimated_pages}, upload_timeout={upload_timeout:.0f}s, "
|
||||
f"job_timeout={job_timeout:.0f}s"
|
||||
)
|
||||
|
||||
last_exception = None
|
||||
attempt_errors: list[str] = []
|
||||
|
||||
for attempt in range(1, LLAMACLOUD_MAX_RETRIES + 1):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=custom_timeout) as custom_client:
|
||||
parser = LlamaParse(
|
||||
api_key=app_config.LLAMA_CLOUD_API_KEY,
|
||||
num_workers=1,
|
||||
verbose=True,
|
||||
language="en",
|
||||
result_type=ResultType.MD,
|
||||
max_timeout=int(max(2000, job_timeout + upload_timeout)),
|
||||
job_timeout_in_seconds=job_timeout,
|
||||
job_timeout_extra_time_per_page_in_seconds=PER_PAGE_JOB_TIMEOUT,
|
||||
custom_client=custom_client,
|
||||
)
|
||||
result = await parser.aparse(file_path)
|
||||
|
||||
if attempt > 1:
|
||||
logging.info(
|
||||
f"LlamaCloud upload succeeded on attempt {attempt} after "
|
||||
f"{len(attempt_errors)} failures"
|
||||
)
|
||||
return result
|
||||
|
||||
except LLAMACLOUD_RETRYABLE_EXCEPTIONS as e:
|
||||
last_exception = e
|
||||
error_type = type(e).__name__
|
||||
error_msg = str(e)[:200]
|
||||
attempt_errors.append(f"Attempt {attempt}: {error_type} - {error_msg}")
|
||||
|
||||
if attempt < LLAMACLOUD_MAX_RETRIES:
|
||||
base_delay = min(
|
||||
LLAMACLOUD_BASE_DELAY * (2 ** (attempt - 1)),
|
||||
LLAMACLOUD_MAX_DELAY,
|
||||
)
|
||||
jitter = base_delay * 0.25 * (2 * random.random() - 1)
|
||||
delay = base_delay + jitter
|
||||
|
||||
if task_logger and log_entry:
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"LlamaCloud upload failed "
|
||||
f"(attempt {attempt}/{LLAMACLOUD_MAX_RETRIES}), "
|
||||
f"retrying in {delay:.0f}s",
|
||||
{
|
||||
"error_type": error_type,
|
||||
"error_message": error_msg,
|
||||
"attempt": attempt,
|
||||
"retry_delay": delay,
|
||||
"file_size_mb": round(file_size_mb, 1),
|
||||
"upload_timeout": upload_timeout,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
f"LlamaCloud upload failed "
|
||||
f"(attempt {attempt}/{LLAMACLOUD_MAX_RETRIES}): "
|
||||
f"{error_type}. File: {file_size_mb:.1f}MB. "
|
||||
f"Retrying in {delay:.0f}s..."
|
||||
)
|
||||
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logging.error(
|
||||
f"LlamaCloud upload failed after {LLAMACLOUD_MAX_RETRIES} "
|
||||
f"attempts. File size: {file_size_mb:.1f}MB, "
|
||||
f"Pages: {estimated_pages}. "
|
||||
f"Errors: {'; '.join(attempt_errors)}"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
raise last_exception or RuntimeError(
|
||||
f"LlamaCloud parsing failed after {LLAMACLOUD_MAX_RETRIES} retries. "
|
||||
f"File size: {file_size_mb:.1f}MB"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-service parse functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def parse_with_unstructured(file_path: str):
|
||||
"""
|
||||
Parse a file using the Unstructured ETL service.
|
||||
|
||||
Returns:
|
||||
List of LangChain Document elements.
|
||||
"""
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
|
||||
loader = UnstructuredLoader(
|
||||
file_path,
|
||||
mode="elements",
|
||||
post_processors=[],
|
||||
languages=["eng"],
|
||||
include_orig_elements=False,
|
||||
include_metadata=False,
|
||||
strategy="auto",
|
||||
)
|
||||
return await loader.aload()
|
||||
|
||||
|
||||
async def parse_with_docling(file_path: str, filename: str) -> str:
|
||||
"""
|
||||
Parse a file using the Docling ETL service (via the Docling service wrapper).
|
||||
|
||||
Returns:
|
||||
Markdown content string.
|
||||
"""
|
||||
from app.services.docling_service import create_docling_service
|
||||
|
||||
docling_service = create_docling_service()
|
||||
|
||||
pdfminer_logger = getLogger("pdfminer")
|
||||
original_level = pdfminer_logger.level
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="pdfminer")
|
||||
warnings.filterwarnings(
|
||||
"ignore", message=".*Cannot set gray non-stroke color.*"
|
||||
)
|
||||
warnings.filterwarnings("ignore", message=".*invalid float value.*")
|
||||
pdfminer_logger.setLevel(ERROR)
|
||||
|
||||
try:
|
||||
result = await docling_service.process_document(file_path, filename)
|
||||
finally:
|
||||
pdfminer_logger.setLevel(original_level)
|
||||
|
||||
return result["content"]
|
||||
218
surfsense_backend/app/tasks/document_processors/_helpers.py
Normal file
218
surfsense_backend/app/tasks/document_processors/_helpers.py
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
"""
|
||||
Document helper functions for deduplication, migration, and connector updates.
|
||||
|
||||
Provides reusable logic shared across file processors and ETL strategies.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.utils.document_converters import generate_unique_identifier_hash
|
||||
|
||||
from ._constants import (
|
||||
BASE_JOB_TIMEOUT,
|
||||
MAX_UPLOAD_TIMEOUT,
|
||||
MIN_UPLOAD_TIMEOUT,
|
||||
PER_PAGE_JOB_TIMEOUT,
|
||||
UPLOAD_BYTES_PER_SECOND_SLOW,
|
||||
)
|
||||
from .base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unique identifier helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_google_drive_unique_identifier(
|
||||
connector: dict | None,
|
||||
filename: str,
|
||||
search_space_id: int,
|
||||
) -> tuple[str, str | None]:
|
||||
"""
|
||||
Get unique identifier hash, using file_id for Google Drive (stable across renames).
|
||||
|
||||
Returns:
|
||||
Tuple of (primary_hash, legacy_hash or None).
|
||||
For Google Drive: (file_id-based hash, filename-based hash for migration).
|
||||
For other sources: (filename-based hash, None).
|
||||
"""
|
||||
if connector and connector.get("type") == DocumentType.GOOGLE_DRIVE_FILE:
|
||||
metadata = connector.get("metadata", {})
|
||||
file_id = metadata.get("google_drive_file_id")
|
||||
|
||||
if file_id:
|
||||
primary_hash = generate_unique_identifier_hash(
|
||||
DocumentType.GOOGLE_DRIVE_FILE, file_id, search_space_id
|
||||
)
|
||||
legacy_hash = generate_unique_identifier_hash(
|
||||
DocumentType.GOOGLE_DRIVE_FILE, filename, search_space_id
|
||||
)
|
||||
return primary_hash, legacy_hash
|
||||
|
||||
primary_hash = generate_unique_identifier_hash(
|
||||
DocumentType.FILE, filename, search_space_id
|
||||
)
|
||||
return primary_hash, None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Document deduplication and migration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def handle_existing_document_update(
|
||||
session: AsyncSession,
|
||||
existing_document: Document,
|
||||
content_hash: str,
|
||||
connector: dict | None,
|
||||
filename: str,
|
||||
primary_hash: str,
|
||||
) -> tuple[bool, Document | None]:
|
||||
"""
|
||||
Handle update logic for an existing document.
|
||||
|
||||
Returns:
|
||||
Tuple of (should_skip_processing, document_to_return):
|
||||
- (True, document): Content unchanged, return existing document
|
||||
- (False, None): Content changed, needs re-processing
|
||||
"""
|
||||
if existing_document.unique_identifier_hash != primary_hash:
|
||||
existing_document.unique_identifier_hash = primary_hash
|
||||
logging.info(f"Migrated document to file_id-based identifier: {filename}")
|
||||
|
||||
if existing_document.content_hash == content_hash:
|
||||
if connector and connector.get("type") == DocumentType.GOOGLE_DRIVE_FILE:
|
||||
connector_metadata = connector.get("metadata", {})
|
||||
new_name = connector_metadata.get("google_drive_file_name")
|
||||
doc_metadata = existing_document.document_metadata or {}
|
||||
old_name = doc_metadata.get("FILE_NAME") or doc_metadata.get(
|
||||
"google_drive_file_name"
|
||||
)
|
||||
|
||||
if new_name and old_name and old_name != new_name:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
existing_document.title = new_name
|
||||
if not existing_document.document_metadata:
|
||||
existing_document.document_metadata = {}
|
||||
existing_document.document_metadata["FILE_NAME"] = new_name
|
||||
existing_document.document_metadata["google_drive_file_name"] = new_name
|
||||
flag_modified(existing_document, "document_metadata")
|
||||
await session.commit()
|
||||
logging.info(
|
||||
f"File renamed in Google Drive: '{old_name}' → '{new_name}' "
|
||||
f"(no re-processing needed)"
|
||||
)
|
||||
|
||||
logging.info(f"Document for file {filename} unchanged. Skipping.")
|
||||
return True, existing_document
|
||||
|
||||
# Content has changed — guard against content_hash collision before
|
||||
# expensive ETL processing.
|
||||
collision_doc = await check_duplicate_document(session, content_hash)
|
||||
if collision_doc and collision_doc.id != existing_document.id:
|
||||
logging.warning(
|
||||
"Content-hash collision for %s: identical content exists in "
|
||||
"document #%s (%s). Skipping re-processing.",
|
||||
filename,
|
||||
collision_doc.id,
|
||||
collision_doc.document_type,
|
||||
)
|
||||
if DocumentStatus.is_state(
|
||||
existing_document.status, DocumentStatus.PENDING
|
||||
) or DocumentStatus.is_state(
|
||||
existing_document.status, DocumentStatus.PROCESSING
|
||||
):
|
||||
await session.delete(existing_document)
|
||||
await session.commit()
|
||||
return True, None
|
||||
|
||||
return True, existing_document
|
||||
|
||||
logging.info(f"Content changed for file {filename}. Updating document.")
|
||||
return False, None
|
||||
|
||||
|
||||
async def find_existing_document_with_migration(
|
||||
session: AsyncSession,
|
||||
primary_hash: str,
|
||||
legacy_hash: str | None,
|
||||
content_hash: str | None = None,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Find existing document, checking primary hash, legacy hash, and content_hash.
|
||||
|
||||
Supports migration from filename-based to file_id-based hashing for
|
||||
Google Drive files, with content_hash fallback for cross-source dedup.
|
||||
"""
|
||||
existing_document = await check_document_by_unique_identifier(session, primary_hash)
|
||||
|
||||
if not existing_document and legacy_hash:
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, legacy_hash
|
||||
)
|
||||
if existing_document:
|
||||
logging.info(
|
||||
"Found legacy document (filename-based hash), "
|
||||
"will migrate to file_id-based hash"
|
||||
)
|
||||
|
||||
if not existing_document and content_hash:
|
||||
existing_document = await check_duplicate_document(session, content_hash)
|
||||
if existing_document:
|
||||
logging.info(
|
||||
f"Found duplicate content from different source (content_hash match). "
|
||||
f"Original document ID: {existing_document.id}, "
|
||||
f"type: {existing_document.document_type}"
|
||||
)
|
||||
|
||||
return existing_document
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Connector helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def update_document_from_connector(
|
||||
document: Document | None,
|
||||
connector: dict | None,
|
||||
session: AsyncSession,
|
||||
) -> None:
|
||||
"""Update document type, metadata, and connector_id from connector info."""
|
||||
if not document or not connector:
|
||||
return
|
||||
if "type" in connector:
|
||||
document.document_type = connector["type"]
|
||||
if "metadata" in connector:
|
||||
if not document.document_metadata:
|
||||
document.document_metadata = connector["metadata"]
|
||||
else:
|
||||
merged = {**document.document_metadata, **connector["metadata"]}
|
||||
document.document_metadata = merged
|
||||
if "connector_id" in connector:
|
||||
document.connector_id = connector["connector_id"]
|
||||
await session.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Timeout calculations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def calculate_upload_timeout(file_size_bytes: int) -> float:
|
||||
"""Calculate upload timeout based on file size (conservative for slow connections)."""
|
||||
estimated_time = (file_size_bytes / UPLOAD_BYTES_PER_SECOND_SLOW) * 1.5
|
||||
return max(MIN_UPLOAD_TIMEOUT, min(estimated_time, MAX_UPLOAD_TIMEOUT))
|
||||
|
||||
|
||||
def calculate_job_timeout(estimated_pages: int, file_size_bytes: int) -> float:
|
||||
"""Calculate job processing timeout based on page count and file size."""
|
||||
page_based_timeout = BASE_JOB_TIMEOUT + (estimated_pages * PER_PAGE_JOB_TIMEOUT)
|
||||
size_based_timeout = BASE_JOB_TIMEOUT + (file_size_bytes / (10 * 1024 * 1024)) * 60
|
||||
return max(page_based_timeout, size_based_timeout)
|
||||
285
surfsense_backend/app/tasks/document_processors/_save.py
Normal file
285
surfsense_backend/app/tasks/document_processors/_save.py
Normal file
|
|
@ -0,0 +1,285 @@
|
|||
"""
|
||||
Unified document save/update logic for file processors.
|
||||
|
||||
Replaces the three nearly-identical ``add_received_file_document_using_*``
|
||||
functions with a single ``save_file_document`` function plus thin wrappers
|
||||
for backward compatibility.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from langchain_core.documents import Document as LangChainDocument
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentStatus, DocumentType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
embed_text,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
)
|
||||
|
||||
from ._helpers import (
|
||||
find_existing_document_with_migration,
|
||||
get_google_drive_unique_identifier,
|
||||
handle_existing_document_update,
|
||||
)
|
||||
from .base import get_current_timestamp, safe_set_chunks
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Summary generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _generate_summary(
|
||||
markdown_content: str,
|
||||
file_name: str,
|
||||
etl_service: str,
|
||||
user_llm,
|
||||
enable_summary: bool,
|
||||
) -> tuple[str, list[float]]:
|
||||
"""
|
||||
Generate a document summary and embedding.
|
||||
|
||||
Docling uses its own large-document summary strategy; other ETL services
|
||||
use the standard ``generate_document_summary`` helper.
|
||||
"""
|
||||
if not enable_summary:
|
||||
summary = f"File: {file_name}\n\n{markdown_content[:4000]}"
|
||||
return summary, embed_text(summary)
|
||||
|
||||
if etl_service == "DOCLING":
|
||||
from app.services.docling_service import create_docling_service
|
||||
|
||||
docling_service = create_docling_service()
|
||||
summary_text = await docling_service.process_large_document_summary(
|
||||
content=markdown_content, llm=user_llm, document_title=file_name
|
||||
)
|
||||
|
||||
meta = {
|
||||
"file_name": file_name,
|
||||
"etl_service": etl_service,
|
||||
"document_type": "File Document",
|
||||
}
|
||||
parts = ["# DOCUMENT METADATA"]
|
||||
for key, value in meta.items():
|
||||
if value:
|
||||
formatted_key = key.replace("_", " ").title()
|
||||
parts.append(f"**{formatted_key}:** {value}")
|
||||
|
||||
enhanced = "\n".join(parts) + "\n\n# DOCUMENT SUMMARY\n\n" + summary_text
|
||||
return enhanced, embed_text(enhanced)
|
||||
|
||||
# Standard summary (Unstructured / LlamaCloud / others)
|
||||
meta = {
|
||||
"file_name": file_name,
|
||||
"etl_service": etl_service,
|
||||
"document_type": "File Document",
|
||||
}
|
||||
return await generate_document_summary(markdown_content, user_llm, meta)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unified save function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def save_file_document(
|
||||
session: AsyncSession,
|
||||
file_name: str,
|
||||
markdown_content: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
etl_service: str,
|
||||
connector: dict | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Process and store a file document with deduplication and migration support.
|
||||
|
||||
Handles both creating new documents and updating existing ones. This is
|
||||
the single implementation behind the per-ETL-service wrapper functions.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
file_name: Name of the processed file
|
||||
markdown_content: Markdown content to store
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
etl_service: Name of the ETL service (UNSTRUCTURED, LLAMACLOUD, DOCLING)
|
||||
connector: Optional connector info for Google Drive files
|
||||
enable_summary: Whether to generate an AI summary
|
||||
|
||||
Returns:
|
||||
Document object if successful, None if duplicate detected
|
||||
"""
|
||||
try:
|
||||
primary_hash, legacy_hash = get_google_drive_unique_identifier(
|
||||
connector, file_name, search_space_id
|
||||
)
|
||||
content_hash = generate_content_hash(markdown_content, search_space_id)
|
||||
|
||||
existing_document = await find_existing_document_with_migration(
|
||||
session, primary_hash, legacy_hash, content_hash
|
||||
)
|
||||
|
||||
if existing_document:
|
||||
should_skip, doc = await handle_existing_document_update(
|
||||
session,
|
||||
existing_document,
|
||||
content_hash,
|
||||
connector,
|
||||
file_name,
|
||||
primary_hash,
|
||||
)
|
||||
if should_skip:
|
||||
return doc
|
||||
|
||||
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
|
||||
if not user_llm:
|
||||
raise RuntimeError(
|
||||
f"No long context LLM configured for user {user_id} "
|
||||
f"in search space {search_space_id}"
|
||||
)
|
||||
|
||||
summary_content, summary_embedding = await _generate_summary(
|
||||
markdown_content, file_name, etl_service, user_llm, enable_summary
|
||||
)
|
||||
chunks = await create_document_chunks(markdown_content)
|
||||
doc_metadata = {"FILE_NAME": file_name, "ETL_SERVICE": etl_service}
|
||||
|
||||
if existing_document:
|
||||
existing_document.title = file_name
|
||||
existing_document.content = summary_content
|
||||
existing_document.content_hash = content_hash
|
||||
existing_document.embedding = summary_embedding
|
||||
existing_document.document_metadata = doc_metadata
|
||||
await safe_set_chunks(session, existing_document, chunks)
|
||||
existing_document.source_markdown = markdown_content
|
||||
existing_document.content_needs_reindexing = False
|
||||
existing_document.updated_at = get_current_timestamp()
|
||||
existing_document.status = DocumentStatus.ready()
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(existing_document)
|
||||
return existing_document
|
||||
|
||||
doc_type = DocumentType.FILE
|
||||
if connector and connector.get("type") == DocumentType.GOOGLE_DRIVE_FILE:
|
||||
doc_type = DocumentType.GOOGLE_DRIVE_FILE
|
||||
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=file_name,
|
||||
document_type=doc_type,
|
||||
document_metadata=doc_metadata,
|
||||
content=summary_content,
|
||||
embedding=summary_embedding,
|
||||
chunks=chunks,
|
||||
content_hash=content_hash,
|
||||
unique_identifier_hash=primary_hash,
|
||||
source_markdown=markdown_content,
|
||||
content_needs_reindexing=False,
|
||||
updated_at=get_current_timestamp(),
|
||||
created_by_id=user_id,
|
||||
connector_id=connector.get("connector_id") if connector else None,
|
||||
status=DocumentStatus.ready(),
|
||||
)
|
||||
session.add(document)
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
return document
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
if "ix_documents_content_hash" in str(db_error):
|
||||
logging.warning(
|
||||
"content_hash collision during commit for %s (%s). Skipping.",
|
||||
file_name,
|
||||
etl_service,
|
||||
)
|
||||
return None
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise RuntimeError(
|
||||
f"Failed to process file document using {etl_service}: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backward-compatible wrapper functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def add_received_file_document_using_unstructured(
|
||||
session: AsyncSession,
|
||||
file_name: str,
|
||||
unstructured_processed_elements: list[LangChainDocument],
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
connector: dict | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> Document | None:
|
||||
"""Process and store a file document using the Unstructured service."""
|
||||
from app.utils.document_converters import convert_document_to_markdown
|
||||
|
||||
markdown_content = await convert_document_to_markdown(
|
||||
unstructured_processed_elements
|
||||
)
|
||||
return await save_file_document(
|
||||
session,
|
||||
file_name,
|
||||
markdown_content,
|
||||
search_space_id,
|
||||
user_id,
|
||||
"UNSTRUCTURED",
|
||||
connector,
|
||||
enable_summary,
|
||||
)
|
||||
|
||||
|
||||
async def add_received_file_document_using_llamacloud(
|
||||
session: AsyncSession,
|
||||
file_name: str,
|
||||
llamacloud_markdown_document: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
connector: dict | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> Document | None:
|
||||
"""Process and store document content parsed by LlamaCloud."""
|
||||
return await save_file_document(
|
||||
session,
|
||||
file_name,
|
||||
llamacloud_markdown_document,
|
||||
search_space_id,
|
||||
user_id,
|
||||
"LLAMACLOUD",
|
||||
connector,
|
||||
enable_summary,
|
||||
)
|
||||
|
||||
|
||||
async def add_received_file_document_using_docling(
|
||||
session: AsyncSession,
|
||||
file_name: str,
|
||||
docling_markdown_document: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
connector: dict | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> Document | None:
|
||||
"""Process and store document content parsed by Docling."""
|
||||
return await save_file_document(
|
||||
session,
|
||||
file_name,
|
||||
docling_markdown_document,
|
||||
search_space_id,
|
||||
user_id,
|
||||
"DOCLING",
|
||||
connector,
|
||||
enable_summary,
|
||||
)
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -14,88 +14,19 @@ from app.utils.document_converters import (
|
|||
create_document_chunks,
|
||||
generate_content_hash,
|
||||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
|
||||
from ._helpers import (
|
||||
find_existing_document_with_migration,
|
||||
get_google_drive_unique_identifier,
|
||||
)
|
||||
from .base import (
|
||||
check_document_by_unique_identifier,
|
||||
check_duplicate_document,
|
||||
get_current_timestamp,
|
||||
safe_set_chunks,
|
||||
)
|
||||
|
||||
|
||||
def _get_google_drive_unique_identifier(
|
||||
connector: dict | None,
|
||||
filename: str,
|
||||
search_space_id: int,
|
||||
) -> tuple[str, str | None]:
|
||||
"""
|
||||
Get unique identifier hash for a file, with special handling for Google Drive.
|
||||
|
||||
For Google Drive files, uses file_id as the unique identifier (doesn't change on rename).
|
||||
For other files, uses filename.
|
||||
|
||||
Args:
|
||||
connector: Optional connector info dict with type and metadata
|
||||
filename: The filename (used for non-Google Drive files or as fallback)
|
||||
search_space_id: The search space ID
|
||||
|
||||
Returns:
|
||||
Tuple of (primary_hash, legacy_hash or None)
|
||||
"""
|
||||
if connector and connector.get("type") == DocumentType.GOOGLE_DRIVE_FILE:
|
||||
metadata = connector.get("metadata", {})
|
||||
file_id = metadata.get("google_drive_file_id")
|
||||
|
||||
if file_id:
|
||||
primary_hash = generate_unique_identifier_hash(
|
||||
DocumentType.GOOGLE_DRIVE_FILE, file_id, search_space_id
|
||||
)
|
||||
legacy_hash = generate_unique_identifier_hash(
|
||||
DocumentType.GOOGLE_DRIVE_FILE, filename, search_space_id
|
||||
)
|
||||
return primary_hash, legacy_hash
|
||||
|
||||
primary_hash = generate_unique_identifier_hash(
|
||||
DocumentType.FILE, filename, search_space_id
|
||||
)
|
||||
return primary_hash, None
|
||||
|
||||
|
||||
async def _find_existing_document_with_migration(
|
||||
session: AsyncSession,
|
||||
primary_hash: str,
|
||||
legacy_hash: str | None,
|
||||
content_hash: str | None = None,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Find existing document, checking both new hash and legacy hash for migration,
|
||||
with fallback to content_hash for cross-source deduplication.
|
||||
"""
|
||||
existing_document = await check_document_by_unique_identifier(session, primary_hash)
|
||||
|
||||
if not existing_document and legacy_hash:
|
||||
existing_document = await check_document_by_unique_identifier(
|
||||
session, legacy_hash
|
||||
)
|
||||
if existing_document:
|
||||
logging.info(
|
||||
"Found legacy document (filename-based hash), will migrate to file_id-based hash"
|
||||
)
|
||||
|
||||
# Fallback: check by content_hash to catch duplicates from different sources
|
||||
if not existing_document and content_hash:
|
||||
existing_document = await check_duplicate_document(session, content_hash)
|
||||
if existing_document:
|
||||
logging.info(
|
||||
f"Found duplicate content from different source (content_hash match). "
|
||||
f"Original document ID: {existing_document.id}, type: {existing_document.document_type}"
|
||||
)
|
||||
|
||||
return existing_document
|
||||
|
||||
|
||||
async def _handle_existing_document_update(
|
||||
session: AsyncSession,
|
||||
existing_document: Document,
|
||||
|
|
@ -224,7 +155,7 @@ async def add_received_markdown_file_document(
|
|||
|
||||
try:
|
||||
# Generate unique identifier hash (uses file_id for Google Drive, filename for others)
|
||||
primary_hash, legacy_hash = _get_google_drive_unique_identifier(
|
||||
primary_hash, legacy_hash = get_google_drive_unique_identifier(
|
||||
connector, file_name, search_space_id
|
||||
)
|
||||
|
||||
|
|
@ -232,7 +163,7 @@ async def add_received_markdown_file_document(
|
|||
content_hash = generate_content_hash(file_in_markdown, search_space_id)
|
||||
|
||||
# Check if document exists (with migration support for Google Drive and content_hash fallback)
|
||||
existing_document = await _find_existing_document_with_migration(
|
||||
existing_document = await find_existing_document_with_migration(
|
||||
session, primary_hash, legacy_hash, content_hash
|
||||
)
|
||||
|
||||
|
|
|
|||
107
surfsense_backend/app/utils/document_versioning.py
Normal file
107
surfsense_backend/app/utils/document_versioning.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
"""Document versioning: snapshot creation and cleanup.
|
||||
|
||||
Rules:
|
||||
- 30-minute debounce window: if the latest version was created < 30 min ago,
|
||||
overwrite it instead of creating a new row.
|
||||
- Maximum 20 versions per document.
|
||||
- Versions older than 90 days are cleaned up.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentVersion
|
||||
|
||||
MAX_VERSIONS_PER_DOCUMENT = 20
|
||||
DEBOUNCE_MINUTES = 30
|
||||
RETENTION_DAYS = 90
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
async def create_version_snapshot(
|
||||
session: AsyncSession,
|
||||
document: Document,
|
||||
) -> DocumentVersion | None:
|
||||
"""Snapshot the document's current state into a DocumentVersion row.
|
||||
|
||||
Returns the created/updated DocumentVersion, or None if nothing was done.
|
||||
"""
|
||||
now = _now()
|
||||
|
||||
latest = (
|
||||
await session.execute(
|
||||
select(DocumentVersion)
|
||||
.where(DocumentVersion.document_id == document.id)
|
||||
.order_by(DocumentVersion.version_number.desc())
|
||||
.limit(1)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if latest is not None:
|
||||
age = now - latest.created_at.replace(tzinfo=UTC)
|
||||
if age < timedelta(minutes=DEBOUNCE_MINUTES):
|
||||
latest.source_markdown = document.source_markdown
|
||||
latest.content_hash = document.content_hash
|
||||
latest.title = document.title
|
||||
latest.created_at = now
|
||||
await session.flush()
|
||||
return latest
|
||||
|
||||
max_num = (
|
||||
await session.execute(
|
||||
select(func.coalesce(func.max(DocumentVersion.version_number), 0)).where(
|
||||
DocumentVersion.document_id == document.id
|
||||
)
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
version = DocumentVersion(
|
||||
document_id=document.id,
|
||||
version_number=max_num + 1,
|
||||
source_markdown=document.source_markdown,
|
||||
content_hash=document.content_hash,
|
||||
title=document.title,
|
||||
created_at=now,
|
||||
)
|
||||
session.add(version)
|
||||
await session.flush()
|
||||
|
||||
# Cleanup: remove versions older than 90 days
|
||||
cutoff = now - timedelta(days=RETENTION_DAYS)
|
||||
await session.execute(
|
||||
delete(DocumentVersion).where(
|
||||
DocumentVersion.document_id == document.id,
|
||||
DocumentVersion.created_at < cutoff,
|
||||
)
|
||||
)
|
||||
|
||||
# Cleanup: cap at MAX_VERSIONS_PER_DOCUMENT
|
||||
count = (
|
||||
await session.execute(
|
||||
select(func.count())
|
||||
.select_from(DocumentVersion)
|
||||
.where(DocumentVersion.document_id == document.id)
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
if count > MAX_VERSIONS_PER_DOCUMENT:
|
||||
excess = count - MAX_VERSIONS_PER_DOCUMENT
|
||||
oldest_ids_result = await session.execute(
|
||||
select(DocumentVersion.id)
|
||||
.where(DocumentVersion.document_id == document.id)
|
||||
.order_by(DocumentVersion.version_number.asc())
|
||||
.limit(excess)
|
||||
)
|
||||
oldest_ids = [row[0] for row in oldest_ids_result.all()]
|
||||
if oldest_ids:
|
||||
await session.execute(
|
||||
delete(DocumentVersion).where(DocumentVersion.id.in_(oldest_ids))
|
||||
)
|
||||
|
||||
await session.flush()
|
||||
return version
|
||||
|
|
@ -2,12 +2,11 @@
|
|||
Integration tests for backend file upload limit enforcement.
|
||||
|
||||
These tests verify that the API rejects uploads that exceed:
|
||||
- Max files per upload (10)
|
||||
- Max per-file size (50 MB)
|
||||
- Max total upload size (200 MB)
|
||||
- Max per-file size (500 MB)
|
||||
|
||||
The limits mirror the frontend's DocumentUploadTab.tsx constants and are
|
||||
enforced server-side to protect against direct API calls.
|
||||
No file count or total size limits are enforced — the frontend batches
|
||||
uploads in groups of 5 and there is no cap on how many files a user can
|
||||
upload in a single session.
|
||||
|
||||
Prerequisites:
|
||||
- PostgreSQL + pgvector
|
||||
|
|
@ -24,60 +23,12 @@ pytestmark = pytest.mark.integration
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test A: File count limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFileCountLimit:
|
||||
"""Uploading more than 10 files in a single request should be rejected."""
|
||||
|
||||
async def test_11_files_returns_413(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
):
|
||||
files = [
|
||||
("files", (f"file_{i}.txt", io.BytesIO(b"test content"), "text/plain"))
|
||||
for i in range(11)
|
||||
]
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
headers=headers,
|
||||
files=files,
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
assert resp.status_code == 413
|
||||
assert "too many files" in resp.json()["detail"].lower()
|
||||
|
||||
async def test_10_files_accepted(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
files = [
|
||||
("files", (f"file_{i}.txt", io.BytesIO(b"test content"), "text/plain"))
|
||||
for i in range(10)
|
||||
]
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
headers=headers,
|
||||
files=files,
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
cleanup_doc_ids.extend(resp.json().get("document_ids", []))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test B: Per-file size limit
|
||||
# Test: Per-file size limit (500 MB)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPerFileSizeLimit:
|
||||
"""A single file exceeding 50 MB should be rejected."""
|
||||
"""A single file exceeding 500 MB should be rejected."""
|
||||
|
||||
async def test_oversized_file_returns_413(
|
||||
self,
|
||||
|
|
@ -85,7 +36,7 @@ class TestPerFileSizeLimit:
|
|||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
):
|
||||
oversized = io.BytesIO(b"\x00" * (50 * 1024 * 1024 + 1))
|
||||
oversized = io.BytesIO(b"\x00" * (500 * 1024 * 1024 + 1))
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
headers=headers,
|
||||
|
|
@ -102,11 +53,11 @@ class TestPerFileSizeLimit:
|
|||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
at_limit = io.BytesIO(b"\x00" * (50 * 1024 * 1024))
|
||||
at_limit = io.BytesIO(b"\x00" * (500 * 1024 * 1024))
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
headers=headers,
|
||||
files=[("files", ("exact50mb.txt", at_limit, "text/plain"))],
|
||||
files=[("files", ("exact500mb.txt", at_limit, "text/plain"))],
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
|
@ -114,26 +65,23 @@ class TestPerFileSizeLimit:
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test C: Total upload size limit
|
||||
# Test: Multiple files accepted without count limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTotalSizeLimit:
|
||||
"""Multiple files whose combined size exceeds 200 MB should be rejected."""
|
||||
class TestNoFileCountLimit:
|
||||
"""Many files in a single request should be accepted."""
|
||||
|
||||
async def test_total_size_over_200mb_returns_413(
|
||||
async def test_many_files_accepted(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
headers: dict[str, str],
|
||||
search_space_id: int,
|
||||
cleanup_doc_ids: list[int],
|
||||
):
|
||||
chunk_size = 45 * 1024 * 1024 # 45 MB each
|
||||
files = [
|
||||
(
|
||||
"files",
|
||||
(f"chunk_{i}.txt", io.BytesIO(b"\x00" * chunk_size), "text/plain"),
|
||||
)
|
||||
for i in range(5) # 5 x 45 MB = 225 MB > 200 MB
|
||||
("files", (f"file_{i}.txt", io.BytesIO(b"test content"), "text/plain"))
|
||||
for i in range(20)
|
||||
]
|
||||
resp = await client.post(
|
||||
"/api/v1/documents/fileupload",
|
||||
|
|
@ -141,5 +89,5 @@ class TestTotalSizeLimit:
|
|||
files=files,
|
||||
data={"search_space_id": str(search_space_id)},
|
||||
)
|
||||
assert resp.status_code == 413
|
||||
assert "total upload size" in resp.json()["detail"].lower()
|
||||
assert resp.status_code == 200
|
||||
cleanup_doc_ids.extend(resp.json().get("document_ids", []))
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
167
surfsense_backend/tests/integration/test_document_versioning.py
Normal file
167
surfsense_backend/tests/integration/test_document_versioning.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
"""Integration tests for document versioning snapshot + cleanup."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Document, DocumentType, DocumentVersion, SearchSpace, User
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_document(
|
||||
db_session: AsyncSession, db_user: User, db_search_space: SearchSpace
|
||||
) -> Document:
|
||||
doc = Document(
|
||||
title="Test Doc",
|
||||
document_type=DocumentType.LOCAL_FOLDER_FILE,
|
||||
document_metadata={},
|
||||
content="Summary of test doc.",
|
||||
content_hash="abc123",
|
||||
unique_identifier_hash="local_folder:test-folder:test.md",
|
||||
source_markdown="# Test\n\nOriginal content.",
|
||||
search_space_id=db_search_space.id,
|
||||
created_by_id=db_user.id,
|
||||
)
|
||||
db_session.add(doc)
|
||||
await db_session.flush()
|
||||
return doc
|
||||
|
||||
|
||||
async def _version_count(session: AsyncSession, document_id: int) -> int:
|
||||
result = await session.execute(
|
||||
select(func.count())
|
||||
.select_from(DocumentVersion)
|
||||
.where(DocumentVersion.document_id == document_id)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
|
||||
async def _get_versions(
|
||||
session: AsyncSession, document_id: int
|
||||
) -> list[DocumentVersion]:
|
||||
result = await session.execute(
|
||||
select(DocumentVersion)
|
||||
.where(DocumentVersion.document_id == document_id)
|
||||
.order_by(DocumentVersion.version_number)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
class TestCreateVersionSnapshot:
|
||||
"""V1-V5: TDD slices for create_version_snapshot."""
|
||||
|
||||
async def test_v1_creates_first_version(self, db_session, db_document):
|
||||
"""V1: First snapshot creates version 1 with the document's current state."""
|
||||
from app.utils.document_versioning import create_version_snapshot
|
||||
|
||||
await create_version_snapshot(db_session, db_document)
|
||||
|
||||
versions = await _get_versions(db_session, db_document.id)
|
||||
assert len(versions) == 1
|
||||
assert versions[0].version_number == 1
|
||||
assert versions[0].source_markdown == "# Test\n\nOriginal content."
|
||||
assert versions[0].content_hash == "abc123"
|
||||
assert versions[0].title == "Test Doc"
|
||||
assert versions[0].document_id == db_document.id
|
||||
|
||||
async def test_v2_creates_version_2_after_30_min(
|
||||
self, db_session, db_document, monkeypatch
|
||||
):
|
||||
"""V2: After 30+ minutes, a new version is created (not overwritten)."""
|
||||
from app.utils.document_versioning import create_version_snapshot
|
||||
|
||||
t0 = datetime(2025, 1, 1, 12, 0, 0, tzinfo=UTC)
|
||||
monkeypatch.setattr("app.utils.document_versioning._now", lambda: t0)
|
||||
await create_version_snapshot(db_session, db_document)
|
||||
|
||||
# Simulate content change and time passing
|
||||
db_document.source_markdown = "# Test\n\nUpdated content."
|
||||
db_document.content_hash = "def456"
|
||||
t1 = t0 + timedelta(minutes=31)
|
||||
monkeypatch.setattr("app.utils.document_versioning._now", lambda: t1)
|
||||
await create_version_snapshot(db_session, db_document)
|
||||
|
||||
versions = await _get_versions(db_session, db_document.id)
|
||||
assert len(versions) == 2
|
||||
assert versions[0].version_number == 1
|
||||
assert versions[1].version_number == 2
|
||||
assert versions[1].source_markdown == "# Test\n\nUpdated content."
|
||||
|
||||
async def test_v3_overwrites_within_30_min(
|
||||
self, db_session, db_document, monkeypatch
|
||||
):
|
||||
"""V3: Within 30 minutes, the latest version is overwritten."""
|
||||
from app.utils.document_versioning import create_version_snapshot
|
||||
|
||||
t0 = datetime(2025, 1, 1, 12, 0, 0, tzinfo=UTC)
|
||||
monkeypatch.setattr("app.utils.document_versioning._now", lambda: t0)
|
||||
await create_version_snapshot(db_session, db_document)
|
||||
count_after_first = await _version_count(db_session, db_document.id)
|
||||
assert count_after_first == 1
|
||||
|
||||
# Simulate quick edit within 30 minutes
|
||||
db_document.source_markdown = "# Test\n\nQuick edit."
|
||||
db_document.content_hash = "quick123"
|
||||
t1 = t0 + timedelta(minutes=10)
|
||||
monkeypatch.setattr("app.utils.document_versioning._now", lambda: t1)
|
||||
await create_version_snapshot(db_session, db_document)
|
||||
|
||||
count_after_second = await _version_count(db_session, db_document.id)
|
||||
assert count_after_second == 1 # still 1, not 2
|
||||
|
||||
versions = await _get_versions(db_session, db_document.id)
|
||||
assert versions[0].source_markdown == "# Test\n\nQuick edit."
|
||||
assert versions[0].content_hash == "quick123"
|
||||
|
||||
async def test_v4_cleanup_90_day_old_versions(
|
||||
self, db_session, db_document, monkeypatch
|
||||
):
|
||||
"""V4: Versions older than 90 days are cleaned up."""
|
||||
from app.utils.document_versioning import create_version_snapshot
|
||||
|
||||
base = datetime(2025, 1, 1, 12, 0, 0, tzinfo=UTC)
|
||||
|
||||
# Create 5 versions spread across time: 3 older than 90 days, 2 recent
|
||||
for i in range(5):
|
||||
db_document.source_markdown = f"Content v{i + 1}"
|
||||
db_document.content_hash = f"hash_{i + 1}"
|
||||
t = base + timedelta(days=i) if i < 3 else base + timedelta(days=100 + i)
|
||||
monkeypatch.setattr("app.utils.document_versioning._now", lambda _t=t: _t)
|
||||
await create_version_snapshot(db_session, db_document)
|
||||
|
||||
# Now trigger cleanup from a "current" time that makes the first 3 versions > 90 days old
|
||||
now = base + timedelta(days=200)
|
||||
monkeypatch.setattr("app.utils.document_versioning._now", lambda: now)
|
||||
db_document.source_markdown = "Content v6"
|
||||
db_document.content_hash = "hash_6"
|
||||
await create_version_snapshot(db_session, db_document)
|
||||
|
||||
versions = await _get_versions(db_session, db_document.id)
|
||||
# The first 3 (old) should be cleaned up; versions 4, 5, 6 remain
|
||||
for v in versions:
|
||||
age = now - v.created_at.replace(tzinfo=UTC)
|
||||
assert age <= timedelta(days=90), f"Version {v.version_number} is too old"
|
||||
|
||||
async def test_v5_cap_at_20_versions(self, db_session, db_document, monkeypatch):
|
||||
"""V5: More than 20 versions triggers cap — oldest gets deleted."""
|
||||
from app.utils.document_versioning import create_version_snapshot
|
||||
|
||||
base = datetime(2025, 6, 1, 12, 0, 0, tzinfo=UTC)
|
||||
|
||||
# Create 21 versions (all within 90 days, each 31 min apart)
|
||||
for i in range(21):
|
||||
db_document.source_markdown = f"Content v{i + 1}"
|
||||
db_document.content_hash = f"hash_{i + 1}"
|
||||
t = base + timedelta(minutes=31 * i)
|
||||
monkeypatch.setattr("app.utils.document_versioning._now", lambda _t=t: _t)
|
||||
await create_version_snapshot(db_session, db_document)
|
||||
|
||||
versions = await _get_versions(db_session, db_document.id)
|
||||
assert len(versions) == 20
|
||||
# The lowest version_number should be 2 (version 1 was the oldest and got capped)
|
||||
assert versions[0].version_number == 2
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
"""Unit tests for scan_folder() pure logic — Tier 2 TDD slices (S1-S4)."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
class TestScanFolder:
|
||||
"""S1-S4: scan_folder() with real tmp_path filesystem."""
|
||||
|
||||
def test_s1_single_md_file(self, tmp_path: Path):
|
||||
"""S1: scan_folder on a dir with one .md file returns correct entry."""
|
||||
from app.tasks.connector_indexers.local_folder_indexer import scan_folder
|
||||
|
||||
md = tmp_path / "note.md"
|
||||
md.write_text("# Hello")
|
||||
|
||||
results = scan_folder(str(tmp_path))
|
||||
|
||||
assert len(results) == 1
|
||||
entry = results[0]
|
||||
assert entry["relative_path"] == "note.md"
|
||||
assert entry["size"] > 0
|
||||
assert "modified_at" in entry
|
||||
assert entry["path"] == str(md)
|
||||
|
||||
def test_s2_extension_filter(self, tmp_path: Path):
|
||||
"""S2: file_extensions filter returns only matching files."""
|
||||
from app.tasks.connector_indexers.local_folder_indexer import scan_folder
|
||||
|
||||
(tmp_path / "a.md").write_text("md")
|
||||
(tmp_path / "b.txt").write_text("txt")
|
||||
(tmp_path / "c.pdf").write_bytes(b"%PDF")
|
||||
|
||||
results = scan_folder(str(tmp_path), file_extensions=[".md"])
|
||||
names = {r["relative_path"] for r in results}
|
||||
|
||||
assert names == {"a.md"}
|
||||
|
||||
def test_s3_exclude_patterns(self, tmp_path: Path):
|
||||
"""S3: exclude_patterns skips files inside excluded directories."""
|
||||
from app.tasks.connector_indexers.local_folder_indexer import scan_folder
|
||||
|
||||
(tmp_path / "good.md").write_text("good")
|
||||
nm = tmp_path / "node_modules"
|
||||
nm.mkdir()
|
||||
(nm / "dep.js").write_text("module")
|
||||
git = tmp_path / ".git"
|
||||
git.mkdir()
|
||||
(git / "config").write_text("gitconfig")
|
||||
|
||||
results = scan_folder(str(tmp_path), exclude_patterns=["node_modules", ".git"])
|
||||
names = {r["relative_path"] for r in results}
|
||||
|
||||
assert "good.md" in names
|
||||
assert not any("node_modules" in n for n in names)
|
||||
assert not any(".git" in n for n in names)
|
||||
|
||||
def test_s4_nested_dirs(self, tmp_path: Path):
|
||||
"""S4: nested subdirectories produce correct relative paths."""
|
||||
from app.tasks.connector_indexers.local_folder_indexer import scan_folder
|
||||
|
||||
daily = tmp_path / "notes" / "daily"
|
||||
daily.mkdir(parents=True)
|
||||
weekly = tmp_path / "notes" / "weekly"
|
||||
weekly.mkdir(parents=True)
|
||||
(daily / "today.md").write_text("today")
|
||||
(weekly / "review.md").write_text("review")
|
||||
(tmp_path / "root.txt").write_text("root")
|
||||
|
||||
results = scan_folder(str(tmp_path))
|
||||
paths = {r["relative_path"] for r in results}
|
||||
|
||||
assert "notes/daily/today.md" in paths or "notes\\daily\\today.md" in paths
|
||||
assert "notes/weekly/review.md" in paths or "notes\\weekly\\review.md" in paths
|
||||
assert "root.txt" in paths
|
||||
|
|
@ -248,7 +248,7 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
|||
return []
|
||||
|
||||
async def fake_build_scoped_filesystem(**kwargs):
|
||||
return {}
|
||||
return {}, {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
|
|
@ -298,7 +298,7 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
|||
return []
|
||||
|
||||
async def fake_build_scoped_filesystem(**kwargs):
|
||||
return {}
|
||||
return {}, {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
|
|
@ -334,7 +334,7 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
|||
return []
|
||||
|
||||
async def fake_build_scoped_filesystem(**kwargs):
|
||||
return {}
|
||||
return {}, {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue