feat: added ai file sorting

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-04-14 01:43:30 -07:00
parent fa0b47dfca
commit 4bee367d4a
51 changed files with 1703 additions and 72 deletions

View file

@ -431,7 +431,9 @@ async def test_llamacloud_heif_accepted_only_with_azure_di(tmp_path, mocker):
mocker.patch("app.config.config.AZURE_DI_ENDPOINT", None, create=True)
mocker.patch("app.config.config.AZURE_DI_KEY", None, create=True)
with pytest.raises(EtlUnsupportedFileError, match="document parser does not support this format"):
with pytest.raises(
EtlUnsupportedFileError, match="document parser does not support this format"
):
await EtlPipelineService().extract(
EtlRequest(file_path=str(heif_file), filename="photo.heif")
)

View file

@ -6,6 +6,7 @@ import pytest
from langchain_core.messages import AIMessage, HumanMessage
from app.agents.new_chat.middleware.knowledge_search import (
KBSearchPlan,
KnowledgeBaseSearchMiddleware,
_build_document_xml,
_normalize_optional_date_range,
@ -366,3 +367,146 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
assert captured["query"] == "deel founders guide summary"
assert captured["start_date"] is None
assert captured["end_date"] is None
async def test_middleware_routes_to_recency_browse_when_flagged(
self,
monkeypatch,
):
"""When the planner sets is_recency_query=true, browse_recent_documents
is called instead of search_knowledge_base."""
browse_captured: dict = {}
search_called = False
async def fake_browse_recent_documents(**kwargs):
browse_captured.update(kwargs)
return []
async def fake_search_knowledge_base(**kwargs):
nonlocal search_called
search_called = True
return []
async def fake_build_scoped_filesystem(**kwargs):
return {}, {}
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.browse_recent_documents",
fake_browse_recent_documents,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
fake_build_scoped_filesystem,
)
llm = FakeLLM(
json.dumps(
{
"optimized_query": "latest uploaded file",
"start_date": None,
"end_date": None,
"is_recency_query": True,
}
)
)
middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=42)
result = await middleware.abefore_agent(
{"messages": [HumanMessage(content="what's my latest file?")]},
runtime=None,
)
assert result is not None
assert browse_captured["search_space_id"] == 42
assert not search_called
async def test_middleware_uses_hybrid_search_when_not_recency(
self,
monkeypatch,
):
"""When is_recency_query is false (default), hybrid search is used."""
search_captured: dict = {}
browse_called = False
async def fake_browse_recent_documents(**kwargs):
nonlocal browse_called
browse_called = True
return []
async def fake_search_knowledge_base(**kwargs):
search_captured.update(kwargs)
return []
async def fake_build_scoped_filesystem(**kwargs):
return {}, {}
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.browse_recent_documents",
fake_browse_recent_documents,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
fake_search_knowledge_base,
)
monkeypatch.setattr(
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
fake_build_scoped_filesystem,
)
llm = FakeLLM(
json.dumps(
{
"optimized_query": "quarterly revenue report analysis",
"start_date": None,
"end_date": None,
"is_recency_query": False,
}
)
)
middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=42)
await middleware.abefore_agent(
{"messages": [HumanMessage(content="find the quarterly revenue report")]},
runtime=None,
)
assert search_captured["query"] == "quarterly revenue report analysis"
assert not browse_called
# ── KBSearchPlan schema ────────────────────────────────────────────────
class TestKBSearchPlanSchema:
def test_is_recency_query_defaults_to_false(self):
plan = KBSearchPlan(optimized_query="test query")
assert plan.is_recency_query is False
def test_is_recency_query_parses_true(self):
plan = _parse_kb_search_plan_response(
json.dumps(
{
"optimized_query": "latest uploaded file",
"start_date": None,
"end_date": None,
"is_recency_query": True,
}
)
)
assert plan.is_recency_query is True
assert plan.optimized_query == "latest uploaded file"
def test_missing_is_recency_query_defaults_to_false(self):
plan = _parse_kb_search_plan_response(
json.dumps(
{
"optimized_query": "meeting notes",
"start_date": None,
"end_date": None,
}
)
)
assert plan.is_recency_query is False

