mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-15 18:25:18 +02:00
feat: added ai file sorting
This commit is contained in:
parent
fa0b47dfca
commit
4bee367d4a
51 changed files with 1703 additions and 72 deletions
|
|
@ -4,6 +4,7 @@ import asyncio
|
|||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from uuid import UUID
|
||||
|
||||
from app.celery_app import celery_app
|
||||
|
|
@ -1551,3 +1552,121 @@ async def _index_uploaded_folder_files_async(
|
|||
heartbeat_task.cancel()
|
||||
if notification_id is not None:
|
||||
_stop_heartbeat(notification_id)
|
||||
|
||||
|
||||
# ===== AI File Sort tasks =====
|
||||
|
||||
AI_SORT_LOCK_TTL_SECONDS = 600 # 10 minutes
|
||||
_ai_sort_redis = None
|
||||
|
||||
|
||||
def _get_ai_sort_redis():
|
||||
import redis
|
||||
|
||||
global _ai_sort_redis
|
||||
if _ai_sort_redis is None:
|
||||
_ai_sort_redis = redis.from_url(config.REDIS_APP_URL, decode_responses=True)
|
||||
return _ai_sort_redis
|
||||
|
||||
|
||||
def _ai_sort_lock_key(search_space_id: int) -> str:
|
||||
return f"ai_sort:search_space:{search_space_id}:lock"
|
||||
|
||||
|
||||
@celery_app.task(name="ai_sort_search_space", bind=True, max_retries=1)
|
||||
def ai_sort_search_space_task(self, search_space_id: int, user_id: str):
|
||||
"""Full AI sort for all documents in a search space."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(_ai_sort_search_space_async(search_space_id, user_id))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _ai_sort_search_space_async(search_space_id: int, user_id: str):
|
||||
r = _get_ai_sort_redis()
|
||||
lock_key = _ai_sort_lock_key(search_space_id)
|
||||
|
||||
if not r.set(lock_key, "running", nx=True, ex=AI_SORT_LOCK_TTL_SECONDS):
|
||||
logger.info(
|
||||
"AI sort already running for search_space=%d, skipping",
|
||||
search_space_id,
|
||||
)
|
||||
return
|
||||
|
||||
t_start = time.perf_counter()
|
||||
try:
|
||||
from app.services.ai_file_sort_service import ai_sort_all_documents
|
||||
from app.services.llm_service import get_document_summary_llm
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
llm = await get_document_summary_llm(
|
||||
session, search_space_id, disable_streaming=True
|
||||
)
|
||||
if llm is None:
|
||||
logger.warning(
|
||||
"No LLM configured for search_space=%d, skipping AI sort",
|
||||
search_space_id,
|
||||
)
|
||||
return
|
||||
|
||||
sorted_count, failed_count = await ai_sort_all_documents(
|
||||
session, search_space_id, llm
|
||||
)
|
||||
elapsed = time.perf_counter() - t_start
|
||||
logger.info(
|
||||
"AI sort search_space=%d done in %.1fs: sorted=%d failed=%d",
|
||||
search_space_id,
|
||||
elapsed,
|
||||
sorted_count,
|
||||
failed_count,
|
||||
)
|
||||
finally:
|
||||
r.delete(lock_key)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="ai_sort_document", bind=True, max_retries=2, default_retry_delay=10
|
||||
)
|
||||
def ai_sort_document_task(self, search_space_id: int, user_id: str, document_id: int):
|
||||
"""Incremental AI sort for a single document after indexing."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_ai_sort_document_async(search_space_id, user_id, document_id)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _ai_sort_document_async(search_space_id: int, user_id: str, document_id: int):
|
||||
from app.db import Document
|
||||
from app.services.ai_file_sort_service import ai_sort_document
|
||||
from app.services.llm_service import get_document_summary_llm
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
document = await session.get(Document, document_id)
|
||||
if document is None:
|
||||
logger.warning("Document %d not found, skipping AI sort", document_id)
|
||||
return
|
||||
|
||||
llm = await get_document_summary_llm(
|
||||
session, search_space_id, disable_streaming=True
|
||||
)
|
||||
if llm is None:
|
||||
logger.warning(
|
||||
"No LLM for search_space=%d, skipping AI sort of doc=%d",
|
||||
search_space_id,
|
||||
document_id,
|
||||
)
|
||||
return
|
||||
|
||||
await ai_sort_document(session, document, llm)
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"AI sorted document=%d into search_space=%d",
|
||||
document_id,
|
||||
search_space_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ from app.services.new_streaming_service import VercelStreamingService
|
|||
from app.utils.content_utils import bootstrap_history_from_db
|
||||
from app.utils.perf import get_perf_logger, log_system_snapshot, trim_native_heap
|
||||
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
_perf_log = get_perf_logger()
|
||||
|
||||
|
||||
|
|
@ -1552,7 +1553,7 @@ async def stream_new_chat(
|
|||
# Shared threads write to team memory; private threads write to user memory.
|
||||
if not stream_result.agent_called_update_memory:
|
||||
if visibility == ChatVisibility.SEARCH_SPACE:
|
||||
asyncio.create_task(
|
||||
task = asyncio.create_task(
|
||||
extract_and_save_team_memory(
|
||||
user_message=user_query,
|
||||
search_space_id=search_space_id,
|
||||
|
|
@ -1560,14 +1561,18 @@ async def stream_new_chat(
|
|||
author_display_name=current_user_display_name,
|
||||
)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
elif user_id:
|
||||
asyncio.create_task(
|
||||
task = asyncio.create_task(
|
||||
extract_and_save_memory(
|
||||
user_message=user_query,
|
||||
user_id=user_id,
|
||||
llm=llm,
|
||||
)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
# Finish the step and message
|
||||
yield streaming_service.format_finish_step()
|
||||
|
|
|
|||
|
|
@ -961,6 +961,7 @@ async def index_google_drive_files(
|
|||
vision_llm = None
|
||||
if connector_enable_vision_llm:
|
||||
from app.services.llm_service import get_vision_llm
|
||||
|
||||
vision_llm = await get_vision_llm(session, search_space_id)
|
||||
drive_client = GoogleDriveClient(
|
||||
session, connector_id, credentials=pre_built_credentials
|
||||
|
|
@ -1168,6 +1169,7 @@ async def index_google_drive_single_file(
|
|||
vision_llm = None
|
||||
if connector_enable_vision_llm:
|
||||
from app.services.llm_service import get_vision_llm
|
||||
|
||||
vision_llm = await get_vision_llm(session, search_space_id)
|
||||
drive_client = GoogleDriveClient(
|
||||
session, connector_id, credentials=pre_built_credentials
|
||||
|
|
@ -1306,6 +1308,7 @@ async def index_google_drive_selected_files(
|
|||
vision_llm = None
|
||||
if connector_enable_vision_llm:
|
||||
from app.services.llm_service import get_vision_llm
|
||||
|
||||
vision_llm = await get_vision_llm(session, search_space_id)
|
||||
drive_client = GoogleDriveClient(
|
||||
session, connector_id, credentials=pre_built_credentials
|
||||
|
|
|
|||
|
|
@ -1360,7 +1360,9 @@ async def index_uploaded_files(
|
|||
|
||||
try:
|
||||
content, content_hash = await _compute_file_content_hash(
|
||||
temp_path, filename, search_space_id,
|
||||
temp_path,
|
||||
filename,
|
||||
search_space_id,
|
||||
vision_llm=vision_llm_instance,
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -656,6 +656,7 @@ async def index_onedrive_files(
|
|||
vision_llm = None
|
||||
if connector_enable_vision_llm:
|
||||
from app.services.llm_service import get_vision_llm
|
||||
|
||||
vision_llm = await get_vision_llm(session, search_space_id)
|
||||
|
||||
onedrive_client = OneDriveClient(session, connector_id)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue