diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py index 0dd683f7e..af93ddc8f 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_drive/create_file.py @@ -58,7 +58,9 @@ def create_create_google_drive_file_tool( - "Create a Google Doc called 'Meeting Notes'" - "Create a spreadsheet named 'Budget 2026' with some sample data" """ - logger.info(f"create_google_drive_file called: name='{name}', type='{file_type}'") + logger.info( + f"create_google_drive_file called: name='{name}', type='{file_type}'" + ) if db_session is None or search_space_id is None or user_id is None: return { @@ -74,7 +76,9 @@ def create_create_google_drive_file_tool( try: metadata_service = GoogleDriveToolMetadataService(db_session) - context = await metadata_service.get_creation_context(search_space_id, user_id) + context = await metadata_service.get_creation_context( + search_space_id, user_id + ) if "error" in context: logger.error(f"Failed to fetch creation context: {context['error']}") @@ -100,8 +104,12 @@ def create_create_google_drive_file_tool( } ) - decisions_raw = approval.get("decisions", []) if isinstance(approval, dict) else [] - decisions = decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] + decisions_raw = ( + approval.get("decisions", []) if isinstance(approval, dict) else [] + ) + decisions = ( + decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] + ) decisions = [d for d in decisions if isinstance(d, dict)] if not decisions: logger.warning("No approval decision received") @@ -183,7 +191,9 @@ def create_create_google_drive_file_tool( logger.info( f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}" ) - client = GoogleDriveClient(session=db_session, connector_id=actual_connector_id) + client = GoogleDriveClient( + session=db_session, connector_id=actual_connector_id + ) try: created = await client.create_file( name=final_name, @@ -203,7 +213,9 @@ def create_create_google_drive_file_tool( } raise - logger.info(f"Google Drive file created: id={created.get('id')}, name={created.get('name')}") + logger.info( + f"Google Drive file created: id={created.get('id')}, name={created.get('name')}" + ) return { "status": "success", "file_id": created.get("id"), diff --git a/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py b/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py index 600aae983..917ba3376 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_drive/trash_file.py @@ -52,7 +52,9 @@ def create_delete_google_drive_file_tool( - "Delete the 'Meeting Notes' file from Google Drive" - "Trash the 'Old Budget' spreadsheet" """ - logger.info(f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}") + logger.info( + f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}" + ) if db_session is None or search_space_id is None or user_id is None: return { @@ -103,8 +105,12 @@ def create_delete_google_drive_file_tool( } ) - decisions_raw = approval.get("decisions", []) if isinstance(approval, dict) else [] - decisions = decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] + decisions_raw = ( + approval.get("decisions", []) if isinstance(approval, dict) else [] + ) + decisions = ( + decisions_raw if isinstance(decisions_raw, list) else [decisions_raw] + ) decisions = [d for d in decisions if isinstance(d, dict)] if not decisions: logger.warning("No approval decision received") @@ -130,11 +136,16 @@ def create_delete_google_drive_file_tool( final_params = decision["args"] final_file_id = final_params.get("file_id", file_id) - final_connector_id = final_params.get("connector_id", connector_id_from_context) + final_connector_id = final_params.get( + "connector_id", connector_id_from_context + ) final_delete_from_kb = final_params.get("delete_from_kb", delete_from_kb) if not final_connector_id: - return {"status": "error", "message": "No connector found for this file."} + return { + "status": "error", + "message": "No connector found for this file.", + } from sqlalchemy.future import select @@ -174,7 +185,9 @@ def create_delete_google_drive_file_tool( } raise - logger.info(f"Google Drive file deleted (moved to trash): file_id={final_file_id}") + logger.info( + f"Google Drive file deleted (moved to trash): file_id={final_file_id}" + ) trash_result: dict[str, Any] = { "status": "success", @@ -195,7 +208,9 @@ def create_delete_google_drive_file_tool( await db_session.delete(document) await db_session.commit() deleted_from_kb = True - logger.info(f"Deleted document {document_id} from knowledge base") + logger.info( + f"Deleted document {document_id} from knowledge base" + ) else: logger.warning(f"Document {document_id} not found in KB") except Exception as e: diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index 01342e920..dffed5e86 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -47,6 +47,10 @@ from app.db import ChatVisibility from .display_image import create_display_image_tool from .generate_image import create_generate_image_tool +from .google_drive import ( + create_create_google_drive_file_tool, + create_delete_google_drive_file_tool, +) from .knowledge_base import create_search_knowledge_base_tool from .linear import ( create_create_linear_issue_tool, @@ -55,10 +59,6 @@ from .linear import ( ) from .link_preview import create_link_preview_tool from .mcp_tool import load_mcp_tools -from .google_drive import ( - create_create_google_drive_file_tool, - create_delete_google_drive_file_tool, -) from .notion import ( create_create_notion_page_tool, create_delete_notion_page_tool, diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index ec319345b..f6264506f 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -10,13 +10,12 @@ Supports loading LLM configurations from: """ import json +import logging from collections.abc import AsyncGenerator from dataclasses import dataclass from typing import Any from uuid import UUID -import logging - from langchain_core.messages import HumanMessage from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select @@ -30,7 +29,13 @@ from app.agents.new_chat.llm_config import ( load_agent_config, load_llm_config_from_yaml, ) -from app.db import ChatVisibility, Document, Report, SurfsenseDocsDocument, async_session_maker +from app.db import ( + ChatVisibility, + Document, + Report, + SurfsenseDocsDocument, + async_session_maker, +) from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE from app.services.chat_session_state_service import ( clear_ai_responding, diff --git a/surfsense_backend/tests/conftest.py b/surfsense_backend/tests/conftest.py index 868664ca8..b6d37f7fd 100644 --- a/surfsense_backend/tests/conftest.py +++ b/surfsense_backend/tests/conftest.py @@ -58,18 +58,14 @@ def backend_url() -> str: @pytest.fixture(scope="session") async def auth_token(backend_url: str) -> str: """Authenticate once per session, registering the user if needed.""" - async with httpx.AsyncClient( - base_url=backend_url, timeout=30.0 - ) as client: + async with httpx.AsyncClient(base_url=backend_url, timeout=30.0) as client: return await get_auth_token(client) @pytest.fixture(scope="session") async def search_space_id(backend_url: str, auth_token: str) -> int: """Discover the first search space belonging to the test user.""" - async with httpx.AsyncClient( - base_url=backend_url, timeout=30.0 - ) as client: + async with httpx.AsyncClient(base_url=backend_url, timeout=30.0) as client: return await get_search_space_id(client, auth_token) @@ -86,7 +82,9 @@ async def _purge_test_search_space( """ deleted = await _force_delete_documents_db(search_space_id) if deleted: - print(f"\n[purge] Deleted {deleted} stale document(s) from search space {search_space_id}") + print( + f"\n[purge] Deleted {deleted} stale document(s) from search space {search_space_id}" + ) yield @@ -100,9 +98,7 @@ def headers(auth_token: str) -> dict[str, str]: @pytest.fixture async def client(backend_url: str) -> AsyncGenerator[httpx.AsyncClient]: """Per-test async HTTP client pointing at the running backend.""" - async with httpx.AsyncClient( - base_url=backend_url, timeout=180.0 - ) as c: + async with httpx.AsyncClient(base_url=backend_url, timeout=180.0) as c: yield c diff --git a/surfsense_backend/tests/e2e/test_document_upload.py b/surfsense_backend/tests/e2e/test_document_upload.py index 5177dce01..f217dc460 100644 --- a/surfsense_backend/tests/e2e/test_document_upload.py +++ b/surfsense_backend/tests/e2e/test_document_upload.py @@ -34,6 +34,7 @@ pytestmark = pytest.mark.document # Helpers local to this module # --------------------------------------------------------------------------- + def _assert_document_ready(doc: dict, *, expected_filename: str) -> None: """Common assertions for a successfully processed document.""" assert doc["title"] == expected_filename @@ -59,7 +60,9 @@ class TestTxtFileUpload: search_space_id: int, cleanup_doc_ids: list[int], ): - resp = await upload_file(client, headers, "sample.txt", search_space_id=search_space_id) + resp = await upload_file( + client, headers, "sample.txt", search_space_id=search_space_id + ) assert resp.status_code == 200 body = resp.json() @@ -74,12 +77,16 @@ class TestTxtFileUpload: search_space_id: int, cleanup_doc_ids: list[int], ): - resp = await upload_file(client, headers, "sample.txt", search_space_id=search_space_id) + resp = await upload_file( + client, headers, "sample.txt", search_space_id=search_space_id + ) assert resp.status_code == 200 doc_ids = resp.json()["document_ids"] cleanup_doc_ids.extend(doc_ids) - statuses = await poll_document_status(client, headers, doc_ids, search_space_id=search_space_id) + statuses = await poll_document_status( + client, headers, doc_ids, search_space_id=search_space_id + ) for did in doc_ids: assert statuses[did]["status"]["state"] == "ready" @@ -90,11 +97,15 @@ class TestTxtFileUpload: search_space_id: int, cleanup_doc_ids: list[int], ): - resp = await upload_file(client, headers, "sample.txt", search_space_id=search_space_id) + resp = await upload_file( + client, headers, "sample.txt", search_space_id=search_space_id + ) doc_ids = resp.json()["document_ids"] cleanup_doc_ids.extend(doc_ids) - await poll_document_status(client, headers, doc_ids, search_space_id=search_space_id) + await poll_document_status( + client, headers, doc_ids, search_space_id=search_space_id + ) doc = await get_document(client, headers, doc_ids[0]) _assert_document_ready(doc, expected_filename="sample.txt") @@ -116,12 +127,16 @@ class TestMarkdownFileUpload: search_space_id: int, cleanup_doc_ids: list[int], ): - resp = await upload_file(client, headers, "sample.md", search_space_id=search_space_id) + resp = await upload_file( + client, headers, "sample.md", search_space_id=search_space_id + ) assert resp.status_code == 200 doc_ids = resp.json()["document_ids"] cleanup_doc_ids.extend(doc_ids) - statuses = await poll_document_status(client, headers, doc_ids, search_space_id=search_space_id) + statuses = await poll_document_status( + client, headers, doc_ids, search_space_id=search_space_id + ) for did in doc_ids: assert statuses[did]["status"]["state"] == "ready" @@ -132,11 +147,15 @@ class TestMarkdownFileUpload: search_space_id: int, cleanup_doc_ids: list[int], ): - resp = await upload_file(client, headers, "sample.md", search_space_id=search_space_id) + resp = await upload_file( + client, headers, "sample.md", search_space_id=search_space_id + ) doc_ids = resp.json()["document_ids"] cleanup_doc_ids.extend(doc_ids) - await poll_document_status(client, headers, doc_ids, search_space_id=search_space_id) + await poll_document_status( + client, headers, doc_ids, search_space_id=search_space_id + ) doc = await get_document(client, headers, doc_ids[0]) _assert_document_ready(doc, expected_filename="sample.md") @@ -158,7 +177,9 @@ class TestPdfFileUpload: search_space_id: int, cleanup_doc_ids: list[int], ): - resp = await upload_file(client, headers, "sample.pdf", search_space_id=search_space_id) + resp = await upload_file( + client, headers, "sample.pdf", search_space_id=search_space_id + ) assert resp.status_code == 200 doc_ids = resp.json()["document_ids"] cleanup_doc_ids.extend(doc_ids) @@ -176,7 +197,9 @@ class TestPdfFileUpload: search_space_id: int, cleanup_doc_ids: list[int], ): - resp = await upload_file(client, headers, "sample.pdf", search_space_id=search_space_id) + resp = await upload_file( + client, headers, "sample.pdf", search_space_id=search_space_id + ) doc_ids = resp.json()["document_ids"] cleanup_doc_ids.extend(doc_ids) @@ -209,7 +232,10 @@ class TestMultiFileUpload: cleanup_doc_ids: list[int], ): resp = await upload_multiple_files( - client, headers, ["sample.txt", "sample.md"], search_space_id=search_space_id + client, + headers, + ["sample.txt", "sample.md"], + search_space_id=search_space_id, ) assert resp.status_code == 200 @@ -226,12 +252,17 @@ class TestMultiFileUpload: cleanup_doc_ids: list[int], ): resp = await upload_multiple_files( - client, headers, ["sample.txt", "sample.md"], search_space_id=search_space_id + client, + headers, + ["sample.txt", "sample.md"], + search_space_id=search_space_id, ) doc_ids = resp.json()["document_ids"] cleanup_doc_ids.extend(doc_ids) - statuses = await poll_document_status(client, headers, doc_ids, search_space_id=search_space_id) + statuses = await poll_document_status( + client, headers, doc_ids, search_space_id=search_space_id + ) for did in doc_ids: assert statuses[did]["status"]["state"] == "ready" @@ -255,15 +286,21 @@ class TestDuplicateFileUpload: cleanup_doc_ids: list[int], ): # First upload - resp1 = await upload_file(client, headers, "sample.txt", search_space_id=search_space_id) + resp1 = await upload_file( + client, headers, "sample.txt", search_space_id=search_space_id + ) assert resp1.status_code == 200 first_ids = resp1.json()["document_ids"] cleanup_doc_ids.extend(first_ids) - await poll_document_status(client, headers, first_ids, search_space_id=search_space_id) + await poll_document_status( + client, headers, first_ids, search_space_id=search_space_id + ) # Second upload of the same file - resp2 = await upload_file(client, headers, "sample.txt", search_space_id=search_space_id) + resp2 = await upload_file( + client, headers, "sample.txt", search_space_id=search_space_id + ) assert resp2.status_code == 200 body2 = resp2.json() @@ -292,11 +329,15 @@ class TestDuplicateContentDetection: tmp_path: Path, ): # First upload - resp1 = await upload_file(client, headers, "sample.txt", search_space_id=search_space_id) + resp1 = await upload_file( + client, headers, "sample.txt", search_space_id=search_space_id + ) assert resp1.status_code == 200 first_ids = resp1.json()["document_ids"] cleanup_doc_ids.extend(first_ids) - await poll_document_status(client, headers, first_ids, search_space_id=search_space_id) + await poll_document_status( + client, headers, first_ids, search_space_id=search_space_id + ) # Copy fixture content to a differently named temp file src = FIXTURES_DIR / "sample.txt" @@ -315,7 +356,9 @@ class TestDuplicateContentDetection: cleanup_doc_ids.extend(second_ids) if second_ids: - statuses = await poll_document_status(client, headers, second_ids, search_space_id=search_space_id) + statuses = await poll_document_status( + client, headers, second_ids, search_space_id=search_space_id + ) for did in second_ids: assert statuses[did]["status"]["state"] == "failed" assert "duplicate" in ( @@ -338,7 +381,9 @@ class TestEmptyFileUpload: search_space_id: int, cleanup_doc_ids: list[int], ): - resp = await upload_file(client, headers, "empty.pdf", search_space_id=search_space_id) + resp = await upload_file( + client, headers, "empty.pdf", search_space_id=search_space_id + ) assert resp.status_code == 200 doc_ids = resp.json()["document_ids"] @@ -414,9 +459,13 @@ class TestDocumentDeletion: headers: dict[str, str], search_space_id: int, ): - resp = await upload_file(client, headers, "sample.txt", search_space_id=search_space_id) + resp = await upload_file( + client, headers, "sample.txt", search_space_id=search_space_id + ) doc_ids = resp.json()["document_ids"] - await poll_document_status(client, headers, doc_ids, search_space_id=search_space_id) + await poll_document_status( + client, headers, doc_ids, search_space_id=search_space_id + ) del_resp = await delete_document(client, headers, doc_ids[0]) assert del_resp.status_code == 200 @@ -443,7 +492,9 @@ class TestDeleteWhileProcessing: search_space_id: int, cleanup_doc_ids: list[int], ): - resp = await upload_file(client, headers, "sample.pdf", search_space_id=search_space_id) + resp = await upload_file( + client, headers, "sample.pdf", search_space_id=search_space_id + ) assert resp.status_code == 200 doc_ids = resp.json()["document_ids"] cleanup_doc_ids.extend(doc_ids) @@ -473,7 +524,9 @@ class TestStatusPolling: search_space_id: int, cleanup_doc_ids: list[int], ): - resp = await upload_file(client, headers, "sample.txt", search_space_id=search_space_id) + resp = await upload_file( + client, headers, "sample.txt", search_space_id=search_space_id + ) doc_ids = resp.json()["document_ids"] cleanup_doc_ids.extend(doc_ids) @@ -501,4 +554,6 @@ class TestStatusPolling: "failed", } - await poll_document_status(client, headers, doc_ids, search_space_id=search_space_id) + await poll_document_status( + client, headers, doc_ids, search_space_id=search_space_id + )