View file

@ -0,0 +1,275 @@
"""Unit tests for AI file sort service: folder label resolution, date extraction, category sanitization."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
pytestmark = pytest.mark.unit
# ── resolve_root_folder_label ──
def _make_document(document_type: str, connector_id=None):
doc = MagicMock()
doc.document_type = document_type
doc.connector_id = connector_id
return doc
def _make_connector(connector_type: str):
conn = MagicMock()
conn.connector_type = connector_type
return conn
def test_root_label_uses_connector_type_when_available():
from app.services.ai_file_sort_service import resolve_root_folder_label
doc = _make_document("FILE", connector_id=1)
conn = _make_connector("GOOGLE_DRIVE_CONNECTOR")
assert resolve_root_folder_label(doc, conn) == "Google Drive"
def test_root_label_falls_back_to_document_type():
from app.services.ai_file_sort_service import resolve_root_folder_label
doc = _make_document("SLACK_CONNECTOR")
assert resolve_root_folder_label(doc, None) == "Slack"
def test_root_label_unknown_doctype_returns_raw_value():
from app.services.ai_file_sort_service import resolve_root_folder_label
doc = _make_document("UNKNOWN_TYPE")
assert resolve_root_folder_label(doc, None) == "UNKNOWN_TYPE"
# ── resolve_date_folder ──
def test_date_folder_from_updated_at():
from app.services.ai_file_sort_service import resolve_date_folder
doc = MagicMock()
doc.updated_at = datetime(2025, 3, 15, 10, 30, 0, tzinfo=UTC)
doc.created_at = datetime(2025, 1, 1, 0, 0, 0, tzinfo=UTC)
assert resolve_date_folder(doc) == "2025-03-15"
def test_date_folder_falls_back_to_created_at():
from app.services.ai_file_sort_service import resolve_date_folder
doc = MagicMock()
doc.updated_at = None
doc.created_at = datetime(2024, 12, 25, 23, 59, 0, tzinfo=UTC)
assert resolve_date_folder(doc) == "2024-12-25"
def test_date_folder_both_none_uses_today():
from app.services.ai_file_sort_service import resolve_date_folder
doc = MagicMock()
doc.updated_at = None
doc.created_at = None
result = resolve_date_folder(doc)
today = datetime.now(UTC).strftime("%Y-%m-%d")
assert result == today
# ── sanitize_category_folder_name ──
def test_sanitize_normal_value():
from app.services.ai_file_sort_service import sanitize_category_folder_name
assert sanitize_category_folder_name("Machine Learning") == "Machine Learning"
def test_sanitize_strips_special_chars():
from app.services.ai_file_sort_service import sanitize_category_folder_name
assert sanitize_category_folder_name("Tax/Reports!") == "TaxReports"
def test_sanitize_empty_returns_fallback():
from app.services.ai_file_sort_service import sanitize_category_folder_name
assert sanitize_category_folder_name("") == "Uncategorized"
assert sanitize_category_folder_name(None) == "Uncategorized"
def test_sanitize_truncates_long_names():
from app.services.ai_file_sort_service import sanitize_category_folder_name
long_name = "A" * 100
result = sanitize_category_folder_name(long_name)
assert len(result) <= 50
# ── generate_ai_taxonomy ──
@pytest.mark.asyncio
async def test_generate_ai_taxonomy_parses_json():
from app.services.ai_file_sort_service import generate_ai_taxonomy
mock_llm = AsyncMock()
mock_result = MagicMock()
mock_result.content = '{"category": "Science", "subcategory": "Physics"}'
mock_llm.ainvoke.return_value = mock_result
cat, sub = await generate_ai_taxonomy(
"Physics Paper", "Some science document about physics", mock_llm
)
assert cat == "Science"
assert sub == "Physics"
@pytest.mark.asyncio
async def test_generate_ai_taxonomy_handles_markdown_code_block():
from app.services.ai_file_sort_service import generate_ai_taxonomy
mock_llm = AsyncMock()
mock_result = MagicMock()
mock_result.content = (
'```json\n{"category": "Finance", "subcategory": "Tax Reports"}\n```'
)
mock_llm.ainvoke.return_value = mock_result
cat, sub = await generate_ai_taxonomy("Tax Doc", "A tax report document", mock_llm)
assert cat == "Finance"
assert sub == "Tax Reports"
@pytest.mark.asyncio
async def test_generate_ai_taxonomy_includes_title_in_prompt():
from app.services.ai_file_sort_service import generate_ai_taxonomy
mock_llm = AsyncMock()
mock_result = MagicMock()
mock_result.content = '{"category": "Engineering", "subcategory": "Backend"}'
mock_llm.ainvoke.return_value = mock_result
await generate_ai_taxonomy("API Design Guide", "content about REST APIs", mock_llm)
prompt_text = mock_llm.ainvoke.call_args[0][0][0].content
assert "API Design Guide" in prompt_text
assert "content about REST APIs" in prompt_text
@pytest.mark.asyncio
async def test_generate_ai_taxonomy_fallback_on_error():
from app.services.ai_file_sort_service import generate_ai_taxonomy
mock_llm = AsyncMock()
mock_llm.ainvoke.side_effect = RuntimeError("LLM down")
cat, sub = await generate_ai_taxonomy("Title", "some content", mock_llm)
assert cat == "Uncategorized"
assert sub == "General"
@pytest.mark.asyncio
async def test_generate_ai_taxonomy_fallback_on_empty_content():
from app.services.ai_file_sort_service import generate_ai_taxonomy
mock_llm = AsyncMock()
cat, sub = await generate_ai_taxonomy("Title", "", mock_llm)
assert cat == "Uncategorized"
assert sub == "General"
mock_llm.ainvoke.assert_not_called()
@pytest.mark.asyncio
async def test_generate_ai_taxonomy_fallback_on_invalid_json():
from app.services.ai_file_sort_service import generate_ai_taxonomy
mock_llm = AsyncMock()
mock_result = MagicMock()
mock_result.content = "not valid json at all"
mock_llm.ainvoke.return_value = mock_result
cat, sub = await generate_ai_taxonomy("Title", "some content", mock_llm)
assert cat == "Uncategorized"
assert sub == "General"
# ── taxonomy caching ──
def test_get_cached_taxonomy_returns_none_when_no_metadata():
from app.services.ai_file_sort_service import _get_cached_taxonomy
doc = MagicMock()
doc.document_metadata = None
assert _get_cached_taxonomy(doc) is None
def test_get_cached_taxonomy_returns_none_when_keys_missing():
from app.services.ai_file_sort_service import _get_cached_taxonomy
doc = MagicMock()
doc.document_metadata = {"some_other_key": "value"}
assert _get_cached_taxonomy(doc) is None
def test_get_cached_taxonomy_returns_cached_values():
from app.services.ai_file_sort_service import _get_cached_taxonomy
doc = MagicMock()
doc.document_metadata = {
"ai_sort_category": "Finance",
"ai_sort_subcategory": "Tax Reports",
}
assert _get_cached_taxonomy(doc) == ("Finance", "Tax Reports")
def test_set_cached_taxonomy_persists_on_metadata():
from app.services.ai_file_sort_service import _set_cached_taxonomy
doc = MagicMock()
doc.document_metadata = {"existing_key": "keep_me"}
_set_cached_taxonomy(doc, "Science", "Physics")
assert doc.document_metadata["ai_sort_category"] == "Science"
assert doc.document_metadata["ai_sort_subcategory"] == "Physics"
assert doc.document_metadata["existing_key"] == "keep_me"
def test_set_cached_taxonomy_creates_metadata_when_none():
from app.services.ai_file_sort_service import _set_cached_taxonomy
doc = MagicMock()
doc.document_metadata = None
_set_cached_taxonomy(doc, "Engineering", "Backend")
assert doc.document_metadata == {
"ai_sort_category": "Engineering",
"ai_sort_subcategory": "Backend",
}
# ── _build_path_segments ──
def test_build_path_segments_structure():
from app.services.ai_file_sort_service import _build_path_segments
segments = _build_path_segments("Google Drive", "2025-03-15", "Science", "Physics")
assert len(segments) == 4
assert segments[0] == {
"name": "Google Drive",
"metadata": {"ai_sort": True, "ai_sort_level": 1},
}
assert segments[1] == {
"name": "2025-03-15",
"metadata": {"ai_sort": True, "ai_sort_level": 2},
}
assert segments[2] == {
"name": "Science",
"metadata": {"ai_sort": True, "ai_sort_level": 3},
}
assert segments[3] == {
"name": "Physics",
"metadata": {"ai_sort": True, "ai_sort_level": 4},
}

View file

@ -0,0 +1,43 @@
"""Unit tests for AI sort task Redis deduplication lock."""
from unittest.mock import MagicMock, patch
import pytest
pytestmark = pytest.mark.unit
def test_lock_key_format():
from app.tasks.celery_tasks.document_tasks import _ai_sort_lock_key
key = _ai_sort_lock_key(42)
assert key == "ai_sort:search_space:42:lock"
def test_lock_prevents_duplicate_run():
"""When the Redis lock already exists, the task should skip execution."""
mock_redis = MagicMock()
mock_redis.set.return_value = False # Lock already held
with (
patch(
"app.tasks.celery_tasks.document_tasks._get_ai_sort_redis",
return_value=mock_redis,
),
patch(
"app.tasks.celery_tasks.document_tasks.get_celery_session_maker"
) as mock_session_maker,
):
import asyncio
from app.tasks.celery_tasks.document_tasks import _ai_sort_search_space_async
loop = asyncio.new_event_loop()
try:
loop.run_until_complete(_ai_sort_search_space_async(1, "user-123"))
finally:
loop.close()
# Session maker should never be called since lock was not acquired
mock_session_maker.assert_not_called()

View file

@ -0,0 +1,87 @@
"""Unit tests for ensure_folder_hierarchy_with_depth_validation."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
pytestmark = pytest.mark.unit
@pytest.mark.asyncio
async def test_creates_missing_folders_in_chain():
"""Should create all folders when none exist."""
from app.services.folder_service import (
ensure_folder_hierarchy_with_depth_validation,
)
session = AsyncMock()
# All lookups return None (no existing folders)
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
session.execute.return_value = mock_result
folder_instances = []
def track_add(obj):
folder_instances.append(obj)
session.add = track_add
with (
patch(
"app.services.folder_service.validate_folder_depth", new_callable=AsyncMock
),
patch(
"app.services.folder_service.generate_folder_position",
new_callable=AsyncMock,
return_value="a0",
),
):
# Mock flush to assign IDs
call_count = 0
async def mock_flush():
nonlocal call_count
call_count += 1
if folder_instances:
folder_instances[-1].id = call_count
session.flush = mock_flush
segments = [
{"name": "Slack", "metadata": {"ai_sort": True, "ai_sort_level": 1}},
{"name": "2025-03-15", "metadata": {"ai_sort": True, "ai_sort_level": 2}},
]
result = await ensure_folder_hierarchy_with_depth_validation(
session, 1, segments
)
assert len(folder_instances) == 2
assert folder_instances[0].name == "Slack"
assert folder_instances[1].name == "2025-03-15"
assert result is folder_instances[-1]
@pytest.mark.asyncio
async def test_reuses_existing_folder():
"""When a folder already exists, it should be reused, not created."""
from app.services.folder_service import (
ensure_folder_hierarchy_with_depth_validation,
)
session = AsyncMock()
existing_folder = MagicMock()
existing_folder.id = 42
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = existing_folder
session.execute.return_value = mock_result
segments = [{"name": "Existing", "metadata": None}]
result = await ensure_folder_hierarchy_with_depth_validation(session, 1, segments)
assert result is existing_folder
session.add.assert_not_called()