chore: ran linting

This commit is contained in:
Anish Sarkar 2026-04-03 13:14:40 +05:30
parent 6ace8850bb
commit 746c730b2e
31 changed files with 801 additions and 660 deletions

View file

@ -166,5 +166,3 @@ def make_connector_document(db_connector, db_user):
return ConnectorDocument(**defaults)
return _make

View file

@ -21,7 +21,9 @@ from app.db import (
pytestmark = pytest.mark.integration
UNIFIED_FIXTURES = (
"patched_summarize", "patched_embed_texts", "patched_chunk_text",
"patched_summarize",
"patched_embed_texts",
"patched_chunk_text",
)
@ -37,6 +39,7 @@ class _FakeSessionMaker:
@asynccontextmanager
async def _ctx():
yield self._session
return _ctx()
@ -59,7 +62,6 @@ def patched_batch_sessions(monkeypatch, db_session):
class TestFullIndexer:
@pytest.mark.usefixtures(*UNIFIED_FIXTURES)
async def test_i1_new_file_indexed(
self,
@ -73,7 +75,7 @@ class TestFullIndexer:
(tmp_path / "note.md").write_text("# Hello World\n\nContent here.")
count, skipped, root_folder_id, err = await index_local_folder(
count, _skipped, _root_folder_id, err = await index_local_folder(
session=db_session,
search_space_id=db_search_space.id,
user_id=str(db_user.id),
@ -85,13 +87,17 @@ class TestFullIndexer:
assert count == 1
docs = (
await db_session.execute(
select(Document).where(
Document.document_type == DocumentType.LOCAL_FOLDER_FILE,
Document.search_space_id == db_search_space.id,
(
await db_session.execute(
select(Document).where(
Document.document_type == DocumentType.LOCAL_FOLDER_FILE,
Document.search_space_id == db_search_space.id,
)
)
)
).scalars().all()
.scalars()
.all()
)
assert len(docs) == 1
assert docs[0].document_type == DocumentType.LOCAL_FOLDER_FILE
assert DocumentStatus.is_state(docs[0].status, DocumentStatus.READY)
@ -130,7 +136,9 @@ class TestFullIndexer:
total = (
await db_session.execute(
select(func.count()).select_from(Document).where(
select(func.count())
.select_from(Document)
.where(
Document.document_type == DocumentType.LOCAL_FOLDER_FILE,
Document.search_space_id == db_search_space.id,
)
@ -174,13 +182,19 @@ class TestFullIndexer:
assert count == 1
versions = (
await db_session.execute(
select(DocumentVersion).join(Document).where(
Document.document_type == DocumentType.LOCAL_FOLDER_FILE,
Document.search_space_id == db_search_space.id,
(
await db_session.execute(
select(DocumentVersion)
.join(Document)
.where(
Document.document_type == DocumentType.LOCAL_FOLDER_FILE,
Document.search_space_id == db_search_space.id,
)
)
)
).scalars().all()
.scalars()
.all()
)
assert len(versions) >= 1
@pytest.mark.usefixtures(*UNIFIED_FIXTURES)
@ -207,7 +221,9 @@ class TestFullIndexer:
docs_before = (
await db_session.execute(
select(func.count()).select_from(Document).where(
select(func.count())
.select_from(Document)
.where(
Document.document_type == DocumentType.LOCAL_FOLDER_FILE,
Document.search_space_id == db_search_space.id,
)
@ -228,7 +244,9 @@ class TestFullIndexer:
docs_after = (
await db_session.execute(
select(func.count()).select_from(Document).where(
select(func.count())
.select_from(Document)
.where(
Document.document_type == DocumentType.LOCAL_FOLDER_FILE,
Document.search_space_id == db_search_space.id,
)
@ -262,13 +280,17 @@ class TestFullIndexer:
assert count == 1
docs = (
await db_session.execute(
select(Document).where(
Document.document_type == DocumentType.LOCAL_FOLDER_FILE,
Document.search_space_id == db_search_space.id,
(
await db_session.execute(
select(Document).where(
Document.document_type == DocumentType.LOCAL_FOLDER_FILE,
Document.search_space_id == db_search_space.id,
)
)
)
).scalars().all()
.scalars()
.all()
)
assert len(docs) == 1
assert docs[0].title == "b.md"
@ -279,7 +301,6 @@ class TestFullIndexer:
class TestFolderMirroring:
@pytest.mark.usefixtures(*UNIFIED_FIXTURES)
async def test_f1_root_folder_created(
self,
@ -335,10 +356,14 @@ class TestFolderMirroring:
)
folders = (
await db_session.execute(
select(Folder).where(Folder.search_space_id == db_search_space.id)
(
await db_session.execute(
select(Folder).where(Folder.search_space_id == db_search_space.id)
)
)
).scalars().all()
.scalars()
.all()
)
folder_names = {f.name for f in folders}
assert "notes" in folder_names
@ -376,10 +401,14 @@ class TestFolderMirroring:
)
folders_before = (
await db_session.execute(
select(Folder).where(Folder.search_space_id == db_search_space.id)
(
await db_session.execute(
select(Folder).where(Folder.search_space_id == db_search_space.id)
)
)
).scalars().all()
.scalars()
.all()
)
ids_before = {f.id for f in folders_before}
await index_local_folder(
@ -392,10 +421,14 @@ class TestFolderMirroring:
)
folders_after = (
await db_session.execute(
select(Folder).where(Folder.search_space_id == db_search_space.id)
(
await db_session.execute(
select(Folder).where(Folder.search_space_id == db_search_space.id)
)
)
).scalars().all()
.scalars()
.all()
)
ids_after = {f.id for f in folders_after}
assert ids_before == ids_after
@ -425,21 +458,23 @@ class TestFolderMirroring:
)
docs = (
await db_session.execute(
select(Document).where(
Document.document_type == DocumentType.LOCAL_FOLDER_FILE,
Document.search_space_id == db_search_space.id,
(
await db_session.execute(
select(Document).where(
Document.document_type == DocumentType.LOCAL_FOLDER_FILE,
Document.search_space_id == db_search_space.id,
)
)
)
).scalars().all()
.scalars()
.all()
)
today_doc = next(d for d in docs if d.title == "today.md")
root_doc = next(d for d in docs if d.title == "root.md")
daily_folder = (
await db_session.execute(
select(Folder).where(Folder.name == "daily")
)
await db_session.execute(select(Folder).where(Folder.name == "daily"))
).scalar_one()
assert today_doc.folder_id == daily_folder.id
@ -455,9 +490,10 @@ class TestFolderMirroring:
tmp_path: Path,
):
"""F5: Deleted dir's empty Folder row is cleaned up on re-sync."""
from app.tasks.connector_indexers.local_folder_indexer import index_local_folder
import shutil
from app.tasks.connector_indexers.local_folder_indexer import index_local_folder
daily = tmp_path / "notes" / "daily"
daily.mkdir(parents=True)
weekly = tmp_path / "notes" / "weekly"
@ -474,9 +510,7 @@ class TestFolderMirroring:
)
weekly_folder = (
await db_session.execute(
select(Folder).where(Folder.name == "weekly")
)
await db_session.execute(select(Folder).where(Folder.name == "weekly"))
).scalar_one_or_none()
assert weekly_folder is not None
@ -492,16 +526,12 @@ class TestFolderMirroring:
)
weekly_after = (
await db_session.execute(
select(Folder).where(Folder.name == "weekly")
)
await db_session.execute(select(Folder).where(Folder.name == "weekly"))
).scalar_one_or_none()
assert weekly_after is None
daily_after = (
await db_session.execute(
select(Folder).where(Folder.name == "daily")
)
await db_session.execute(select(Folder).where(Folder.name == "daily"))
).scalar_one_or_none()
assert daily_after is not None
@ -551,18 +581,14 @@ class TestFolderMirroring:
).scalar_one()
daily_folder = (
await db_session.execute(
select(Folder).where(Folder.name == "daily")
)
await db_session.execute(select(Folder).where(Folder.name == "daily"))
).scalar_one()
assert doc.folder_id == daily_folder.id
assert daily_folder.parent_id is not None
notes_folder = (
await db_session.execute(
select(Folder).where(Folder.name == "notes")
)
await db_session.execute(select(Folder).where(Folder.name == "notes"))
).scalar_one()
assert daily_folder.parent_id == notes_folder.id
assert notes_folder.parent_id == root_folder_id
@ -592,9 +618,7 @@ class TestFolderMirroring:
)
eph_folder = (
await db_session.execute(
select(Folder).where(Folder.name == "ephemeral")
)
await db_session.execute(select(Folder).where(Folder.name == "ephemeral"))
).scalar_one_or_none()
assert eph_folder is not None
@ -612,16 +636,12 @@ class TestFolderMirroring:
)
eph_after = (
await db_session.execute(
select(Folder).where(Folder.name == "ephemeral")
)
await db_session.execute(select(Folder).where(Folder.name == "ephemeral"))
).scalar_one_or_none()
assert eph_after is None
notes_after = (
await db_session.execute(
select(Folder).where(Folder.name == "notes")
)
await db_session.execute(select(Folder).where(Folder.name == "notes"))
).scalar_one_or_none()
assert notes_after is None
@ -632,7 +652,6 @@ class TestFolderMirroring:
class TestBatchMode:
@pytest.mark.usefixtures(*UNIFIED_FIXTURES)
async def test_b1_batch_indexes_multiple_files(
self,
@ -649,7 +668,7 @@ class TestBatchMode:
(tmp_path / "b.md").write_text("File B content")
(tmp_path / "c.md").write_text("File C content")
count, failed, root_folder_id, err = await index_local_folder(
count, failed, _root_folder_id, err = await index_local_folder(
session=db_session,
search_space_id=db_search_space.id,
user_id=str(db_user.id),
@ -667,13 +686,17 @@ class TestBatchMode:
assert err is None
docs = (
await db_session.execute(
select(Document).where(
Document.document_type == DocumentType.LOCAL_FOLDER_FILE,
Document.search_space_id == db_search_space.id,
(
await db_session.execute(
select(Document).where(
Document.document_type == DocumentType.LOCAL_FOLDER_FILE,
Document.search_space_id == db_search_space.id,
)
)
)
).scalars().all()
.scalars()
.all()
)
assert len(docs) == 3
assert {d.title for d in docs} == {"a.md", "b.md", "c.md"}
assert all(
@ -714,13 +737,17 @@ class TestBatchMode:
assert err is not None
docs = (
await db_session.execute(
select(Document).where(
Document.document_type == DocumentType.LOCAL_FOLDER_FILE,
Document.search_space_id == db_search_space.id,
(
await db_session.execute(
select(Document).where(
Document.document_type == DocumentType.LOCAL_FOLDER_FILE,
Document.search_space_id == db_search_space.id,
)
)
)
).scalars().all()
.scalars()
.all()
)
assert len(docs) == 2
assert {d.title for d in docs} == {"good1.md", "good2.md"}
@ -731,7 +758,6 @@ class TestBatchMode:
class TestPipelineIntegration:
@pytest.mark.usefixtures(*UNIFIED_FIXTURES)
async def test_p1_local_folder_file_through_pipeline(
self,
@ -742,7 +768,9 @@ class TestPipelineIntegration:
):
"""P1: LOCAL_FOLDER_FILE ConnectorDocument through prepare+index to READY."""
from app.indexing_pipeline.connector_document import ConnectorDocument
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
from app.indexing_pipeline.indexing_pipeline_service import (
IndexingPipelineService,
)
doc = ConnectorDocument(
title="Test Local File",
@ -763,12 +791,16 @@ class TestPipelineIntegration:
assert result is not None
docs = (
await db_session.execute(
select(Document).where(
Document.document_type == DocumentType.LOCAL_FOLDER_FILE,
Document.search_space_id == db_search_space.id,
(
await db_session.execute(
select(Document).where(
Document.document_type == DocumentType.LOCAL_FOLDER_FILE,
Document.search_space_id == db_search_space.id,
)
)
)
).scalars().all()
.scalars()
.all()
)
assert len(docs) == 1
assert DocumentStatus.is_state(docs[0].status, DocumentStatus.READY)

View file

@ -34,14 +34,16 @@ async def db_document(
async def _version_count(session: AsyncSession, document_id: int) -> int:
result = await session.execute(
select(func.count()).select_from(DocumentVersion).where(
DocumentVersion.document_id == document_id
)
select(func.count())
.select_from(DocumentVersion)
.where(DocumentVersion.document_id == document_id)
)
return result.scalar_one()
async def _get_versions(session: AsyncSession, document_id: int) -> list[DocumentVersion]:
async def _get_versions(
session: AsyncSession, document_id: int
) -> list[DocumentVersion]:
result = await session.execute(
select(DocumentVersion)
.where(DocumentVersion.document_id == document_id)
@ -74,18 +76,14 @@ class TestCreateVersionSnapshot:
from app.utils.document_versioning import create_version_snapshot
t0 = datetime(2025, 1, 1, 12, 0, 0, tzinfo=UTC)
monkeypatch.setattr(
"app.utils.document_versioning._now", lambda: t0
)
monkeypatch.setattr("app.utils.document_versioning._now", lambda: t0)
await create_version_snapshot(db_session, db_document)
# Simulate content change and time passing
db_document.source_markdown = "# Test\n\nUpdated content."
db_document.content_hash = "def456"
t1 = t0 + timedelta(minutes=31)
monkeypatch.setattr(
"app.utils.document_versioning._now", lambda: t1
)
monkeypatch.setattr("app.utils.document_versioning._now", lambda: t1)
await create_version_snapshot(db_session, db_document)
versions = await _get_versions(db_session, db_document.id)
@ -101,9 +99,7 @@ class TestCreateVersionSnapshot:
from app.utils.document_versioning import create_version_snapshot
t0 = datetime(2025, 1, 1, 12, 0, 0, tzinfo=UTC)
monkeypatch.setattr(
"app.utils.document_versioning._now", lambda: t0
)
monkeypatch.setattr("app.utils.document_versioning._now", lambda: t0)
await create_version_snapshot(db_session, db_document)
count_after_first = await _version_count(db_session, db_document.id)
assert count_after_first == 1
@ -112,9 +108,7 @@ class TestCreateVersionSnapshot:
db_document.source_markdown = "# Test\n\nQuick edit."
db_document.content_hash = "quick123"
t1 = t0 + timedelta(minutes=10)
monkeypatch.setattr(
"app.utils.document_versioning._now", lambda: t1
)
monkeypatch.setattr("app.utils.document_versioning._now", lambda: t1)
await create_version_snapshot(db_session, db_document)
count_after_second = await _version_count(db_session, db_document.id)
@ -134,22 +128,15 @@ class TestCreateVersionSnapshot:
# Create 5 versions spread across time: 3 older than 90 days, 2 recent
for i in range(5):
db_document.source_markdown = f"Content v{i+1}"
db_document.content_hash = f"hash_{i+1}"
if i < 3:
t = base + timedelta(days=i) # old
else:
t = base + timedelta(days=100 + i) # recent
monkeypatch.setattr(
"app.utils.document_versioning._now", lambda _t=t: _t
)
db_document.source_markdown = f"Content v{i + 1}"
db_document.content_hash = f"hash_{i + 1}"
t = base + timedelta(days=i) if i < 3 else base + timedelta(days=100 + i)
monkeypatch.setattr("app.utils.document_versioning._now", lambda _t=t: _t)
await create_version_snapshot(db_session, db_document)
# Now trigger cleanup from a "current" time that makes the first 3 versions > 90 days old
now = base + timedelta(days=200)
monkeypatch.setattr(
"app.utils.document_versioning._now", lambda: now
)
monkeypatch.setattr("app.utils.document_versioning._now", lambda: now)
db_document.source_markdown = "Content v6"
db_document.content_hash = "hash_6"
await create_version_snapshot(db_session, db_document)
@ -160,9 +147,7 @@ class TestCreateVersionSnapshot:
age = now - v.created_at.replace(tzinfo=UTC)
assert age <= timedelta(days=90), f"Version {v.version_number} is too old"
async def test_v5_cap_at_20_versions(
self, db_session, db_document, monkeypatch
):
async def test_v5_cap_at_20_versions(self, db_session, db_document, monkeypatch):
"""V5: More than 20 versions triggers cap — oldest gets deleted."""
from app.utils.document_versioning import create_version_snapshot
@ -170,12 +155,10 @@ class TestCreateVersionSnapshot:
# Create 21 versions (all within 90 days, each 31 min apart)
for i in range(21):
db_document.source_markdown = f"Content v{i+1}"
db_document.content_hash = f"hash_{i+1}"
db_document.source_markdown = f"Content v{i + 1}"
db_document.content_hash = f"hash_{i + 1}"
t = base + timedelta(minutes=31 * i)
monkeypatch.setattr(
"app.utils.document_versioning._now", lambda _t=t: _t
)
monkeypatch.setattr("app.utils.document_versioning._now", lambda _t=t: _t)
await create_version_snapshot(db_session, db_document)
versions = await _get_versions(db_session, db_document.id)