feat: added ai file sorting

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-04-14 01:43:30 -07:00
parent fa0b47dfca
commit 4bee367d4a
51 changed files with 1703 additions and 72 deletions

View file

@ -4,6 +4,7 @@ import asyncio
import contextlib
import logging
import os
import time
from uuid import UUID
from app.celery_app import celery_app
@ -1551,3 +1552,121 @@ async def _index_uploaded_folder_files_async(
heartbeat_task.cancel()
if notification_id is not None:
_stop_heartbeat(notification_id)
# ===== AI File Sort tasks =====
AI_SORT_LOCK_TTL_SECONDS = 600 # 10 minutes
_ai_sort_redis = None
def _get_ai_sort_redis():
import redis
global _ai_sort_redis
if _ai_sort_redis is None:
_ai_sort_redis = redis.from_url(config.REDIS_APP_URL, decode_responses=True)
return _ai_sort_redis
def _ai_sort_lock_key(search_space_id: int) -> str:
return f"ai_sort:search_space:{search_space_id}:lock"
@celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1)
def ai_sort_search_space_task(self, search_space_id: int, user_id: str):
"""Full AI sort for all documents in a search space."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_ai_sort_search_space_async(search_space_id, user_id))
finally:
loop.close()
async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
r = _get_ai_sort_redis()
lock_key = _ai_sort_lock_key(search_space_id)
if not r.set(lock_key, "running", nx=True, ex=AI_SORT_LOCK_TTL_SECONDS):
logger.info(
"AI sort already running for search_space=%d, skipping",
search_space_id,
)
return
t_start = time.perf_counter()
try:
from app.services.ai_file_sort_service import ai_sort_all_documents
from app.services.llm_service import get_document_summary_llm
async with get_celery_session_maker()() as session:
llm = await get_document_summary_llm(
session, search_space_id, disable_streaming=True
)
if llm is None:
logger.warning(
"No LLM configured for search_space=%d, skipping AI sort",
search_space_id,
)
return
sorted_count, failed_count = await ai_sort_all_documents(
session, search_space_id, llm
)
elapsed = time.perf_counter() - t_start
logger.info(
"AI sort search_space=%d done in %.1fs: sorted=%d failed=%d",
search_space_id,
elapsed,
sorted_count,
failed_count,
)
finally:
r.delete(lock_key)
@celery_app.task(
name="ai_sort_document", bind=True, max_retries=2, default_retry_delay=10
)
def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int):
"""Incremental AI sort for a single document after indexing."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_ai_sort_document_async(search_space_id, user_id, document_id)
)
finally:
loop.close()
async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int):
from app.db import Document
from app.services.ai_file_sort_service import ai_sort_document
from app.services.llm_service import get_document_summary_llm
async with get_celery_session_maker()() as session:
document = await session.get(Document, document_id)
if document is None:
logger.warning("Document %d not found, skipping AI sort", document_id)
return
llm = await get_document_summary_llm(
session, search_space_id, disable_streaming=True
)
if llm is None:
logger.warning(
"No LLM for search_space=%d, skipping AI sort of doc=%d",
search_space_id,
document_id,
)
return
await ai_sort_document(session, document, llm)
await session.commit()
logger.info(
"AI sorted document=%d into search_space=%d",
document_id,
search_space_id,
)

View file

@ -61,6 +61,7 @@ from app.services.new_streaming_service import VercelStreamingService
from app.utils.content_utils import bootstrap_history_from_db
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
_background_tasks: set[asyncio.Task] = set()
_perf_log = get_perf_logger()
@ -1552,7 +1553,7 @@ async def stream_new_chat(
# Shared threads write to team memory; private threads write to user memory.
if not stream_result.agent_called_update_memory:
if visibility == ChatVisibility.SEARCH_SPACE:
asyncio.create_task(
task = asyncio.create_task(
extract_and_save_team_memory(
user_message=user_query,
search_space_id=search_space_id,
@ -1560,14 +1561,18 @@ async def stream_new_chat(
author_display_name=current_user_display_name,
)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
elif user_id:
asyncio.create_task(
task = asyncio.create_task(
extract_and_save_memory(
user_message=user_query,
user_id=user_id,
llm=llm,
)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
# Finish the step and message
yield streaming_service.format_finish_step()

View file

@ -961,6 +961,7 @@ async def index_google_drive_files(
vision_llm = None
if connector_enable_vision_llm:
from app.services.llm_service import get_vision_llm
vision_llm = await get_vision_llm(session, search_space_id)
drive_client = GoogleDriveClient(
session, connector_id, credentials=pre_built_credentials
@ -1168,6 +1169,7 @@ async def index_google_drive_single_file(
vision_llm = None
if connector_enable_vision_llm:
from app.services.llm_service import get_vision_llm
vision_llm = await get_vision_llm(session, search_space_id)
drive_client = GoogleDriveClient(
session, connector_id, credentials=pre_built_credentials
@ -1306,6 +1308,7 @@ async def index_google_drive_selected_files(
vision_llm = None
if connector_enable_vision_llm:
from app.services.llm_service import get_vision_llm
vision_llm = await get_vision_llm(session, search_space_id)
drive_client = GoogleDriveClient(
session, connector_id, credentials=pre_built_credentials

View file

@ -1360,7 +1360,9 @@ async def index_uploaded_files(
try:
content, content_hash = await _compute_file_content_hash(
temp_path, filename, search_space_id,
temp_path,
filename,
search_space_id,
vision_llm=vision_llm_instance,
)
except Exception as e:

View file

@ -656,6 +656,7 @@ async def index_onedrive_files(
vision_llm = None
if connector_enable_vision_llm:
from app.services.llm_service import get_vision_llm
vision_llm = await get_vision_llm(session, search_space_id)
onedrive_client = OneDriveClient(session, connector_id)