mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 08:46:22 +02:00
feat: added ai file sorting
This commit is contained in:
parent
fa0b47dfca
commit
4bee367d4a
51 changed files with 1703 additions and 72 deletions
|
|
@ -431,7 +431,9 @@ async def test_llamacloud_heif_accepted_only_with_azure_di(tmp_path, mocker):
|
|||
mocker.patch("app.config.config.AZURE_DI_ENDPOINT", None, create=True)
|
||||
mocker.patch("app.config.config.AZURE_DI_KEY", None, create=True)
|
||||
|
||||
with pytest.raises(EtlUnsupportedFileError, match="document parser does not support this format"):
|
||||
with pytest.raises(
|
||||
EtlUnsupportedFileError, match="document parser does not support this format"
|
||||
):
|
||||
await EtlPipelineService().extract(
|
||||
EtlRequest(file_path=str(heif_file), filename="photo.heif")
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from app.agents.new_chat.middleware.knowledge_search import (
|
||||
KBSearchPlan,
|
||||
KnowledgeBaseSearchMiddleware,
|
||||
_build_document_xml,
|
||||
_normalize_optional_date_range,
|
||||
|
|
@ -366,3 +367,146 @@ class TestKnowledgeBaseSearchMiddlewarePlanner:
|
|||
assert captured["query"] == "deel founders guide summary"
|
||||
assert captured["start_date"] is None
|
||||
assert captured["end_date"] is None
|
||||
|
||||
async def test_middleware_routes_to_recency_browse_when_flagged(
|
||||
self,
|
||||
monkeypatch,
|
||||
):
|
||||
"""When the planner sets is_recency_query=true, browse_recent_documents
|
||||
is called instead of search_knowledge_base."""
|
||||
browse_captured: dict = {}
|
||||
search_called = False
|
||||
|
||||
async def fake_browse_recent_documents(**kwargs):
|
||||
browse_captured.update(kwargs)
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**kwargs):
|
||||
nonlocal search_called
|
||||
search_called = True
|
||||
return []
|
||||
|
||||
async def fake_build_scoped_filesystem(**kwargs):
|
||||
return {}, {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.browse_recent_documents",
|
||||
fake_browse_recent_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
|
||||
fake_build_scoped_filesystem,
|
||||
)
|
||||
|
||||
llm = FakeLLM(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "latest uploaded file",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"is_recency_query": True,
|
||||
}
|
||||
)
|
||||
)
|
||||
middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=42)
|
||||
|
||||
result = await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="what's my latest file?")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert browse_captured["search_space_id"] == 42
|
||||
assert not search_called
|
||||
|
||||
async def test_middleware_uses_hybrid_search_when_not_recency(
|
||||
self,
|
||||
monkeypatch,
|
||||
):
|
||||
"""When is_recency_query is false (default), hybrid search is used."""
|
||||
search_captured: dict = {}
|
||||
browse_called = False
|
||||
|
||||
async def fake_browse_recent_documents(**kwargs):
|
||||
nonlocal browse_called
|
||||
browse_called = True
|
||||
return []
|
||||
|
||||
async def fake_search_knowledge_base(**kwargs):
|
||||
search_captured.update(kwargs)
|
||||
return []
|
||||
|
||||
async def fake_build_scoped_filesystem(**kwargs):
|
||||
return {}, {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.browse_recent_documents",
|
||||
fake_browse_recent_documents,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.search_knowledge_base",
|
||||
fake_search_knowledge_base,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"app.agents.new_chat.middleware.knowledge_search.build_scoped_filesystem",
|
||||
fake_build_scoped_filesystem,
|
||||
)
|
||||
|
||||
llm = FakeLLM(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "quarterly revenue report analysis",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"is_recency_query": False,
|
||||
}
|
||||
)
|
||||
)
|
||||
middleware = KnowledgeBaseSearchMiddleware(llm=llm, search_space_id=42)
|
||||
|
||||
await middleware.abefore_agent(
|
||||
{"messages": [HumanMessage(content="find the quarterly revenue report")]},
|
||||
runtime=None,
|
||||
)
|
||||
|
||||
assert search_captured["query"] == "quarterly revenue report analysis"
|
||||
assert not browse_called
|
||||
|
||||
|
||||
# ── KBSearchPlan schema ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestKBSearchPlanSchema:
|
||||
def test_is_recency_query_defaults_to_false(self):
|
||||
plan = KBSearchPlan(optimized_query="test query")
|
||||
assert plan.is_recency_query is False
|
||||
|
||||
def test_is_recency_query_parses_true(self):
|
||||
plan = _parse_kb_search_plan_response(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "latest uploaded file",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"is_recency_query": True,
|
||||
}
|
||||
)
|
||||
)
|
||||
assert plan.is_recency_query is True
|
||||
assert plan.optimized_query == "latest uploaded file"
|
||||
|
||||
def test_missing_is_recency_query_defaults_to_false(self):
|
||||
plan = _parse_kb_search_plan_response(
|
||||
json.dumps(
|
||||
{
|
||||
"optimized_query": "meeting notes",
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
}
|
||||
)
|
||||
)
|
||||
assert plan.is_recency_query is False
|
||||
|
|
|
|||
|
|
@ -0,0 +1,275 @@
|
|||
"""Unit tests for AI file sort service: folder label resolution, date extraction, category sanitization."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ── resolve_root_folder_label ──
|
||||
|
||||
|
||||
def _make_document(document_type: str, connector_id=None):
|
||||
doc = MagicMock()
|
||||
doc.document_type = document_type
|
||||
doc.connector_id = connector_id
|
||||
return doc
|
||||
|
||||
|
||||
def _make_connector(connector_type: str):
|
||||
conn = MagicMock()
|
||||
conn.connector_type = connector_type
|
||||
return conn
|
||||
|
||||
|
||||
def test_root_label_uses_connector_type_when_available():
|
||||
from app.services.ai_file_sort_service import resolve_root_folder_label
|
||||
|
||||
doc = _make_document("FILE", connector_id=1)
|
||||
conn = _make_connector("GOOGLE_DRIVE_CONNECTOR")
|
||||
assert resolve_root_folder_label(doc, conn) == "Google Drive"
|
||||
|
||||
|
||||
def test_root_label_falls_back_to_document_type():
|
||||
from app.services.ai_file_sort_service import resolve_root_folder_label
|
||||
|
||||
doc = _make_document("SLACK_CONNECTOR")
|
||||
assert resolve_root_folder_label(doc, None) == "Slack"
|
||||
|
||||
|
||||
def test_root_label_unknown_doctype_returns_raw_value():
|
||||
from app.services.ai_file_sort_service import resolve_root_folder_label
|
||||
|
||||
doc = _make_document("UNKNOWN_TYPE")
|
||||
assert resolve_root_folder_label(doc, None) == "UNKNOWN_TYPE"
|
||||
|
||||
|
||||
# ── resolve_date_folder ──
|
||||
|
||||
|
||||
def test_date_folder_from_updated_at():
|
||||
from app.services.ai_file_sort_service import resolve_date_folder
|
||||
|
||||
doc = MagicMock()
|
||||
doc.updated_at = datetime(2025, 3, 15, 10, 30, 0, tzinfo=UTC)
|
||||
doc.created_at = datetime(2025, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||
assert resolve_date_folder(doc) == "2025-03-15"
|
||||
|
||||
|
||||
def test_date_folder_falls_back_to_created_at():
|
||||
from app.services.ai_file_sort_service import resolve_date_folder
|
||||
|
||||
doc = MagicMock()
|
||||
doc.updated_at = None
|
||||
doc.created_at = datetime(2024, 12, 25, 23, 59, 0, tzinfo=UTC)
|
||||
assert resolve_date_folder(doc) == "2024-12-25"
|
||||
|
||||
|
||||
def test_date_folder_both_none_uses_today():
|
||||
from app.services.ai_file_sort_service import resolve_date_folder
|
||||
|
||||
doc = MagicMock()
|
||||
doc.updated_at = None
|
||||
doc.created_at = None
|
||||
result = resolve_date_folder(doc)
|
||||
today = datetime.now(UTC).strftime("%Y-%m-%d")
|
||||
assert result == today
|
||||
|
||||
|
||||
# ── sanitize_category_folder_name ──
|
||||
|
||||
|
||||
def test_sanitize_normal_value():
|
||||
from app.services.ai_file_sort_service import sanitize_category_folder_name
|
||||
|
||||
assert sanitize_category_folder_name("Machine Learning") == "Machine Learning"
|
||||
|
||||
|
||||
def test_sanitize_strips_special_chars():
|
||||
from app.services.ai_file_sort_service import sanitize_category_folder_name
|
||||
|
||||
assert sanitize_category_folder_name("Tax/Reports!") == "TaxReports"
|
||||
|
||||
|
||||
def test_sanitize_empty_returns_fallback():
|
||||
from app.services.ai_file_sort_service import sanitize_category_folder_name
|
||||
|
||||
assert sanitize_category_folder_name("") == "Uncategorized"
|
||||
assert sanitize_category_folder_name(None) == "Uncategorized"
|
||||
|
||||
|
||||
def test_sanitize_truncates_long_names():
|
||||
from app.services.ai_file_sort_service import sanitize_category_folder_name
|
||||
|
||||
long_name = "A" * 100
|
||||
result = sanitize_category_folder_name(long_name)
|
||||
assert len(result) <= 50
|
||||
|
||||
|
||||
# ── generate_ai_taxonomy ──
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_ai_taxonomy_parses_json():
|
||||
from app.services.ai_file_sort_service import generate_ai_taxonomy
|
||||
|
||||
mock_llm = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.content = '{"category": "Science", "subcategory": "Physics"}'
|
||||
mock_llm.ainvoke.return_value = mock_result
|
||||
|
||||
cat, sub = await generate_ai_taxonomy(
|
||||
"Physics Paper", "Some science document about physics", mock_llm
|
||||
)
|
||||
assert cat == "Science"
|
||||
assert sub == "Physics"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_ai_taxonomy_handles_markdown_code_block():
|
||||
from app.services.ai_file_sort_service import generate_ai_taxonomy
|
||||
|
||||
mock_llm = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.content = (
|
||||
'```json\n{"category": "Finance", "subcategory": "Tax Reports"}\n```'
|
||||
)
|
||||
mock_llm.ainvoke.return_value = mock_result
|
||||
|
||||
cat, sub = await generate_ai_taxonomy("Tax Doc", "A tax report document", mock_llm)
|
||||
assert cat == "Finance"
|
||||
assert sub == "Tax Reports"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_ai_taxonomy_includes_title_in_prompt():
|
||||
from app.services.ai_file_sort_service import generate_ai_taxonomy
|
||||
|
||||
mock_llm = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.content = '{"category": "Engineering", "subcategory": "Backend"}'
|
||||
mock_llm.ainvoke.return_value = mock_result
|
||||
|
||||
await generate_ai_taxonomy("API Design Guide", "content about REST APIs", mock_llm)
|
||||
|
||||
prompt_text = mock_llm.ainvoke.call_args[0][0][0].content
|
||||
assert "API Design Guide" in prompt_text
|
||||
assert "content about REST APIs" in prompt_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_ai_taxonomy_fallback_on_error():
|
||||
from app.services.ai_file_sort_service import generate_ai_taxonomy
|
||||
|
||||
mock_llm = AsyncMock()
|
||||
mock_llm.ainvoke.side_effect = RuntimeError("LLM down")
|
||||
|
||||
cat, sub = await generate_ai_taxonomy("Title", "some content", mock_llm)
|
||||
assert cat == "Uncategorized"
|
||||
assert sub == "General"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_ai_taxonomy_fallback_on_empty_content():
|
||||
from app.services.ai_file_sort_service import generate_ai_taxonomy
|
||||
|
||||
mock_llm = AsyncMock()
|
||||
cat, sub = await generate_ai_taxonomy("Title", "", mock_llm)
|
||||
assert cat == "Uncategorized"
|
||||
assert sub == "General"
|
||||
mock_llm.ainvoke.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_ai_taxonomy_fallback_on_invalid_json():
|
||||
from app.services.ai_file_sort_service import generate_ai_taxonomy
|
||||
|
||||
mock_llm = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.content = "not valid json at all"
|
||||
mock_llm.ainvoke.return_value = mock_result
|
||||
|
||||
cat, sub = await generate_ai_taxonomy("Title", "some content", mock_llm)
|
||||
assert cat == "Uncategorized"
|
||||
assert sub == "General"
|
||||
|
||||
|
||||
# ── taxonomy caching ──
|
||||
|
||||
|
||||
def test_get_cached_taxonomy_returns_none_when_no_metadata():
|
||||
from app.services.ai_file_sort_service import _get_cached_taxonomy
|
||||
|
||||
doc = MagicMock()
|
||||
doc.document_metadata = None
|
||||
assert _get_cached_taxonomy(doc) is None
|
||||
|
||||
|
||||
def test_get_cached_taxonomy_returns_none_when_keys_missing():
|
||||
from app.services.ai_file_sort_service import _get_cached_taxonomy
|
||||
|
||||
doc = MagicMock()
|
||||
doc.document_metadata = {"some_other_key": "value"}
|
||||
assert _get_cached_taxonomy(doc) is None
|
||||
|
||||
|
||||
def test_get_cached_taxonomy_returns_cached_values():
|
||||
from app.services.ai_file_sort_service import _get_cached_taxonomy
|
||||
|
||||
doc = MagicMock()
|
||||
doc.document_metadata = {
|
||||
"ai_sort_category": "Finance",
|
||||
"ai_sort_subcategory": "Tax Reports",
|
||||
}
|
||||
assert _get_cached_taxonomy(doc) == ("Finance", "Tax Reports")
|
||||
|
||||
|
||||
def test_set_cached_taxonomy_persists_on_metadata():
|
||||
from app.services.ai_file_sort_service import _set_cached_taxonomy
|
||||
|
||||
doc = MagicMock()
|
||||
doc.document_metadata = {"existing_key": "keep_me"}
|
||||
_set_cached_taxonomy(doc, "Science", "Physics")
|
||||
assert doc.document_metadata["ai_sort_category"] == "Science"
|
||||
assert doc.document_metadata["ai_sort_subcategory"] == "Physics"
|
||||
assert doc.document_metadata["existing_key"] == "keep_me"
|
||||
|
||||
|
||||
def test_set_cached_taxonomy_creates_metadata_when_none():
|
||||
from app.services.ai_file_sort_service import _set_cached_taxonomy
|
||||
|
||||
doc = MagicMock()
|
||||
doc.document_metadata = None
|
||||
_set_cached_taxonomy(doc, "Engineering", "Backend")
|
||||
assert doc.document_metadata == {
|
||||
"ai_sort_category": "Engineering",
|
||||
"ai_sort_subcategory": "Backend",
|
||||
}
|
||||
|
||||
|
||||
# ── _build_path_segments ──
|
||||
|
||||
|
||||
def test_build_path_segments_structure():
|
||||
from app.services.ai_file_sort_service import _build_path_segments
|
||||
|
||||
segments = _build_path_segments("Google Drive", "2025-03-15", "Science", "Physics")
|
||||
assert len(segments) == 4
|
||||
assert segments[0] == {
|
||||
"name": "Google Drive",
|
||||
"metadata": {"ai_sort": True, "ai_sort_level": 1},
|
||||
}
|
||||
assert segments[1] == {
|
||||
"name": "2025-03-15",
|
||||
"metadata": {"ai_sort": True, "ai_sort_level": 2},
|
||||
}
|
||||
assert segments[2] == {
|
||||
"name": "Science",
|
||||
"metadata": {"ai_sort": True, "ai_sort_level": 3},
|
||||
}
|
||||
assert segments[3] == {
|
||||
"name": "Physics",
|
||||
"metadata": {"ai_sort": True, "ai_sort_level": 4},
|
||||
}
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
"""Unit tests for AI sort task Redis deduplication lock."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
def test_lock_key_format():
|
||||
from app.tasks.celery_tasks.document_tasks import _ai_sort_lock_key
|
||||
|
||||
key = _ai_sort_lock_key(42)
|
||||
assert key == "ai_sort:search_space:42:lock"
|
||||
|
||||
|
||||
def test_lock_prevents_duplicate_run():
|
||||
"""When the Redis lock already exists, the task should skip execution."""
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.set.return_value = False # Lock already held
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.tasks.celery_tasks.document_tasks._get_ai_sort_redis",
|
||||
return_value=mock_redis,
|
||||
),
|
||||
patch(
|
||||
"app.tasks.celery_tasks.document_tasks.get_celery_session_maker"
|
||||
) as mock_session_maker,
|
||||
):
|
||||
import asyncio
|
||||
|
||||
from app.tasks.celery_tasks.document_tasks import _ai_sort_search_space_async
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(_ai_sort_search_space_async(1, "user-123"))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Session maker should never be called since lock was not acquired
|
||||
mock_session_maker.assert_not_called()
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
"""Unit tests for ensure_folder_hierarchy_with_depth_validation."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_missing_folders_in_chain():
|
||||
"""Should create all folders when none exist."""
|
||||
from app.services.folder_service import (
|
||||
ensure_folder_hierarchy_with_depth_validation,
|
||||
)
|
||||
|
||||
session = AsyncMock()
|
||||
# All lookups return None (no existing folders)
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
session.execute.return_value = mock_result
|
||||
|
||||
folder_instances = []
|
||||
|
||||
def track_add(obj):
|
||||
folder_instances.append(obj)
|
||||
|
||||
session.add = track_add
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.folder_service.validate_folder_depth", new_callable=AsyncMock
|
||||
),
|
||||
patch(
|
||||
"app.services.folder_service.generate_folder_position",
|
||||
new_callable=AsyncMock,
|
||||
return_value="a0",
|
||||
),
|
||||
):
|
||||
# Mock flush to assign IDs
|
||||
call_count = 0
|
||||
|
||||
async def mock_flush():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if folder_instances:
|
||||
folder_instances[-1].id = call_count
|
||||
|
||||
session.flush = mock_flush
|
||||
|
||||
segments = [
|
||||
{"name": "Slack", "metadata": {"ai_sort": True, "ai_sort_level": 1}},
|
||||
{"name": "2025-03-15", "metadata": {"ai_sort": True, "ai_sort_level": 2}},
|
||||
]
|
||||
|
||||
result = await ensure_folder_hierarchy_with_depth_validation(
|
||||
session, 1, segments
|
||||
)
|
||||
|
||||
assert len(folder_instances) == 2
|
||||
assert folder_instances[0].name == "Slack"
|
||||
assert folder_instances[1].name == "2025-03-15"
|
||||
assert result is folder_instances[-1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reuses_existing_folder():
|
||||
"""When a folder already exists, it should be reused, not created."""
|
||||
from app.services.folder_service import (
|
||||
ensure_folder_hierarchy_with_depth_validation,
|
||||
)
|
||||
|
||||
session = AsyncMock()
|
||||
|
||||
existing_folder = MagicMock()
|
||||
existing_folder.id = 42
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = existing_folder
|
||||
session.execute.return_value = mock_result
|
||||
|
||||
segments = [{"name": "Existing", "metadata": None}]
|
||||
|
||||
result = await ensure_folder_hierarchy_with_depth_validation(session, 1, segments)
|
||||
|
||||
assert result is existing_folder
|
||||
session.add.assert_not_called()
|
||||
Loading…
Add table
Add a link
Reference in a new issue