diff --git a/surfsense_backend/alembic/versions/39_add_rbac_tables.py b/surfsense_backend/alembic/versions/39_add_rbac_tables.py new file mode 100644 index 000000000..ac2df0df2 --- /dev/null +++ b/surfsense_backend/alembic/versions/39_add_rbac_tables.py @@ -0,0 +1,179 @@ +"""Add RBAC tables for search space access control + +Revision ID: 39 +Revises: 38 +Create Date: 2025-11-27 00:00:00.000000 + +This migration adds: +- Permission enum for granular access control +- search_space_roles table for custom roles per search space +- search_space_memberships table for user-searchspace-role relationships +- search_space_invites table for invite links +""" + +from collections.abc import Sequence + +from sqlalchemy import inspect + +from alembic import op + +revision: str = "39" +down_revision: str | None = "38" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema - add RBAC tables for search space access control.""" + + # Create search_space_roles table + op.execute( + """ + CREATE TABLE IF NOT EXISTS search_space_roles ( + id SERIAL PRIMARY KEY, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + name VARCHAR(100) NOT NULL, + description VARCHAR(500), + permissions TEXT[] NOT NULL DEFAULT '{}', + is_default BOOLEAN NOT NULL DEFAULT FALSE, + is_system_role BOOLEAN NOT NULL DEFAULT FALSE, + search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE, + CONSTRAINT uq_searchspace_role_name UNIQUE (search_space_id, name) + ); + """ + ) + + # Create search_space_invites table (needs to be created before memberships due to FK) + op.execute( + """ + CREATE TABLE IF NOT EXISTS search_space_invites ( + id SERIAL PRIMARY KEY, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + invite_code VARCHAR(64) NOT NULL UNIQUE, + search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE, + role_id INTEGER REFERENCES search_space_roles(id) ON DELETE SET NULL, + created_by_id UUID REFERENCES "user"(id) ON DELETE SET NULL, + expires_at TIMESTAMPTZ, + max_uses INTEGER, + uses_count INTEGER NOT NULL DEFAULT 0, + is_active BOOLEAN NOT NULL DEFAULT TRUE, + name VARCHAR(100) + ); + """ + ) + + # Create search_space_memberships table + op.execute( + """ + CREATE TABLE IF NOT EXISTS search_space_memberships ( + id SERIAL PRIMARY KEY, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE, + search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE, + role_id INTEGER REFERENCES search_space_roles(id) ON DELETE SET NULL, + is_owner BOOLEAN NOT NULL DEFAULT FALSE, + joined_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + invited_by_invite_id INTEGER REFERENCES search_space_invites(id) ON DELETE SET NULL, + CONSTRAINT uq_user_searchspace_membership UNIQUE (user_id, search_space_id) + ); + """ + ) + + # Get connection and inspector for checking existing indexes + conn = op.get_bind() + inspector = inspect(conn) + + # Create indexes for search_space_roles + existing_indexes = [ + idx["name"] for idx in inspector.get_indexes("search_space_roles") + ] + if "ix_search_space_roles_id" not in existing_indexes: + op.create_index("ix_search_space_roles_id", "search_space_roles", ["id"]) + if "ix_search_space_roles_created_at" not in existing_indexes: + op.create_index( + "ix_search_space_roles_created_at", "search_space_roles", ["created_at"] + ) + if "ix_search_space_roles_name" not in existing_indexes: + op.create_index("ix_search_space_roles_name", "search_space_roles", ["name"]) + + # Create indexes for search_space_memberships + existing_indexes = [ + idx["name"] for idx in inspector.get_indexes("search_space_memberships") + ] + if "ix_search_space_memberships_id" not in existing_indexes: + op.create_index( + "ix_search_space_memberships_id", "search_space_memberships", ["id"] + ) + if "ix_search_space_memberships_created_at" not in existing_indexes: + op.create_index( + "ix_search_space_memberships_created_at", + "search_space_memberships", + ["created_at"], + ) + if "ix_search_space_memberships_user_id" not in existing_indexes: + op.create_index( + "ix_search_space_memberships_user_id", + "search_space_memberships", + ["user_id"], + ) + if "ix_search_space_memberships_search_space_id" not in existing_indexes: + op.create_index( + "ix_search_space_memberships_search_space_id", + "search_space_memberships", + ["search_space_id"], + ) + + # Create indexes for search_space_invites + existing_indexes = [ + idx["name"] for idx in inspector.get_indexes("search_space_invites") + ] + if "ix_search_space_invites_id" not in existing_indexes: + op.create_index("ix_search_space_invites_id", "search_space_invites", ["id"]) + if "ix_search_space_invites_created_at" not in existing_indexes: + op.create_index( + "ix_search_space_invites_created_at", "search_space_invites", ["created_at"] + ) + if "ix_search_space_invites_invite_code" not in existing_indexes: + op.create_index( + "ix_search_space_invites_invite_code", + "search_space_invites", + ["invite_code"], + ) + + +def downgrade() -> None: + """Downgrade schema - remove RBAC tables.""" + + # Drop indexes for search_space_memberships + op.drop_index( + "ix_search_space_memberships_search_space_id", + table_name="search_space_memberships", + ) + op.drop_index( + "ix_search_space_memberships_user_id", table_name="search_space_memberships" + ) + op.drop_index( + "ix_search_space_memberships_created_at", table_name="search_space_memberships" + ) + op.drop_index( + "ix_search_space_memberships_id", table_name="search_space_memberships" + ) + + # Drop indexes for search_space_invites + op.drop_index( + "ix_search_space_invites_invite_code", table_name="search_space_invites" + ) + op.drop_index( + "ix_search_space_invites_created_at", table_name="search_space_invites" + ) + op.drop_index("ix_search_space_invites_id", table_name="search_space_invites") + + # Drop indexes for search_space_roles + op.drop_index("ix_search_space_roles_name", table_name="search_space_roles") + op.drop_index("ix_search_space_roles_created_at", table_name="search_space_roles") + op.drop_index("ix_search_space_roles_id", table_name="search_space_roles") + + # Drop tables in correct order (respecting foreign key constraints) + op.drop_table("search_space_memberships") + op.drop_table("search_space_invites") + op.drop_table("search_space_roles") diff --git a/surfsense_backend/alembic/versions/40_move_llm_preferences_to_searchspace.py b/surfsense_backend/alembic/versions/40_move_llm_preferences_to_searchspace.py new file mode 100644 index 000000000..1067cffcc --- /dev/null +++ b/surfsense_backend/alembic/versions/40_move_llm_preferences_to_searchspace.py @@ -0,0 +1,63 @@ +"""Move LLM preferences from user-level to search space level + +Revision ID: 40 +Revises: 39 +Create Date: 2024-11-27 + +This migration moves LLM preferences (long_context_llm_id, fast_llm_id, strategic_llm_id) +from the user_search_space_preferences table to the searchspaces table itself. + +This change supports the RBAC model where LLM preferences are shared by all members +of a search space, rather than being per-user. +""" + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "40" +down_revision = "39" +branch_labels = None +depends_on = None + + +def upgrade(): + # Add LLM preference columns to searchspaces table + op.add_column( + "searchspaces", + sa.Column("long_context_llm_id", sa.Integer(), nullable=True), + ) + op.add_column( + "searchspaces", + sa.Column("fast_llm_id", sa.Integer(), nullable=True), + ) + op.add_column( + "searchspaces", + sa.Column("strategic_llm_id", sa.Integer(), nullable=True), + ) + + # Migrate existing preferences from user_search_space_preferences to searchspaces + # We take the owner's preferences (the user who created the search space) + connection = op.get_bind() + + # Get all search spaces and their owner's preferences + connection.execute( + sa.text(""" + UPDATE searchspaces ss + SET + long_context_llm_id = usp.long_context_llm_id, + fast_llm_id = usp.fast_llm_id, + strategic_llm_id = usp.strategic_llm_id + FROM user_search_space_preferences usp + WHERE ss.id = usp.search_space_id + AND ss.user_id = usp.user_id + """) + ) + + +def downgrade(): + # Remove LLM preference columns from searchspaces table + op.drop_column("searchspaces", "strategic_llm_id") + op.drop_column("searchspaces", "fast_llm_id") + op.drop_column("searchspaces", "long_context_llm_id") diff --git a/surfsense_backend/app/agents/researcher/nodes.py b/surfsense_backend/app/agents/researcher/nodes.py index 223d82b67..c53e3348f 100644 --- a/surfsense_backend/app/agents/researcher/nodes.py +++ b/surfsense_backend/app/agents/researcher/nodes.py @@ -11,7 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession # Additional imports for document fetching from sqlalchemy.future import select -from app.db import Document, SearchSpace +from app.db import Document from app.services.connector_service import ConnectorService from app.services.query_service import QueryService @@ -92,19 +92,18 @@ def extract_sources_from_documents( async def fetch_documents_by_ids( - document_ids: list[int], user_id: str, db_session: AsyncSession + document_ids: list[int], search_space_id: int, db_session: AsyncSession ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """ - Fetch documents by their IDs with ownership check using DOCUMENTS mode approach. + Fetch documents by their IDs within a search space. - This function ensures that only documents belonging to the user are fetched, - providing security by checking ownership through SearchSpace association. + This function ensures that only documents belonging to the search space are fetched. Similar to SearchMode.DOCUMENTS, it fetches full documents and concatenates their chunks. Also creates source objects for UI display, grouped by document type. Args: document_ids: List of document IDs to fetch - user_id: The user ID to check ownership + search_space_id: The search space ID to filter by db_session: The database session Returns: @@ -114,11 +113,12 @@ async def fetch_documents_by_ids( return [], [] try: - # Query documents with ownership check + # Query documents filtered by search space result = await db_session.execute( - select(Document) - .join(SearchSpace) - .filter(Document.id.in_(document_ids), SearchSpace.user_id == user_id) + select(Document).filter( + Document.id.in_(document_ids), + Document.search_space_id == search_space_id, + ) ) documents = result.scalars().all() @@ -515,7 +515,6 @@ async def fetch_documents_by_ids( async def fetch_relevant_documents( research_questions: list[str], - user_id: str, search_space_id: int, db_session: AsyncSession, connectors_to_search: list[str], @@ -536,7 +535,6 @@ async def fetch_relevant_documents( Args: research_questions: List of research questions to find documents for - user_id: The user ID search_space_id: The search space ID db_session: The database session connectors_to_search: List of connectors to search @@ -619,7 +617,6 @@ async def fetch_relevant_documents( youtube_chunks, ) = await connector_service.search_youtube( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, top_k=top_k, search_mode=search_mode, @@ -646,7 +643,6 @@ async def fetch_relevant_documents( extension_chunks, ) = await connector_service.search_extension( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, top_k=top_k, search_mode=search_mode, @@ -673,7 +669,6 @@ async def fetch_relevant_documents( crawled_urls_chunks, ) = await connector_service.search_crawled_urls( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, top_k=top_k, search_mode=search_mode, @@ -697,7 +692,6 @@ async def fetch_relevant_documents( elif connector == "FILE": source_object, files_chunks = await connector_service.search_files( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, top_k=top_k, search_mode=search_mode, @@ -721,7 +715,6 @@ async def fetch_relevant_documents( elif connector == "SLACK_CONNECTOR": source_object, slack_chunks = await connector_service.search_slack( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, top_k=top_k, search_mode=search_mode, @@ -748,7 +741,6 @@ async def fetch_relevant_documents( notion_chunks, ) = await connector_service.search_notion( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, top_k=top_k, search_mode=search_mode, @@ -775,7 +767,6 @@ async def fetch_relevant_documents( github_chunks, ) = await connector_service.search_github( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, top_k=top_k, search_mode=search_mode, @@ -802,7 +793,6 @@ async def fetch_relevant_documents( linear_chunks, ) = await connector_service.search_linear( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, top_k=top_k, search_mode=search_mode, @@ -829,7 +819,6 @@ async def fetch_relevant_documents( tavily_chunks, ) = await connector_service.search_tavily( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, top_k=top_k, ) @@ -855,7 +844,6 @@ async def fetch_relevant_documents( searx_chunks, ) = await connector_service.search_searxng( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, top_k=top_k, ) @@ -881,7 +869,6 @@ async def fetch_relevant_documents( linkup_chunks, ) = await connector_service.search_linkup( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, mode=linkup_mode, ) @@ -907,7 +894,6 @@ async def fetch_relevant_documents( baidu_chunks, ) = await connector_service.search_baidu( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, top_k=top_k, ) @@ -933,7 +919,6 @@ async def fetch_relevant_documents( discord_chunks, ) = await connector_service.search_discord( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, top_k=top_k, search_mode=search_mode, @@ -955,7 +940,6 @@ async def fetch_relevant_documents( elif connector == "JIRA_CONNECTOR": source_object, jira_chunks = await connector_service.search_jira( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, top_k=top_k, search_mode=search_mode, @@ -981,7 +965,6 @@ async def fetch_relevant_documents( calendar_chunks, ) = await connector_service.search_google_calendar( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, top_k=top_k, search_mode=search_mode, @@ -1007,7 +990,6 @@ async def fetch_relevant_documents( airtable_chunks, ) = await connector_service.search_airtable( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, top_k=top_k, search_mode=search_mode, @@ -1033,7 +1015,6 @@ async def fetch_relevant_documents( gmail_chunks, ) = await connector_service.search_google_gmail( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, top_k=top_k, search_mode=search_mode, @@ -1059,7 +1040,6 @@ async def fetch_relevant_documents( confluence_chunks, ) = await connector_service.search_confluence( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, top_k=top_k, search_mode=search_mode, @@ -1085,7 +1065,6 @@ async def fetch_relevant_documents( clickup_chunks, ) = await connector_service.search_clickup( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, top_k=top_k, search_mode=search_mode, @@ -1112,7 +1091,6 @@ async def fetch_relevant_documents( luma_chunks, ) = await connector_service.search_luma( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, top_k=top_k, search_mode=search_mode, @@ -1139,7 +1117,6 @@ async def fetch_relevant_documents( elasticsearch_chunks, ) = await connector_service.search_elasticsearch( user_query=reformulated_query, - user_id=user_id, search_space_id=search_space_id, top_k=top_k, search_mode=search_mode, @@ -1315,7 +1292,6 @@ async def reformulate_user_query( reformulated_query = await QueryService.reformulate_query_with_chat_history( user_query=user_query, session=state.db_session, - user_id=configuration.user_id, search_space_id=configuration.search_space_id, chat_history_str=chat_history_str, ) @@ -1389,7 +1365,7 @@ async def handle_qna_workflow( user_selected_documents, ) = await fetch_documents_by_ids( document_ids=configuration.document_ids_to_add_in_context, - user_id=configuration.user_id, + search_space_id=configuration.search_space_id, db_session=state.db_session, ) @@ -1404,7 +1380,7 @@ async def handle_qna_workflow( # Create connector service using state db_session connector_service = ConnectorService( - state.db_session, user_id=configuration.user_id + state.db_session, search_space_id=configuration.search_space_id ) await connector_service.initialize_counter() @@ -1413,7 +1389,6 @@ async def handle_qna_workflow( relevant_documents = await fetch_relevant_documents( research_questions=research_questions, - user_id=configuration.user_id, search_space_id=configuration.search_space_id, db_session=state.db_session, connectors_to_search=configuration.connectors_to_search, @@ -1459,7 +1434,6 @@ async def handle_qna_workflow( "user_query": user_query, # Use the reformulated query "reformulated_query": reformulated_query, "relevant_documents": all_documents, # Use combined documents - "user_id": configuration.user_id, "search_space_id": configuration.search_space_id, "language": configuration.language, } @@ -1551,12 +1525,11 @@ async def generate_further_questions( Returns: Dict containing the further questions in the "further_questions" key for state update. """ - from app.services.llm_service import get_user_fast_llm + from app.services.llm_service import get_fast_llm # Get configuration and state data configuration = Configuration.from_runnable_config(config) chat_history = state.chat_history - user_id = configuration.user_id search_space_id = configuration.search_space_id streaming_service = state.streaming_service @@ -1571,10 +1544,10 @@ async def generate_further_questions( } ) - # Get user's fast LLM - llm = await get_user_fast_llm(state.db_session, user_id, search_space_id) + # Get search space's fast LLM + llm = await get_fast_llm(state.db_session, search_space_id) if not llm: - error_message = f"No fast LLM configured for user {user_id} in search space {search_space_id}" + error_message = f"No fast LLM configured for search space {search_space_id}" print(error_message) writer({"yield_value": streaming_service.format_error(error_message)}) diff --git a/surfsense_backend/app/agents/researcher/qna_agent/configuration.py b/surfsense_backend/app/agents/researcher/qna_agent/configuration.py index ea107a575..e7dd9175e 100644 --- a/surfsense_backend/app/agents/researcher/qna_agent/configuration.py +++ b/surfsense_backend/app/agents/researcher/qna_agent/configuration.py @@ -18,7 +18,6 @@ class Configuration: relevant_documents: list[ Any ] # Documents provided directly to the agent for answering - user_id: str # User identifier search_space_id: int # Search space identifier language: str | None = None # Language for responses diff --git a/surfsense_backend/app/agents/researcher/qna_agent/nodes.py b/surfsense_backend/app/agents/researcher/qna_agent/nodes.py index 3112a581a..37bdbc362 100644 --- a/surfsense_backend/app/agents/researcher/qna_agent/nodes.py +++ b/surfsense_backend/app/agents/researcher/qna_agent/nodes.py @@ -142,13 +142,12 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any Returns: Dict containing the final answer in the "final_answer" key. """ - from app.services.llm_service import get_user_fast_llm + from app.services.llm_service import get_fast_llm # Get configuration and relevant documents from configuration configuration = Configuration.from_runnable_config(config) documents = state.reranked_documents user_query = configuration.user_query - user_id = configuration.user_id search_space_id = configuration.search_space_id language = configuration.language @@ -178,10 +177,10 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any else "" ) - # Get user's fast LLM - llm = await get_user_fast_llm(state.db_session, user_id, search_space_id) + # Get search space's fast LLM + llm = await get_fast_llm(state.db_session, search_space_id) if not llm: - error_message = f"No fast LLM configured for user {user_id} in search space {search_space_id}" + error_message = f"No fast LLM configured for search space {search_space_id}" print(error_message) raise RuntimeError(error_message) diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 06abb7a39..6195bec87 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -131,6 +131,169 @@ class LogStatus(str, Enum): FAILED = "FAILED" +class Permission(str, Enum): + """ + Granular permissions for search space resources. + Use '*' (FULL_ACCESS) to grant all permissions. + """ + + # Documents + DOCUMENTS_CREATE = "documents:create" + DOCUMENTS_READ = "documents:read" + DOCUMENTS_UPDATE = "documents:update" + DOCUMENTS_DELETE = "documents:delete" + + # Chats + CHATS_CREATE = "chats:create" + CHATS_READ = "chats:read" + CHATS_UPDATE = "chats:update" + CHATS_DELETE = "chats:delete" + + # LLM Configs + LLM_CONFIGS_CREATE = "llm_configs:create" + LLM_CONFIGS_READ = "llm_configs:read" + LLM_CONFIGS_UPDATE = "llm_configs:update" + LLM_CONFIGS_DELETE = "llm_configs:delete" + + # Podcasts + PODCASTS_CREATE = "podcasts:create" + PODCASTS_READ = "podcasts:read" + PODCASTS_UPDATE = "podcasts:update" + PODCASTS_DELETE = "podcasts:delete" + + # Connectors + CONNECTORS_CREATE = "connectors:create" + CONNECTORS_READ = "connectors:read" + CONNECTORS_UPDATE = "connectors:update" + CONNECTORS_DELETE = "connectors:delete" + + # Logs + LOGS_READ = "logs:read" + LOGS_DELETE = "logs:delete" + + # Members + MEMBERS_INVITE = "members:invite" + MEMBERS_VIEW = "members:view" + MEMBERS_REMOVE = "members:remove" + MEMBERS_MANAGE_ROLES = "members:manage_roles" + + # Roles + ROLES_CREATE = "roles:create" + ROLES_READ = "roles:read" + ROLES_UPDATE = "roles:update" + ROLES_DELETE = "roles:delete" + + # Search Space Settings + SETTINGS_VIEW = "settings:view" + SETTINGS_UPDATE = "settings:update" + SETTINGS_DELETE = "settings:delete" # Delete the entire search space + + # Full access wildcard + FULL_ACCESS = "*" + + +# Predefined role permission sets for convenience +DEFAULT_ROLE_PERMISSIONS = { + "Owner": [Permission.FULL_ACCESS.value], + "Admin": [ + # Documents + Permission.DOCUMENTS_CREATE.value, + Permission.DOCUMENTS_READ.value, + Permission.DOCUMENTS_UPDATE.value, + Permission.DOCUMENTS_DELETE.value, + # Chats + Permission.CHATS_CREATE.value, + Permission.CHATS_READ.value, + Permission.CHATS_UPDATE.value, + Permission.CHATS_DELETE.value, + # LLM Configs + Permission.LLM_CONFIGS_CREATE.value, + Permission.LLM_CONFIGS_READ.value, + Permission.LLM_CONFIGS_UPDATE.value, + Permission.LLM_CONFIGS_DELETE.value, + # Podcasts + Permission.PODCASTS_CREATE.value, + Permission.PODCASTS_READ.value, + Permission.PODCASTS_UPDATE.value, + Permission.PODCASTS_DELETE.value, + # Connectors + Permission.CONNECTORS_CREATE.value, + Permission.CONNECTORS_READ.value, + Permission.CONNECTORS_UPDATE.value, + Permission.CONNECTORS_DELETE.value, + # Logs + Permission.LOGS_READ.value, + Permission.LOGS_DELETE.value, + # Members + Permission.MEMBERS_INVITE.value, + Permission.MEMBERS_VIEW.value, + Permission.MEMBERS_REMOVE.value, + Permission.MEMBERS_MANAGE_ROLES.value, + # Roles + Permission.ROLES_CREATE.value, + Permission.ROLES_READ.value, + Permission.ROLES_UPDATE.value, + Permission.ROLES_DELETE.value, + # Settings (no delete) + Permission.SETTINGS_VIEW.value, + Permission.SETTINGS_UPDATE.value, + ], + "Editor": [ + # Documents + Permission.DOCUMENTS_CREATE.value, + Permission.DOCUMENTS_READ.value, + Permission.DOCUMENTS_UPDATE.value, + Permission.DOCUMENTS_DELETE.value, + # Chats + Permission.CHATS_CREATE.value, + Permission.CHATS_READ.value, + Permission.CHATS_UPDATE.value, + Permission.CHATS_DELETE.value, + # LLM Configs (read only) + Permission.LLM_CONFIGS_READ.value, + Permission.LLM_CONFIGS_CREATE.value, + Permission.LLM_CONFIGS_UPDATE.value, + # Podcasts + Permission.PODCASTS_CREATE.value, + Permission.PODCASTS_READ.value, + Permission.PODCASTS_UPDATE.value, + Permission.PODCASTS_DELETE.value, + # Connectors (full access for editors) + Permission.CONNECTORS_CREATE.value, + Permission.CONNECTORS_READ.value, + Permission.CONNECTORS_UPDATE.value, + # Logs + Permission.LOGS_READ.value, + # Members (view only) + Permission.MEMBERS_VIEW.value, + # Roles (read only) + Permission.ROLES_READ.value, + # Settings (view only) + Permission.SETTINGS_VIEW.value, + ], + "Viewer": [ + # Documents (read only) + Permission.DOCUMENTS_READ.value, + # Chats (read only) + Permission.CHATS_READ.value, + # LLM Configs (read only) + Permission.LLM_CONFIGS_READ.value, + # Podcasts (read only) + Permission.PODCASTS_READ.value, + # Connectors (read only) + Permission.CONNECTORS_READ.value, + # Logs (read only) + Permission.LOGS_READ.value, + # Members (view only) + Permission.MEMBERS_VIEW.value, + # Roles (read only) + Permission.ROLES_READ.value, + # Settings (view only) + Permission.SETTINGS_VIEW.value, + ], +} + + class Base(DeclarativeBase): pass @@ -230,6 +393,13 @@ class SearchSpace(BaseModel, TimestampMixin): qna_custom_instructions = Column( Text, nullable=True, default="" ) # User's custom instructions + + # Search space-level LLM preferences (shared by all members) + # Note: These can be negative IDs for global configs (from YAML) or positive IDs for custom configs (from DB) + long_context_llm_id = Column(Integer, nullable=True) + fast_llm_id = Column(Integer, nullable=True) + strategic_llm_id = Column(Integer, nullable=True) + user_id = Column( UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False ) @@ -277,6 +447,26 @@ class SearchSpace(BaseModel, TimestampMixin): cascade="all, delete-orphan", ) + # RBAC relationships + roles = relationship( + "SearchSpaceRole", + back_populates="search_space", + order_by="SearchSpaceRole.id", + cascade="all, delete-orphan", + ) + memberships = relationship( + "SearchSpaceMembership", + back_populates="search_space", + order_by="SearchSpaceMembership.id", + cascade="all, delete-orphan", + ) + invites = relationship( + "SearchSpaceInvite", + back_populates="search_space", + order_by="SearchSpaceInvite.id", + cascade="all, delete-orphan", + ) + class SearchSourceConnector(BaseModel, TimestampMixin): __tablename__ = "search_source_connectors" @@ -368,13 +558,6 @@ class UserSearchSpacePreference(BaseModel, TimestampMixin): user = relationship("User", back_populates="search_space_preferences") search_space = relationship("SearchSpace", back_populates="user_preferences") - # Note: Relationships removed because foreign keys no longer exist - # Global configs (negative IDs) don't exist in llm_configs table - # Application code manually fetches configs when needed - # long_context_llm = relationship("LLMConfig", foreign_keys=[long_context_llm_id], post_update=True) - # fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True) - # strategic_llm = relationship("LLMConfig", foreign_keys=[strategic_llm_id], post_update=True) - class Log(BaseModel, TimestampMixin): __tablename__ = "logs" @@ -393,6 +576,140 @@ class Log(BaseModel, TimestampMixin): search_space = relationship("SearchSpace", back_populates="logs") +class SearchSpaceRole(BaseModel, TimestampMixin): + """ + Custom roles that can be defined per search space. + Each search space can have multiple roles with different permission sets. + """ + + __tablename__ = "search_space_roles" + __table_args__ = ( + UniqueConstraint( + "search_space_id", + "name", + name="uq_searchspace_role_name", + ), + ) + + name = Column(String(100), nullable=False, index=True) + description = Column(String(500), nullable=True) + # List of Permission enum values (e.g., ["documents:read", "chats:create"]) + permissions = Column(ARRAY(String), nullable=False, default=[]) + # Whether this role is assigned to new members by default when they join via invite + is_default = Column(Boolean, nullable=False, default=False) + # System roles (Owner, Admin, Editor, Viewer) cannot be deleted + is_system_role = Column(Boolean, nullable=False, default=False) + + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False + ) + search_space = relationship("SearchSpace", back_populates="roles") + + memberships = relationship( + "SearchSpaceMembership", back_populates="role", passive_deletes=True + ) + invites = relationship( + "SearchSpaceInvite", back_populates="role", passive_deletes=True + ) + + +class SearchSpaceMembership(BaseModel, TimestampMixin): + """ + Tracks user membership in search spaces with their assigned role. + Each user can be a member of multiple search spaces with different roles. + """ + + __tablename__ = "search_space_memberships" + __table_args__ = ( + UniqueConstraint( + "user_id", + "search_space_id", + name="uq_user_searchspace_membership", + ), + ) + + user_id = Column( + UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False + ) + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False + ) + role_id = Column( + Integer, + ForeignKey("search_space_roles.id", ondelete="SET NULL"), + nullable=True, + ) + # Indicates if this user is the original creator/owner of the search space + is_owner = Column(Boolean, nullable=False, default=False) + # Timestamp when the user joined (via invite or as creator) + joined_at = Column( + TIMESTAMP(timezone=True), + nullable=False, + default=lambda: datetime.now(UTC), + ) + # Reference to the invite used to join (null if owner/creator) + invited_by_invite_id = Column( + Integer, + ForeignKey("search_space_invites.id", ondelete="SET NULL"), + nullable=True, + ) + + user = relationship("User", back_populates="search_space_memberships") + search_space = relationship("SearchSpace", back_populates="memberships") + role = relationship("SearchSpaceRole", back_populates="memberships") + invited_by_invite = relationship( + "SearchSpaceInvite", back_populates="used_by_memberships" + ) + + +class SearchSpaceInvite(BaseModel, TimestampMixin): + """ + Invite links for search spaces. + Users can create invite links with specific roles that others can use to join. + """ + + __tablename__ = "search_space_invites" + + # Unique invite code (used in invite URLs) + invite_code = Column(String(64), nullable=False, unique=True, index=True) + + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False + ) + # Role to assign when invite is used (null means use default role) + role_id = Column( + Integer, + ForeignKey("search_space_roles.id", ondelete="SET NULL"), + nullable=True, + ) + # User who created this invite + created_by_id = Column( + UUID(as_uuid=True), + ForeignKey("user.id", ondelete="SET NULL"), + nullable=True, + ) + + # Expiration timestamp (null means never expires) + expires_at = Column(TIMESTAMP(timezone=True), nullable=True) + # Maximum number of times this invite can be used (null means unlimited) + max_uses = Column(Integer, nullable=True) + # Number of times this invite has been used + uses_count = Column(Integer, nullable=False, default=0) + # Whether this invite is currently active + is_active = Column(Boolean, nullable=False, default=True) + # Optional custom name/label for the invite + name = Column(String(100), nullable=True) + + search_space = relationship("SearchSpace", back_populates="invites") + role = relationship("SearchSpaceRole", back_populates="invites") + created_by = relationship("User", back_populates="created_invites") + used_by_memberships = relationship( + "SearchSpaceMembership", + back_populates="invited_by_invite", + passive_deletes=True, + ) + + if config.AUTH_TYPE == "GOOGLE": class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base): @@ -409,6 +726,18 @@ if config.AUTH_TYPE == "GOOGLE": cascade="all, delete-orphan", ) + # RBAC relationships + search_space_memberships = relationship( + "SearchSpaceMembership", + back_populates="user", + cascade="all, delete-orphan", + ) + created_invites = relationship( + "SearchSpaceInvite", + back_populates="created_by", + passive_deletes=True, + ) + # Page usage tracking for ETL services pages_limit = Column(Integer, nullable=False, default=500, server_default="500") pages_used = Column(Integer, nullable=False, default=0, server_default="0") @@ -423,6 +752,18 @@ else: cascade="all, delete-orphan", ) + # RBAC relationships + search_space_memberships = relationship( + "SearchSpaceMembership", + back_populates="user", + cascade="all, delete-orphan", + ) + created_invites = relationship( + "SearchSpaceInvite", + back_populates="created_by", + passive_deletes=True, + ) + # Page usage tracking for ETL services pages_limit = Column(Integer, nullable=False, default=500, server_default="500") pages_used = Column(Integer, nullable=False, default=0, server_default="0") @@ -492,3 +833,109 @@ async def get_documents_hybrid_search_retriever( session: AsyncSession = Depends(get_async_session), ): return DocumentHybridSearchRetriever(session) + + +def has_permission(user_permissions: list[str], required_permission: str) -> bool: + """ + Check if the user has the required permission. + Supports wildcard (*) for full access. + + Args: + user_permissions: List of permission strings the user has + required_permission: The permission string to check for + + Returns: + True if user has the permission, False otherwise + """ + if not user_permissions: + return False + + # Full access wildcard grants all permissions + if Permission.FULL_ACCESS.value in user_permissions: + return True + + return required_permission in user_permissions + + +def has_any_permission( + user_permissions: list[str], required_permissions: list[str] +) -> bool: + """ + Check if the user has any of the required permissions. + + Args: + user_permissions: List of permission strings the user has + required_permissions: List of permission strings to check for (any match) + + Returns: + True if user has at least one of the permissions, False otherwise + """ + if not user_permissions: + return False + + if Permission.FULL_ACCESS.value in user_permissions: + return True + + return any(perm in user_permissions for perm in required_permissions) + + +def has_all_permissions( + user_permissions: list[str], required_permissions: list[str] +) -> bool: + """ + Check if the user has all of the required permissions. + + Args: + user_permissions: List of permission strings the user has + required_permissions: List of permission strings to check for (all must match) + + Returns: + True if user has all of the permissions, False otherwise + """ + if not user_permissions: + return False + + if Permission.FULL_ACCESS.value in user_permissions: + return True + + return all(perm in user_permissions for perm in required_permissions) + + +def get_default_roles_config() -> list[dict]: + """ + Get the configuration for default system roles. + These roles are created automatically when a search space is created. + + Returns: + List of role configurations with name, description, permissions, and flags + """ + return [ + { + "name": "Owner", + "description": "Full access to all search space resources and settings", + "permissions": DEFAULT_ROLE_PERMISSIONS["Owner"], + "is_default": False, + "is_system_role": True, + }, + { + "name": "Admin", + "description": "Can manage most resources except deleting the search space", + "permissions": DEFAULT_ROLE_PERMISSIONS["Admin"], + "is_default": False, + "is_system_role": True, + }, + { + "name": "Editor", + "description": "Can create and edit documents, chats, and podcasts", + "permissions": DEFAULT_ROLE_PERMISSIONS["Editor"], + "is_default": True, # Default role for new members via invite + "is_system_role": True, + }, + { + "name": "Viewer", + "description": "Read-only access to search space resources", + "permissions": DEFAULT_ROLE_PERMISSIONS["Viewer"], + "is_default": False, + "is_system_role": True, + }, + ] diff --git a/surfsense_backend/app/retriver/chunks_hybrid_search.py b/surfsense_backend/app/retriver/chunks_hybrid_search.py index cb96ac695..25a121ad7 100644 --- a/surfsense_backend/app/retriver/chunks_hybrid_search.py +++ b/surfsense_backend/app/retriver/chunks_hybrid_search.py @@ -12,8 +12,7 @@ class ChucksHybridSearchRetriever: self, query_text: str, top_k: int, - user_id: str, - search_space_id: int | None = None, + search_space_id: int, ) -> list: """ Perform vector similarity search on chunks. @@ -21,8 +20,7 @@ class ChucksHybridSearchRetriever: Args: query_text: The search query text top_k: Number of results to return - user_id: The ID of the user performing the search - search_space_id: Optional search space ID to filter results + search_space_id: The search space ID to search within Returns: List of chunks sorted by vector similarity @@ -31,25 +29,20 @@ class ChucksHybridSearchRetriever: from sqlalchemy.orm import joinedload from app.config import config - from app.db import Chunk, Document, SearchSpace + from app.db import Chunk, Document # Get embedding for the query embedding_model = config.embedding_model_instance query_embedding = embedding_model.embed(query_text) - # Build the base query with user ownership check + # Build the query filtered by search space query = ( select(Chunk) .options(joinedload(Chunk.document).joinedload(Document.search_space)) .join(Document, Chunk.document_id == Document.id) - .join(SearchSpace, Document.search_space_id == SearchSpace.id) - .where(SearchSpace.user_id == user_id) + .where(Document.search_space_id == search_space_id) ) - # Add search space filter if provided - if search_space_id is not None: - query = query.where(Document.search_space_id == search_space_id) - # Add vector similarity ordering query = query.order_by(Chunk.embedding.op("<=>")(query_embedding)).limit(top_k) @@ -63,8 +56,7 @@ class ChucksHybridSearchRetriever: self, query_text: str, top_k: int, - user_id: str, - search_space_id: int | None = None, + search_space_id: int, ) -> list: """ Perform full-text keyword search on chunks. @@ -72,8 +64,7 @@ class ChucksHybridSearchRetriever: Args: query_text: The search query text top_k: Number of results to return - user_id: The ID of the user performing the search - search_space_id: Optional search space ID to filter results + search_space_id: The search space ID to search within Returns: List of chunks sorted by text relevance @@ -81,28 +72,23 @@ class ChucksHybridSearchRetriever: from sqlalchemy import func, select from sqlalchemy.orm import joinedload - from app.db import Chunk, Document, SearchSpace + from app.db import Chunk, Document # Create tsvector and tsquery for PostgreSQL full-text search tsvector = func.to_tsvector("english", Chunk.content) tsquery = func.plainto_tsquery("english", query_text) - # Build the base query with user ownership check + # Build the query filtered by search space query = ( select(Chunk) .options(joinedload(Chunk.document).joinedload(Document.search_space)) .join(Document, Chunk.document_id == Document.id) - .join(SearchSpace, Document.search_space_id == SearchSpace.id) - .where(SearchSpace.user_id == user_id) + .where(Document.search_space_id == search_space_id) .where( tsvector.op("@@")(tsquery) ) # Only include results that match the query ) - # Add search space filter if provided - if search_space_id is not None: - query = query.where(Document.search_space_id == search_space_id) - # Add text search ranking query = query.order_by(func.ts_rank_cd(tsvector, tsquery).desc()).limit(top_k) @@ -116,8 +102,7 @@ class ChucksHybridSearchRetriever: self, query_text: str, top_k: int, - user_id: str, - search_space_id: int | None = None, + search_space_id: int, document_type: str | None = None, ) -> list: """ @@ -126,8 +111,7 @@ class ChucksHybridSearchRetriever: Args: query_text: The search query text top_k: Number of results to return - user_id: The ID of the user performing the search - search_space_id: Optional search space ID to filter results + search_space_id: The search space ID to search within document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL") Returns: @@ -137,7 +121,7 @@ class ChucksHybridSearchRetriever: from sqlalchemy.orm import joinedload from app.config import config - from app.db import Chunk, Document, DocumentType, SearchSpace + from app.db import Chunk, Document, DocumentType # Get embedding for the query embedding_model = config.embedding_model_instance @@ -151,12 +135,8 @@ class ChucksHybridSearchRetriever: tsvector = func.to_tsvector("english", Chunk.content) tsquery = func.plainto_tsquery("english", query_text) - # Base conditions for document filtering - base_conditions = [SearchSpace.user_id == user_id] - - # Add search space filter if provided - if search_space_id is not None: - base_conditions.append(Document.search_space_id == search_space_id) + # Base conditions for chunk filtering - search space is required + base_conditions = [Document.search_space_id == search_space_id] # Add document type filter if provided if document_type is not None: @@ -171,7 +151,7 @@ class ChucksHybridSearchRetriever: else: base_conditions.append(Document.document_type == document_type) - # CTE for semantic search with user ownership check + # CTE for semantic search filtered by search space semantic_search_cte = ( select( Chunk.id, @@ -180,7 +160,6 @@ class ChucksHybridSearchRetriever: .label("rank"), ) .join(Document, Chunk.document_id == Document.id) - .join(SearchSpace, Document.search_space_id == SearchSpace.id) .where(*base_conditions) ) @@ -190,7 +169,7 @@ class ChucksHybridSearchRetriever: .cte("semantic_search") ) - # CTE for keyword search with user ownership check + # CTE for keyword search filtered by search space keyword_search_cte = ( select( Chunk.id, @@ -199,7 +178,6 @@ class ChucksHybridSearchRetriever: .label("rank"), ) .join(Document, Chunk.document_id == Document.id) - .join(SearchSpace, Document.search_space_id == SearchSpace.id) .where(*base_conditions) .where(tsvector.op("@@")(tsquery)) ) diff --git a/surfsense_backend/app/retriver/documents_hybrid_search.py b/surfsense_backend/app/retriver/documents_hybrid_search.py index b4e826189..0c08ecc05 100644 --- a/surfsense_backend/app/retriver/documents_hybrid_search.py +++ b/surfsense_backend/app/retriver/documents_hybrid_search.py @@ -12,8 +12,7 @@ class DocumentHybridSearchRetriever: self, query_text: str, top_k: int, - user_id: str, - search_space_id: int | None = None, + search_space_id: int, ) -> list: """ Perform vector similarity search on documents. @@ -21,8 +20,7 @@ class DocumentHybridSearchRetriever: Args: query_text: The search query text top_k: Number of results to return - user_id: The ID of the user performing the search - search_space_id: Optional search space ID to filter results + search_space_id: The search space ID to search within Returns: List of documents sorted by vector similarity @@ -31,24 +29,19 @@ class DocumentHybridSearchRetriever: from sqlalchemy.orm import joinedload from app.config import config - from app.db import Document, SearchSpace + from app.db import Document # Get embedding for the query embedding_model = config.embedding_model_instance query_embedding = embedding_model.embed(query_text) - # Build the base query with user ownership check + # Build the query filtered by search space query = ( select(Document) .options(joinedload(Document.search_space)) - .join(SearchSpace, Document.search_space_id == SearchSpace.id) - .where(SearchSpace.user_id == user_id) + .where(Document.search_space_id == search_space_id) ) - # Add search space filter if provided - if search_space_id is not None: - query = query.where(Document.search_space_id == search_space_id) - # Add vector similarity ordering query = query.order_by(Document.embedding.op("<=>")(query_embedding)).limit( top_k @@ -64,8 +57,7 @@ class DocumentHybridSearchRetriever: self, query_text: str, top_k: int, - user_id: str, - search_space_id: int | None = None, + search_space_id: int, ) -> list: """ Perform full-text keyword search on documents. @@ -73,8 +65,7 @@ class DocumentHybridSearchRetriever: Args: query_text: The search query text top_k: Number of results to return - user_id: The ID of the user performing the search - search_space_id: Optional search space ID to filter results + search_space_id: The search space ID to search within Returns: List of documents sorted by text relevance @@ -82,27 +73,22 @@ class DocumentHybridSearchRetriever: from sqlalchemy import func, select from sqlalchemy.orm import joinedload - from app.db import Document, SearchSpace + from app.db import Document # Create tsvector and tsquery for PostgreSQL full-text search tsvector = func.to_tsvector("english", Document.content) tsquery = func.plainto_tsquery("english", query_text) - # Build the base query with user ownership check + # Build the query filtered by search space query = ( select(Document) .options(joinedload(Document.search_space)) - .join(SearchSpace, Document.search_space_id == SearchSpace.id) - .where(SearchSpace.user_id == user_id) + .where(Document.search_space_id == search_space_id) .where( tsvector.op("@@")(tsquery) ) # Only include results that match the query ) - # Add search space filter if provided - if search_space_id is not None: - query = query.where(Document.search_space_id == search_space_id) - # Add text search ranking query = query.order_by(func.ts_rank_cd(tsvector, tsquery).desc()).limit(top_k) @@ -116,8 +102,7 @@ class DocumentHybridSearchRetriever: self, query_text: str, top_k: int, - user_id: str, - search_space_id: int | None = None, + search_space_id: int, document_type: str | None = None, ) -> list: """ @@ -126,8 +111,7 @@ class DocumentHybridSearchRetriever: Args: query_text: The search query text top_k: Number of results to return - user_id: The ID of the user performing the search - search_space_id: Optional search space ID to filter results + search_space_id: The search space ID to search within document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL") """ @@ -135,7 +119,7 @@ class DocumentHybridSearchRetriever: from sqlalchemy.orm import joinedload from app.config import config - from app.db import Document, DocumentType, SearchSpace + from app.db import Document, DocumentType # Get embedding for the query embedding_model = config.embedding_model_instance @@ -149,12 +133,8 @@ class DocumentHybridSearchRetriever: tsvector = func.to_tsvector("english", Document.content) tsquery = func.plainto_tsquery("english", query_text) - # Base conditions for document filtering - base_conditions = [SearchSpace.user_id == user_id] - - # Add search space filter if provided - if search_space_id is not None: - base_conditions.append(Document.search_space_id == search_space_id) + # Base conditions for document filtering - search space is required + base_conditions = [Document.search_space_id == search_space_id] # Add document type filter if provided if document_type is not None: @@ -169,17 +149,13 @@ class DocumentHybridSearchRetriever: else: base_conditions.append(Document.document_type == document_type) - # CTE for semantic search with user ownership check - semantic_search_cte = ( - select( - Document.id, - func.rank() - .over(order_by=Document.embedding.op("<=>")(query_embedding)) - .label("rank"), - ) - .join(SearchSpace, Document.search_space_id == SearchSpace.id) - .where(*base_conditions) - ) + # CTE for semantic search filtered by search space + semantic_search_cte = select( + Document.id, + func.rank() + .over(order_by=Document.embedding.op("<=>")(query_embedding)) + .label("rank"), + ).where(*base_conditions) semantic_search_cte = ( semantic_search_cte.order_by(Document.embedding.op("<=>")(query_embedding)) @@ -187,7 +163,7 @@ class DocumentHybridSearchRetriever: .cte("semantic_search") ) - # CTE for keyword search with user ownership check + # CTE for keyword search filtered by search space keyword_search_cte = ( select( Document.id, @@ -195,7 +171,6 @@ class DocumentHybridSearchRetriever: .over(order_by=func.ts_rank_cd(tsvector, tsquery).desc()) .label("rank"), ) - .join(SearchSpace, Document.search_space_id == SearchSpace.id) .where(*base_conditions) .where(tsvector.op("@@")(tsquery)) ) diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index 1c7e3505f..127a8d927 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -15,12 +15,14 @@ from .llm_config_routes import router as llm_config_router from .logs_routes import router as logs_router from .luma_add_connector_route import router as luma_add_connector_router from .podcasts_routes import router as podcasts_router +from .rbac_routes import router as rbac_router from .search_source_connectors_routes import router as search_source_connectors_router from .search_spaces_routes import router as search_spaces_router router = APIRouter() router.include_router(search_spaces_router) +router.include_router(rbac_router) # RBAC routes for roles, members, invites router.include_router(documents_router) router.include_router(podcasts_router) router.include_router(chats_router) diff --git a/surfsense_backend/app/routes/chats_routes.py b/surfsense_backend/app/routes/chats_routes.py index 05360cee0..d7aff102b 100644 --- a/surfsense_backend/app/routes/chats_routes.py +++ b/surfsense_backend/app/routes/chats_routes.py @@ -6,7 +6,14 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload -from app.db import Chat, SearchSpace, User, UserSearchSpacePreference, get_async_session +from app.db import ( + Chat, + Permission, + SearchSpace, + SearchSpaceMembership, + User, + get_async_session, +) from app.schemas import ( AISDKChatRequest, ChatCreate, @@ -16,7 +23,7 @@ from app.schemas import ( ) from app.tasks.stream_connector_search_results import stream_connector_search_results from app.users import current_active_user -from app.utils.check_ownership import check_ownership +from app.utils.rbac import check_permission from app.utils.validators import ( validate_connectors, validate_document_ids, @@ -59,45 +66,38 @@ async def handle_chat_data( # print("RESQUEST DATA:", request_data) # print("SELECTED CONNECTORS:", selected_connectors) - # Check if the search space belongs to the current user + # Check if the user has chat access to the search space try: - await check_ownership(session, SearchSpace, search_space_id, user) - language_result = await session.execute( - select(UserSearchSpacePreference) - .options( - selectinload(UserSearchSpacePreference.search_space).selectinload( - SearchSpace.llm_configs - ), - # Note: Removed selectinload for LLM relationships as they no longer exist - # Global configs (negative IDs) don't have foreign keys - # LLM configs are now fetched manually when needed - ) - .filter( - UserSearchSpacePreference.search_space_id == search_space_id, - UserSearchSpacePreference.user_id == user.id, - ) + await check_permission( + session, + user, + search_space_id, + Permission.CHATS_CREATE.value, + "You don't have permission to use chat in this search space", ) - user_preference = language_result.scalars().first() - # print("UserSearchSpacePreference:", user_preference) + + # Get search space with LLM configs (preferences are now stored at search space level) + search_space_result = await session.execute( + select(SearchSpace) + .options(selectinload(SearchSpace.llm_configs)) + .filter(SearchSpace.id == search_space_id) + ) + search_space = search_space_result.scalars().first() language = None llm_configs = [] # Initialize to empty list - if ( - user_preference - and user_preference.search_space - and user_preference.search_space.llm_configs - ): - llm_configs = user_preference.search_space.llm_configs + if search_space and search_space.llm_configs: + llm_configs = search_space.llm_configs - # Manually fetch LLM configs since relationships no longer exist - # Check fast_llm, long_context_llm, and strategic_llm IDs + # Get language from configured LLM preferences + # LLM preferences are now stored on the SearchSpace model from app.config import config as app_config for llm_id in [ - user_preference.fast_llm_id, - user_preference.long_context_llm_id, - user_preference.strategic_llm_id, + search_space.fast_llm_id, + search_space.long_context_llm_id, + search_space.strategic_llm_id, ]: if llm_id is not None: # Check if it's a global config (negative ID) @@ -161,8 +161,18 @@ async def create_chat( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + Create a new chat. + Requires CHATS_CREATE permission. + """ try: - await check_ownership(session, SearchSpace, chat.search_space_id, user) + await check_permission( + session, + user, + chat.search_space_id, + Permission.CHATS_CREATE.value, + "You don't have permission to create chats in this search space", + ) db_chat = Chat(**chat.model_dump()) session.add(db_chat) await session.commit() @@ -197,6 +207,10 @@ async def read_chats( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + List chats the user has access to. + Requires CHATS_READ permission for the search space(s). + """ # Validate pagination parameters if skip < 0: raise HTTPException( @@ -212,9 +226,17 @@ async def read_chats( status_code=400, detail="search_space_id must be a positive integer" ) try: - # Select specific fields excluding messages - query = ( - select( + if search_space_id is not None: + # Check permission for specific search space + await check_permission( + session, + user, + search_space_id, + Permission.CHATS_READ.value, + "You don't have permission to read chats in this search space", + ) + # Select specific fields excluding messages + query = select( Chat.id, Chat.type, Chat.title, @@ -222,17 +244,28 @@ async def read_chats( Chat.search_space_id, Chat.created_at, Chat.state_version, + ).filter(Chat.search_space_id == search_space_id) + else: + # Get chats from all search spaces user has membership in + query = ( + select( + Chat.id, + Chat.type, + Chat.title, + Chat.initial_connectors, + Chat.search_space_id, + Chat.created_at, + Chat.state_version, + ) + .join(SearchSpace) + .join(SearchSpaceMembership) + .filter(SearchSpaceMembership.user_id == user.id) ) - .join(SearchSpace) - .filter(SearchSpace.user_id == user.id) - ) - - # Filter by search_space_id if provided - if search_space_id is not None: - query = query.filter(Chat.search_space_id == search_space_id) result = await session.execute(query.offset(skip).limit(limit)) return result.all() + except HTTPException: + raise except OperationalError: raise HTTPException( status_code=503, detail="Database operation failed. Please try again later." @@ -249,19 +282,32 @@ async def read_chat( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + Get a specific chat by ID. + Requires CHATS_READ permission for the search space. + """ try: - result = await session.execute( - select(Chat) - .join(SearchSpace) - .filter(Chat.id == chat_id, SearchSpace.user_id == user.id) - ) + result = await session.execute(select(Chat).filter(Chat.id == chat_id)) chat = result.scalars().first() + if not chat: raise HTTPException( status_code=404, - detail="Chat not found or you don't have permission to access it", + detail="Chat not found", ) + + # Check permission for the search space + await check_permission( + session, + user, + chat.search_space_id, + Permission.CHATS_READ.value, + "You don't have permission to read chats in this search space", + ) + return chat + except HTTPException: + raise except OperationalError: raise HTTPException( status_code=503, detail="Database operation failed. Please try again later." @@ -280,8 +326,26 @@ async def update_chat( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + Update a chat. + Requires CHATS_UPDATE permission for the search space. + """ try: - db_chat = await read_chat(chat_id, session, user) + result = await session.execute(select(Chat).filter(Chat.id == chat_id)) + db_chat = result.scalars().first() + + if not db_chat: + raise HTTPException(status_code=404, detail="Chat not found") + + # Check permission for the search space + await check_permission( + session, + user, + db_chat.search_space_id, + Permission.CHATS_UPDATE.value, + "You don't have permission to update chats in this search space", + ) + update_data = chat_update.model_dump(exclude_unset=True) for key, value in update_data.items(): if key == "messages": @@ -318,8 +382,26 @@ async def delete_chat( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + Delete a chat. + Requires CHATS_DELETE permission for the search space. + """ try: - db_chat = await read_chat(chat_id, session, user) + result = await session.execute(select(Chat).filter(Chat.id == chat_id)) + db_chat = result.scalars().first() + + if not db_chat: + raise HTTPException(status_code=404, detail="Chat not found") + + # Check permission for the search space + await check_permission( + session, + user, + db_chat.search_space_id, + Permission.CHATS_DELETE.value, + "You don't have permission to delete chats in this search space", + ) + await session.delete(db_chat) await session.commit() return {"message": "Chat deleted successfully"} diff --git a/surfsense_backend/app/routes/documents_routes.py b/surfsense_backend/app/routes/documents_routes.py index ae9df0cf4..67015243f 100644 --- a/surfsense_backend/app/routes/documents_routes.py +++ b/surfsense_backend/app/routes/documents_routes.py @@ -10,7 +10,9 @@ from app.db import ( Chunk, Document, DocumentType, + Permission, SearchSpace, + SearchSpaceMembership, User, get_async_session, ) @@ -22,7 +24,7 @@ from app.schemas import ( PaginatedResponse, ) from app.users import current_active_user -from app.utils.check_ownership import check_ownership +from app.utils.rbac import check_permission try: asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) @@ -44,9 +46,19 @@ async def create_documents( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + Create new documents. + Requires DOCUMENTS_CREATE permission. + """ try: - # Check if the user owns the search space - await check_ownership(session, SearchSpace, request.search_space_id, user) + # Check permission + await check_permission( + session, + user, + request.search_space_id, + Permission.DOCUMENTS_CREATE.value, + "You don't have permission to create documents in this search space", + ) if request.document_type == DocumentType.EXTENSION: from app.tasks.celery_tasks.document_tasks import ( @@ -93,8 +105,19 @@ async def create_documents_file_upload( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + Upload files as documents. + Requires DOCUMENTS_CREATE permission. + """ try: - await check_ownership(session, SearchSpace, search_space_id, user) + # Check permission + await check_permission( + session, + user, + search_space_id, + Permission.DOCUMENTS_CREATE.value, + "You don't have permission to create documents in this search space", + ) if not files: raise HTTPException(status_code=400, detail="No files provided") @@ -151,7 +174,8 @@ async def read_documents( user: User = Depends(current_active_user), ): """ - List documents owned by the current user, with optional filtering and pagination. + List documents the user has access to, with optional filtering and pagination. + Requires DOCUMENTS_READ permission for the search space(s). Args: skip: Absolute number of items to skip from the beginning. If provided, it takes precedence over 'page'. @@ -167,40 +191,49 @@ async def read_documents( Notes: - If both 'skip' and 'page' are provided, 'skip' is used. - - Results are scoped to documents owned by the current user. + - Results are scoped to documents in search spaces the user has membership in. """ try: from sqlalchemy import func - query = ( - select(Document).join(SearchSpace).filter(SearchSpace.user_id == user.id) - ) - - # Filter by search_space_id if provided + # If specific search_space_id, check permission if search_space_id is not None: - query = query.filter(Document.search_space_id == search_space_id) + await check_permission( + session, + user, + search_space_id, + Permission.DOCUMENTS_READ.value, + "You don't have permission to read documents in this search space", + ) + query = select(Document).filter(Document.search_space_id == search_space_id) + count_query = ( + select(func.count()) + .select_from(Document) + .filter(Document.search_space_id == search_space_id) + ) + else: + # Get documents from all search spaces user has membership in + query = ( + select(Document) + .join(SearchSpace) + .join(SearchSpaceMembership) + .filter(SearchSpaceMembership.user_id == user.id) + ) + count_query = ( + select(func.count()) + .select_from(Document) + .join(SearchSpace) + .join(SearchSpaceMembership) + .filter(SearchSpaceMembership.user_id == user.id) + ) # Filter by document_types if provided if document_types is not None and document_types.strip(): type_list = [t.strip() for t in document_types.split(",") if t.strip()] if type_list: query = query.filter(Document.document_type.in_(type_list)) - - # Get total count - count_query = ( - select(func.count()) - .select_from(Document) - .join(SearchSpace) - .filter(SearchSpace.user_id == user.id) - ) - if search_space_id is not None: - count_query = count_query.filter( - Document.search_space_id == search_space_id - ) - if document_types is not None and document_types.strip(): - type_list = [t.strip() for t in document_types.split(",") if t.strip()] - if type_list: count_query = count_query.filter(Document.document_type.in_(type_list)) + total_result = await session.execute(count_query) total = total_result.scalar() or 0 @@ -235,6 +268,8 @@ async def read_documents( ) return PaginatedResponse(items=api_documents, total=total) + except HTTPException: + raise except Exception as e: raise HTTPException( status_code=500, detail=f"Failed to fetch documents: {e!s}" @@ -254,6 +289,7 @@ async def search_documents( ): """ Search documents by title substring, optionally filtered by search_space_id and document_types. + Requires DOCUMENTS_READ permission for the search space(s). Args: title: Case-insensitive substring to match against document titles. Required. @@ -275,37 +311,48 @@ async def search_documents( try: from sqlalchemy import func - query = ( - select(Document).join(SearchSpace).filter(SearchSpace.user_id == user.id) - ) + # If specific search_space_id, check permission if search_space_id is not None: - query = query.filter(Document.search_space_id == search_space_id) + await check_permission( + session, + user, + search_space_id, + Permission.DOCUMENTS_READ.value, + "You don't have permission to read documents in this search space", + ) + query = select(Document).filter(Document.search_space_id == search_space_id) + count_query = ( + select(func.count()) + .select_from(Document) + .filter(Document.search_space_id == search_space_id) + ) + else: + # Get documents from all search spaces user has membership in + query = ( + select(Document) + .join(SearchSpace) + .join(SearchSpaceMembership) + .filter(SearchSpaceMembership.user_id == user.id) + ) + count_query = ( + select(func.count()) + .select_from(Document) + .join(SearchSpace) + .join(SearchSpaceMembership) + .filter(SearchSpaceMembership.user_id == user.id) + ) # Only search by title (case-insensitive) query = query.filter(Document.title.ilike(f"%{title}%")) + count_query = count_query.filter(Document.title.ilike(f"%{title}%")) # Filter by document_types if provided if document_types is not None and document_types.strip(): type_list = [t.strip() for t in document_types.split(",") if t.strip()] if type_list: query = query.filter(Document.document_type.in_(type_list)) - - # Get total count - count_query = ( - select(func.count()) - .select_from(Document) - .join(SearchSpace) - .filter(SearchSpace.user_id == user.id) - ) - if search_space_id is not None: - count_query = count_query.filter( - Document.search_space_id == search_space_id - ) - count_query = count_query.filter(Document.title.ilike(f"%{title}%")) - if document_types is not None and document_types.strip(): - type_list = [t.strip() for t in document_types.split(",") if t.strip()] - if type_list: count_query = count_query.filter(Document.document_type.in_(type_list)) + total_result = await session.execute(count_query) total = total_result.scalar() or 0 @@ -340,6 +387,8 @@ async def search_documents( ) return PaginatedResponse(items=api_documents, total=total) + except HTTPException: + raise except Exception as e: raise HTTPException( status_code=500, detail=f"Failed to search documents: {e!s}" @@ -353,7 +402,8 @@ async def get_document_type_counts( user: User = Depends(current_active_user), ): """ - Get counts of documents by type for the current user. + Get counts of documents by type for search spaces the user has access to. + Requires DOCUMENTS_READ permission for the search space(s). Args: search_space_id: If provided, restrict counts to a specific search space. @@ -366,20 +416,36 @@ async def get_document_type_counts( try: from sqlalchemy import func - query = ( - select(Document.document_type, func.count(Document.id)) - .join(SearchSpace) - .filter(SearchSpace.user_id == user.id) - .group_by(Document.document_type) - ) - if search_space_id is not None: - query = query.filter(Document.search_space_id == search_space_id) + # Check permission for specific search space + await check_permission( + session, + user, + search_space_id, + Permission.DOCUMENTS_READ.value, + "You don't have permission to read documents in this search space", + ) + query = ( + select(Document.document_type, func.count(Document.id)) + .filter(Document.search_space_id == search_space_id) + .group_by(Document.document_type) + ) + else: + # Get counts from all search spaces user has membership in + query = ( + select(Document.document_type, func.count(Document.id)) + .join(SearchSpace) + .join(SearchSpaceMembership) + .filter(SearchSpaceMembership.user_id == user.id) + .group_by(Document.document_type) + ) result = await session.execute(query) type_counts = dict(result.all()) return type_counts + except HTTPException: + raise except Exception as e: raise HTTPException( status_code=500, detail=f"Failed to fetch document type counts: {e!s}" @@ -394,6 +460,7 @@ async def get_document_by_chunk_id( ): """ Retrieves a document based on a chunk ID, including all its chunks ordered by creation time. + Requires DOCUMENTS_READ permission for the search space. The document's embedding and chunk embeddings are excluded from the response. """ try: @@ -406,21 +473,29 @@ async def get_document_by_chunk_id( status_code=404, detail=f"Chunk with id {chunk_id} not found" ) - # Get the associated document and verify ownership + # Get the associated document document_result = await session.execute( select(Document) .options(selectinload(Document.chunks)) - .join(SearchSpace) - .filter(Document.id == chunk.document_id, SearchSpace.user_id == user.id) + .filter(Document.id == chunk.document_id) ) document = document_result.scalars().first() if not document: raise HTTPException( status_code=404, - detail="Document not found or you don't have access to it", + detail="Document not found", ) + # Check permission for the search space + await check_permission( + session, + user, + document.search_space_id, + Permission.DOCUMENTS_READ.value, + "You don't have permission to read documents in this search space", + ) + # Sort chunks by creation time sorted_chunks = sorted(document.chunks, key=lambda x: x.created_at) @@ -449,11 +524,13 @@ async def read_document( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + Get a specific document by ID. + Requires DOCUMENTS_READ permission for the search space. + """ try: result = await session.execute( - select(Document) - .join(SearchSpace) - .filter(Document.id == document_id, SearchSpace.user_id == user.id) + select(Document).filter(Document.id == document_id) ) document = result.scalars().first() @@ -462,6 +539,15 @@ async def read_document( status_code=404, detail=f"Document with id {document_id} not found" ) + # Check permission for the search space + await check_permission( + session, + user, + document.search_space_id, + Permission.DOCUMENTS_READ.value, + "You don't have permission to read documents in this search space", + ) + # Convert database object to API-friendly format return DocumentRead( id=document.id, @@ -472,6 +558,8 @@ async def read_document( created_at=document.created_at, search_space_id=document.search_space_id, ) + except HTTPException: + raise except Exception as e: raise HTTPException( status_code=500, detail=f"Failed to fetch document: {e!s}" @@ -485,12 +573,13 @@ async def update_document( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + Update a document. + Requires DOCUMENTS_UPDATE permission for the search space. + """ try: - # Query the document directly instead of using read_document function result = await session.execute( - select(Document) - .join(SearchSpace) - .filter(Document.id == document_id, SearchSpace.user_id == user.id) + select(Document).filter(Document.id == document_id) ) db_document = result.scalars().first() @@ -499,6 +588,15 @@ async def update_document( status_code=404, detail=f"Document with id {document_id} not found" ) + # Check permission for the search space + await check_permission( + session, + user, + db_document.search_space_id, + Permission.DOCUMENTS_UPDATE.value, + "You don't have permission to update documents in this search space", + ) + update_data = document_update.model_dump(exclude_unset=True) for key, value in update_data.items(): setattr(db_document, key, value) @@ -530,12 +628,13 @@ async def delete_document( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + Delete a document. + Requires DOCUMENTS_DELETE permission for the search space. + """ try: - # Query the document directly instead of using read_document function result = await session.execute( - select(Document) - .join(SearchSpace) - .filter(Document.id == document_id, SearchSpace.user_id == user.id) + select(Document).filter(Document.id == document_id) ) document = result.scalars().first() @@ -544,6 +643,15 @@ async def delete_document( status_code=404, detail=f"Document with id {document_id} not found" ) + # Check permission for the search space + await check_permission( + session, + user, + document.search_space_id, + Permission.DOCUMENTS_DELETE.value, + "You don't have permission to delete documents in this search space", + ) + await session.delete(document) await session.commit() return {"message": "Document deleted successfully"} diff --git a/surfsense_backend/app/routes/llm_config_routes.py b/surfsense_backend/app/routes/llm_config_routes.py index 35c3ce574..31c7200f5 100644 --- a/surfsense_backend/app/routes/llm_config_routes.py +++ b/surfsense_backend/app/routes/llm_config_routes.py @@ -8,67 +8,22 @@ from sqlalchemy.future import select from app.config import config from app.db import ( LLMConfig, + Permission, SearchSpace, User, - UserSearchSpacePreference, get_async_session, ) from app.schemas import LLMConfigCreate, LLMConfigRead, LLMConfigUpdate from app.services.llm_service import validate_llm_config from app.users import current_active_user +from app.utils.rbac import check_permission router = APIRouter() logger = logging.getLogger(__name__) -# Helper function to check search space access -async def check_search_space_access( - session: AsyncSession, search_space_id: int, user: User -) -> SearchSpace: - """Verify that the user has access to the search space""" - result = await session.execute( - select(SearchSpace).filter( - SearchSpace.id == search_space_id, SearchSpace.user_id == user.id - ) - ) - search_space = result.scalars().first() - if not search_space: - raise HTTPException( - status_code=404, - detail="Search space not found or you don't have permission to access it", - ) - return search_space - - -# Helper function to get or create user search space preference -async def get_or_create_user_preference( - session: AsyncSession, user_id, search_space_id: int -) -> UserSearchSpacePreference: - """Get or create user preference for a search space""" - result = await session.execute( - select(UserSearchSpacePreference).filter( - UserSearchSpacePreference.user_id == user_id, - UserSearchSpacePreference.search_space_id == search_space_id, - ) - # Removed selectinload options since relationships no longer exist - ) - preference = result.scalars().first() - - if not preference: - # Create new preference entry - preference = UserSearchSpacePreference( - user_id=user_id, - search_space_id=search_space_id, - ) - session.add(preference) - await session.commit() - await session.refresh(preference) - - return preference - - class LLMPreferencesUpdate(BaseModel): - """Schema for updating user LLM preferences""" + """Schema for updating search space LLM preferences""" long_context_llm_id: int | None = None fast_llm_id: int | None = None @@ -76,7 +31,7 @@ class LLMPreferencesUpdate(BaseModel): class LLMPreferencesRead(BaseModel): - """Schema for reading user LLM preferences""" + """Schema for reading search space LLM preferences""" long_context_llm_id: int | None = None fast_llm_id: int | None = None @@ -144,10 +99,19 @@ async def create_llm_config( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Create a new LLM configuration for a search space""" + """ + Create a new LLM configuration for a search space. + Requires LLM_CONFIGS_CREATE permission. + """ try: - # Verify user has access to the search space - await check_search_space_access(session, llm_config.search_space_id, user) + # Verify user has permission to create LLM configs + await check_permission( + session, + user, + llm_config.search_space_id, + Permission.LLM_CONFIGS_CREATE.value, + "You don't have permission to create LLM configurations in this search space", + ) # Validate the LLM configuration by making a test API call is_valid, error_message = await validate_llm_config( @@ -187,10 +151,19 @@ async def read_llm_configs( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Get all LLM configurations for a search space""" + """ + Get all LLM configurations for a search space. + Requires LLM_CONFIGS_READ permission. + """ try: - # Verify user has access to the search space - await check_search_space_access(session, search_space_id, user) + # Verify user has permission to read LLM configs + await check_permission( + session, + user, + search_space_id, + Permission.LLM_CONFIGS_READ.value, + "You don't have permission to view LLM configurations in this search space", + ) result = await session.execute( select(LLMConfig) @@ -213,7 +186,10 @@ async def read_llm_config( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Get a specific LLM configuration by ID""" + """ + Get a specific LLM configuration by ID. + Requires LLM_CONFIGS_READ permission. + """ try: # Get the LLM config result = await session.execute( @@ -224,8 +200,14 @@ async def read_llm_config( if not llm_config: raise HTTPException(status_code=404, detail="LLM configuration not found") - # Verify user has access to the search space - await check_search_space_access(session, llm_config.search_space_id, user) + # Verify user has permission to read LLM configs + await check_permission( + session, + user, + llm_config.search_space_id, + Permission.LLM_CONFIGS_READ.value, + "You don't have permission to view LLM configurations in this search space", + ) return llm_config except HTTPException: @@ -243,7 +225,10 @@ async def update_llm_config( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Update an existing LLM configuration""" + """ + Update an existing LLM configuration. + Requires LLM_CONFIGS_UPDATE permission. + """ try: # Get the LLM config result = await session.execute( @@ -254,8 +239,14 @@ async def update_llm_config( if not db_llm_config: raise HTTPException(status_code=404, detail="LLM configuration not found") - # Verify user has access to the search space - await check_search_space_access(session, db_llm_config.search_space_id, user) + # Verify user has permission to update LLM configs + await check_permission( + session, + user, + db_llm_config.search_space_id, + Permission.LLM_CONFIGS_UPDATE.value, + "You don't have permission to update LLM configurations in this search space", + ) update_data = llm_config_update.model_dump(exclude_unset=True) @@ -311,7 +302,10 @@ async def delete_llm_config( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Delete an LLM configuration""" + """ + Delete an LLM configuration. + Requires LLM_CONFIGS_DELETE permission. + """ try: # Get the LLM config result = await session.execute( @@ -322,8 +316,14 @@ async def delete_llm_config( if not db_llm_config: raise HTTPException(status_code=404, detail="LLM configuration not found") - # Verify user has access to the search space - await check_search_space_access(session, db_llm_config.search_space_id, user) + # Verify user has permission to delete LLM configs + await check_permission( + session, + user, + db_llm_config.search_space_id, + Permission.LLM_CONFIGS_DELETE.value, + "You don't have permission to delete LLM configurations in this search space", + ) await session.delete(db_llm_config) await session.commit() @@ -337,28 +337,42 @@ async def delete_llm_config( ) from e -# User LLM Preferences endpoints +# Search Space LLM Preferences endpoints @router.get( "/search-spaces/{search_space_id}/llm-preferences", response_model=LLMPreferencesRead, ) -async def get_user_llm_preferences( +async def get_llm_preferences( search_space_id: int, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Get the current user's LLM preferences for a specific search space""" + """ + Get the LLM preferences for a specific search space. + LLM preferences are shared by all members of the search space. + Requires LLM_CONFIGS_READ permission. + """ try: - # Verify user has access to the search space - await check_search_space_access(session, search_space_id, user) - - # Get or create user preference for this search space - preference = await get_or_create_user_preference( - session, user.id, search_space_id + # Verify user has permission to read LLM configs + await check_permission( + session, + user, + search_space_id, + Permission.LLM_CONFIGS_READ.value, + "You don't have permission to view LLM preferences in this search space", ) + # Get the search space + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + + if not search_space: + raise HTTPException(status_code=404, detail="Search space not found") + # Helper function to get config (global or custom) async def get_config_for_id(config_id): if config_id is None: @@ -391,14 +405,14 @@ async def get_user_llm_preferences( return result.scalars().first() # Get the configs (from DB for custom, or constructed for global) - long_context_llm = await get_config_for_id(preference.long_context_llm_id) - fast_llm = await get_config_for_id(preference.fast_llm_id) - strategic_llm = await get_config_for_id(preference.strategic_llm_id) + long_context_llm = await get_config_for_id(search_space.long_context_llm_id) + fast_llm = await get_config_for_id(search_space.fast_llm_id) + strategic_llm = await get_config_for_id(search_space.strategic_llm_id) return { - "long_context_llm_id": preference.long_context_llm_id, - "fast_llm_id": preference.fast_llm_id, - "strategic_llm_id": preference.strategic_llm_id, + "long_context_llm_id": search_space.long_context_llm_id, + "fast_llm_id": search_space.fast_llm_id, + "strategic_llm_id": search_space.strategic_llm_id, "long_context_llm": long_context_llm, "fast_llm": fast_llm, "strategic_llm": strategic_llm, @@ -415,22 +429,37 @@ async def get_user_llm_preferences( "/search-spaces/{search_space_id}/llm-preferences", response_model=LLMPreferencesRead, ) -async def update_user_llm_preferences( +async def update_llm_preferences( search_space_id: int, preferences: LLMPreferencesUpdate, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Update the current user's LLM preferences for a specific search space""" + """ + Update the LLM preferences for a specific search space. + LLM preferences are shared by all members of the search space. + Requires SETTINGS_UPDATE permission (only users with settings access can change). + """ try: - # Verify user has access to the search space - await check_search_space_access(session, search_space_id, user) - - # Get or create user preference for this search space - preference = await get_or_create_user_preference( - session, user.id, search_space_id + # Verify user has permission to update settings (not just LLM configs) + # This ensures only users with settings access can change shared LLM preferences + await check_permission( + session, + user, + search_space_id, + Permission.SETTINGS_UPDATE.value, + "You don't have permission to update LLM preferences in this search space", ) + # Get the search space + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + + if not search_space: + raise HTTPException(status_code=404, detail="Search space not found") + # Validate that all provided LLM config IDs belong to the search space update_data = preferences.model_dump(exclude_unset=True) @@ -485,18 +514,13 @@ async def update_user_llm_preferences( f"Multiple languages detected in LLM selection for search_space {search_space_id}: {languages}. " "This may affect response quality." ) - # Don't raise an exception - allow users to proceed - # raise HTTPException( - # status_code=400, - # detail="All selected LLM configurations must have the same language setting", - # ) - # Update user preferences + # Update search space LLM preferences for key, value in update_data.items(): - setattr(preference, key, value) + setattr(search_space, key, value) await session.commit() - await session.refresh(preference) + await session.refresh(search_space) # Helper function to get config (global or custom) async def get_config_for_id(config_id): @@ -530,15 +554,15 @@ async def update_user_llm_preferences( return result.scalars().first() # Get the configs (from DB for custom, or constructed for global) - long_context_llm = await get_config_for_id(preference.long_context_llm_id) - fast_llm = await get_config_for_id(preference.fast_llm_id) - strategic_llm = await get_config_for_id(preference.strategic_llm_id) + long_context_llm = await get_config_for_id(search_space.long_context_llm_id) + fast_llm = await get_config_for_id(search_space.fast_llm_id) + strategic_llm = await get_config_for_id(search_space.strategic_llm_id) # Return updated preferences return { - "long_context_llm_id": preference.long_context_llm_id, - "fast_llm_id": preference.fast_llm_id, - "strategic_llm_id": preference.strategic_llm_id, + "long_context_llm_id": search_space.long_context_llm_id, + "fast_llm_id": search_space.fast_llm_id, + "strategic_llm_id": search_space.strategic_llm_id, "long_context_llm": long_context_llm, "fast_llm": fast_llm, "strategic_llm": strategic_llm, diff --git a/surfsense_backend/app/routes/logs_routes.py b/surfsense_backend/app/routes/logs_routes.py index d9dd997ce..98fd9141e 100644 --- a/surfsense_backend/app/routes/logs_routes.py +++ b/surfsense_backend/app/routes/logs_routes.py @@ -5,10 +5,19 @@ from sqlalchemy import and_, desc from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.db import Log, LogLevel, LogStatus, SearchSpace, User, get_async_session +from app.db import ( + Log, + LogLevel, + LogStatus, + Permission, + SearchSpace, + SearchSpaceMembership, + User, + get_async_session, +) from app.schemas import LogCreate, LogRead, LogUpdate from app.users import current_active_user -from app.utils.check_ownership import check_ownership +from app.utils.rbac import check_permission router = APIRouter() @@ -19,10 +28,19 @@ async def create_log( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Create a new log entry.""" + """ + Create a new log entry. + Note: This is typically called internally. Requires LOGS_READ permission (since logs are usually system-generated). + """ try: - # Check if the user owns the search space - await check_ownership(session, SearchSpace, log.search_space_id, user) + # Check if the user has access to the search space + await check_permission( + session, + user, + log.search_space_id, + Permission.LOGS_READ.value, + "You don't have permission to access logs in this search space", + ) db_log = Log(**log.model_dump()) session.add(db_log) @@ -51,22 +69,38 @@ async def read_logs( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Get logs with optional filtering.""" + """ + Get logs with optional filtering. + Requires LOGS_READ permission for the search space(s). + """ try: - # Build base query - only logs from user's search spaces - query = ( - select(Log) - .join(SearchSpace) - .filter(SearchSpace.user_id == user.id) - .order_by(desc(Log.created_at)) # Most recent first - ) - # Apply filters filters = [] if search_space_id is not None: - await check_ownership(session, SearchSpace, search_space_id, user) - filters.append(Log.search_space_id == search_space_id) + # Check permission for specific search space + await check_permission( + session, + user, + search_space_id, + Permission.LOGS_READ.value, + "You don't have permission to read logs in this search space", + ) + # Build query for specific search space + query = ( + select(Log) + .filter(Log.search_space_id == search_space_id) + .order_by(desc(Log.created_at)) + ) + else: + # Build base query - logs from search spaces user has membership in + query = ( + select(Log) + .join(SearchSpace) + .join(SearchSpaceMembership) + .filter(SearchSpaceMembership.user_id == user.id) + .order_by(desc(Log.created_at)) + ) if level is not None: filters.append(Log.level == level) @@ -104,19 +138,26 @@ async def read_log( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Get a specific log by ID.""" + """ + Get a specific log by ID. + Requires LOGS_READ permission for the search space. + """ try: - # Get log and verify user owns the search space - result = await session.execute( - select(Log) - .join(SearchSpace) - .filter(Log.id == log_id, SearchSpace.user_id == user.id) - ) + result = await session.execute(select(Log).filter(Log.id == log_id)) log = result.scalars().first() if not log: raise HTTPException(status_code=404, detail="Log not found") + # Check permission for the search space + await check_permission( + session, + user, + log.search_space_id, + Permission.LOGS_READ.value, + "You don't have permission to read logs in this search space", + ) + return log except HTTPException: raise @@ -133,19 +174,26 @@ async def update_log( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Update a log entry.""" + """ + Update a log entry. + Requires LOGS_READ permission (logs are typically updated by system). + """ try: - # Get log and verify user owns the search space - result = await session.execute( - select(Log) - .join(SearchSpace) - .filter(Log.id == log_id, SearchSpace.user_id == user.id) - ) + result = await session.execute(select(Log).filter(Log.id == log_id)) db_log = result.scalars().first() if not db_log: raise HTTPException(status_code=404, detail="Log not found") + # Check permission for the search space + await check_permission( + session, + user, + db_log.search_space_id, + Permission.LOGS_READ.value, + "You don't have permission to access logs in this search space", + ) + # Update only provided fields update_data = log_update.model_dump(exclude_unset=True) for field, value in update_data.items(): @@ -169,19 +217,26 @@ async def delete_log( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Delete a log entry.""" + """ + Delete a log entry. + Requires LOGS_DELETE permission for the search space. + """ try: - # Get log and verify user owns the search space - result = await session.execute( - select(Log) - .join(SearchSpace) - .filter(Log.id == log_id, SearchSpace.user_id == user.id) - ) + result = await session.execute(select(Log).filter(Log.id == log_id)) db_log = result.scalars().first() if not db_log: raise HTTPException(status_code=404, detail="Log not found") + # Check permission for the search space + await check_permission( + session, + user, + db_log.search_space_id, + Permission.LOGS_DELETE.value, + "You don't have permission to delete logs in this search space", + ) + await session.delete(db_log) await session.commit() return {"message": "Log deleted successfully"} @@ -201,10 +256,19 @@ async def get_logs_summary( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Get a summary of logs for a search space in the last X hours.""" + """ + Get a summary of logs for a search space in the last X hours. + Requires LOGS_READ permission for the search space. + """ try: - # Check ownership - await check_ownership(session, SearchSpace, search_space_id, user) + # Check permission + await check_permission( + session, + user, + search_space_id, + Permission.LOGS_READ.value, + "You don't have permission to read logs in this search space", + ) # Calculate time window since = datetime.utcnow().replace(microsecond=0) - timedelta(hours=hours) diff --git a/surfsense_backend/app/routes/podcasts_routes.py b/surfsense_backend/app/routes/podcasts_routes.py index ae1fdaeef..deb9d9744 100644 --- a/surfsense_backend/app/routes/podcasts_routes.py +++ b/surfsense_backend/app/routes/podcasts_routes.py @@ -7,7 +7,15 @@ from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.db import Chat, Podcast, SearchSpace, User, get_async_session +from app.db import ( + Chat, + Permission, + Podcast, + SearchSpace, + SearchSpaceMembership, + User, + get_async_session, +) from app.schemas import ( PodcastCreate, PodcastGenerateRequest, @@ -16,7 +24,7 @@ from app.schemas import ( ) from app.tasks.podcast_tasks import generate_chat_podcast from app.users import current_active_user -from app.utils.check_ownership import check_ownership +from app.utils.rbac import check_permission router = APIRouter() @@ -27,8 +35,18 @@ async def create_podcast( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + Create a new podcast. + Requires PODCASTS_CREATE permission. + """ try: - await check_ownership(session, SearchSpace, podcast.search_space_id, user) + await check_permission( + session, + user, + podcast.search_space_id, + Permission.PODCASTS_CREATE.value, + "You don't have permission to create podcasts in this search space", + ) db_podcast = Podcast(**podcast.model_dump()) session.add(db_podcast) await session.commit() @@ -58,20 +76,45 @@ async def create_podcast( async def read_podcasts( skip: int = 0, limit: int = 100, + search_space_id: int | None = None, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + List podcasts the user has access to. + Requires PODCASTS_READ permission for the search space(s). + """ if skip < 0 or limit < 1: raise HTTPException(status_code=400, detail="Invalid pagination parameters") try: - result = await session.execute( - select(Podcast) - .join(SearchSpace) - .filter(SearchSpace.user_id == user.id) - .offset(skip) - .limit(limit) - ) + if search_space_id is not None: + # Check permission for specific search space + await check_permission( + session, + user, + search_space_id, + Permission.PODCASTS_READ.value, + "You don't have permission to read podcasts in this search space", + ) + result = await session.execute( + select(Podcast) + .filter(Podcast.search_space_id == search_space_id) + .offset(skip) + .limit(limit) + ) + else: + # Get podcasts from all search spaces user has membership in + result = await session.execute( + select(Podcast) + .join(SearchSpace) + .join(SearchSpaceMembership) + .filter(SearchSpaceMembership.user_id == user.id) + .offset(skip) + .limit(limit) + ) return result.scalars().all() + except HTTPException: + raise except SQLAlchemyError: raise HTTPException( status_code=500, detail="Database error occurred while fetching podcasts" @@ -84,18 +127,29 @@ async def read_podcast( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + Get a specific podcast by ID. + Requires PODCASTS_READ permission for the search space. + """ try: - result = await session.execute( - select(Podcast) - .join(SearchSpace) - .filter(Podcast.id == podcast_id, SearchSpace.user_id == user.id) - ) + result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id)) podcast = result.scalars().first() + if not podcast: raise HTTPException( status_code=404, - detail="Podcast not found or you don't have permission to access it", + detail="Podcast not found", ) + + # Check permission for the search space + await check_permission( + session, + user, + podcast.search_space_id, + Permission.PODCASTS_READ.value, + "You don't have permission to read podcasts in this search space", + ) + return podcast except HTTPException as he: raise he @@ -112,8 +166,26 @@ async def update_podcast( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + Update a podcast. + Requires PODCASTS_UPDATE permission for the search space. + """ try: - db_podcast = await read_podcast(podcast_id, session, user) + result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id)) + db_podcast = result.scalars().first() + + if not db_podcast: + raise HTTPException(status_code=404, detail="Podcast not found") + + # Check permission for the search space + await check_permission( + session, + user, + db_podcast.search_space_id, + Permission.PODCASTS_UPDATE.value, + "You don't have permission to update podcasts in this search space", + ) + update_data = podcast_update.model_dump(exclude_unset=True) for key, value in update_data.items(): setattr(db_podcast, key, value) @@ -140,8 +212,26 @@ async def delete_podcast( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + Delete a podcast. + Requires PODCASTS_DELETE permission for the search space. + """ try: - db_podcast = await read_podcast(podcast_id, session, user) + result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id)) + db_podcast = result.scalars().first() + + if not db_podcast: + raise HTTPException(status_code=404, detail="Podcast not found") + + # Check permission for the search space + await check_permission( + session, + user, + db_podcast.search_space_id, + Permission.PODCASTS_DELETE.value, + "You don't have permission to delete podcasts in this search space", + ) + await session.delete(db_podcast) await session.commit() return {"message": "Podcast deleted successfully"} @@ -181,9 +271,19 @@ async def generate_podcast( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + Generate a podcast from a chat or document. + Requires PODCASTS_CREATE permission. + """ try: - # Check if the user owns the search space - await check_ownership(session, SearchSpace, request.search_space_id, user) + # Check if the user has permission to create podcasts + await check_permission( + session, + user, + request.search_space_id, + Permission.PODCASTS_CREATE.value, + "You don't have permission to create podcasts in this search space", + ) if request.type == "CHAT": # Verify that all chat IDs belong to this user and search space @@ -251,22 +351,29 @@ async def stream_podcast( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Stream a podcast audio file.""" + """ + Stream a podcast audio file. + Requires PODCASTS_READ permission for the search space. + """ try: - # Get the podcast and check if user has access - result = await session.execute( - select(Podcast) - .join(SearchSpace) - .filter(Podcast.id == podcast_id, SearchSpace.user_id == user.id) - ) + result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id)) podcast = result.scalars().first() if not podcast: raise HTTPException( status_code=404, - detail="Podcast not found or you don't have permission to access it", + detail="Podcast not found", ) + # Check permission for the search space + await check_permission( + session, + user, + podcast.search_space_id, + Permission.PODCASTS_READ.value, + "You don't have permission to access podcasts in this search space", + ) + # Get the file path file_path = podcast.file_location @@ -303,12 +410,30 @@ async def get_podcast_by_chat_id( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + Get a podcast by its associated chat ID. + Requires PODCASTS_READ permission for the search space. + """ try: - # Get the podcast and check if user has access + # First get the chat to find its search space + chat_result = await session.execute(select(Chat).filter(Chat.id == chat_id)) + chat = chat_result.scalars().first() + + if not chat: + return None + + # Check permission for the search space + await check_permission( + session, + user, + chat.search_space_id, + Permission.PODCASTS_READ.value, + "You don't have permission to read podcasts in this search space", + ) + + # Get the podcast result = await session.execute( - select(Podcast) - .join(SearchSpace) - .filter(Podcast.chat_id == chat_id, SearchSpace.user_id == user.id) + select(Podcast).filter(Podcast.chat_id == chat_id) ) podcast = result.scalars().first() diff --git a/surfsense_backend/app/routes/rbac_routes.py b/surfsense_backend/app/routes/rbac_routes.py new file mode 100644 index 000000000..c5392f284 --- /dev/null +++ b/surfsense_backend/app/routes/rbac_routes.py @@ -0,0 +1,1084 @@ +""" +RBAC (Role-Based Access Control) routes for managing roles, memberships, and invites. + +Endpoints: +- /searchspaces/{search_space_id}/roles - CRUD for roles +- /searchspaces/{search_space_id}/members - CRUD for memberships +- /searchspaces/{search_space_id}/invites - CRUD for invites +- /invites/{invite_code}/info - Get invite info (public) +- /invites/accept - Accept an invite +- /permissions - List all available permissions +""" + +import logging +from datetime import UTC, datetime + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy.orm import selectinload + +from app.db import ( + Permission, + SearchSpace, + SearchSpaceInvite, + SearchSpaceMembership, + SearchSpaceRole, + User, + get_async_session, +) +from app.schemas import ( + InviteAcceptRequest, + InviteAcceptResponse, + InviteCreate, + InviteInfoResponse, + InviteRead, + InviteUpdate, + MembershipRead, + MembershipUpdate, + PermissionInfo, + PermissionsListResponse, + RoleCreate, + RoleRead, + RoleUpdate, + UserSearchSpaceAccess, +) +from app.users import current_active_user +from app.utils.rbac import ( + check_permission, + check_search_space_access, + generate_invite_code, + get_default_role, + get_user_permissions, +) + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +# ============ Permissions Endpoints ============ + + +@router.get("/permissions", response_model=PermissionsListResponse) +async def list_all_permissions( + user: User = Depends(current_active_user), +): + """ + List all available permissions that can be assigned to roles. + """ + permissions = [] + for perm in Permission: + # Extract category from permission value (e.g., "documents:read" -> "documents") + category = perm.value.split(":")[0] if ":" in perm.value else "general" + + permissions.append( + PermissionInfo( + value=perm.value, + name=perm.name, + category=category, + ) + ) + + return PermissionsListResponse(permissions=permissions) + + +# ============ Role Endpoints ============ + + +@router.post( + "/searchspaces/{search_space_id}/roles", + response_model=RoleRead, +) +async def create_role( + search_space_id: int, + role_data: RoleCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Create a new custom role in a search space. + Requires ROLES_CREATE permission. + """ + try: + await check_permission( + session, + user, + search_space_id, + Permission.ROLES_CREATE.value, + "You don't have permission to create roles", + ) + + # Check if role with same name already exists + result = await session.execute( + select(SearchSpaceRole).filter( + SearchSpaceRole.search_space_id == search_space_id, + SearchSpaceRole.name == role_data.name, + ) + ) + if result.scalars().first(): + raise HTTPException( + status_code=409, + detail=f"A role with name '{role_data.name}' already exists in this search space", + ) + + # Validate permissions + valid_permissions = {p.value for p in Permission} + for perm in role_data.permissions: + if perm not in valid_permissions: + raise HTTPException( + status_code=400, + detail=f"Invalid permission: {perm}", + ) + + # If setting is_default to True, unset any existing default + if role_data.is_default: + await session.execute( + select(SearchSpaceRole).filter( + SearchSpaceRole.search_space_id == search_space_id, + SearchSpaceRole.is_default == True, # noqa: E712 + ) + ) + existing_defaults = await session.execute( + select(SearchSpaceRole).filter( + SearchSpaceRole.search_space_id == search_space_id, + SearchSpaceRole.is_default == True, # noqa: E712 + ) + ) + for existing in existing_defaults.scalars().all(): + existing.is_default = False + + db_role = SearchSpaceRole( + **role_data.model_dump(), + search_space_id=search_space_id, + is_system_role=False, + ) + session.add(db_role) + await session.commit() + await session.refresh(db_role) + return db_role + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.error(f"Failed to create role: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to create role: {e!s}" + ) from e + + +@router.get( + "/searchspaces/{search_space_id}/roles", + response_model=list[RoleRead], +) +async def list_roles( + search_space_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + List all roles in a search space. + Requires ROLES_READ permission. + """ + try: + await check_permission( + session, + user, + search_space_id, + Permission.ROLES_READ.value, + "You don't have permission to view roles", + ) + + result = await session.execute( + select(SearchSpaceRole).filter( + SearchSpaceRole.search_space_id == search_space_id + ) + ) + return result.scalars().all() + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to fetch roles: {e!s}" + ) from e + + +@router.get( + "/searchspaces/{search_space_id}/roles/{role_id}", + response_model=RoleRead, +) +async def get_role( + search_space_id: int, + role_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Get a specific role by ID. + Requires ROLES_READ permission. + """ + try: + await check_permission( + session, + user, + search_space_id, + Permission.ROLES_READ.value, + "You don't have permission to view roles", + ) + + result = await session.execute( + select(SearchSpaceRole).filter( + SearchSpaceRole.id == role_id, + SearchSpaceRole.search_space_id == search_space_id, + ) + ) + role = result.scalars().first() + + if not role: + raise HTTPException(status_code=404, detail="Role not found") + + return role + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to fetch role: {e!s}" + ) from e + + +@router.put( + "/searchspaces/{search_space_id}/roles/{role_id}", + response_model=RoleRead, +) +async def update_role( + search_space_id: int, + role_id: int, + role_update: RoleUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Update a role. + Requires ROLES_UPDATE permission. + System roles can only have their permissions updated, not name/description. + """ + try: + await check_permission( + session, + user, + search_space_id, + Permission.ROLES_UPDATE.value, + "You don't have permission to update roles", + ) + + result = await session.execute( + select(SearchSpaceRole).filter( + SearchSpaceRole.id == role_id, + SearchSpaceRole.search_space_id == search_space_id, + ) + ) + db_role = result.scalars().first() + + if not db_role: + raise HTTPException(status_code=404, detail="Role not found") + + update_data = role_update.model_dump(exclude_unset=True) + + # System roles have restrictions on what can be updated + if db_role.is_system_role: + # Can only update permissions for system roles + restricted_fields = {"name", "description", "is_default"} + if any(field in update_data for field in restricted_fields): + raise HTTPException( + status_code=400, + detail="Cannot modify name, description, or default status of system roles", + ) + + # Check for name conflict if updating name + if "name" in update_data and update_data["name"] != db_role.name: + existing = await session.execute( + select(SearchSpaceRole).filter( + SearchSpaceRole.search_space_id == search_space_id, + SearchSpaceRole.name == update_data["name"], + ) + ) + if existing.scalars().first(): + raise HTTPException( + status_code=409, + detail=f"A role with name '{update_data['name']}' already exists", + ) + + # Validate permissions if provided + if "permissions" in update_data: + valid_permissions = {p.value for p in Permission} + for perm in update_data["permissions"]: + if perm not in valid_permissions: + raise HTTPException( + status_code=400, + detail=f"Invalid permission: {perm}", + ) + + # Handle is_default change + if update_data.get("is_default") and not db_role.is_default: + # Unset existing default + existing_defaults = await session.execute( + select(SearchSpaceRole).filter( + SearchSpaceRole.search_space_id == search_space_id, + SearchSpaceRole.is_default == True, # noqa: E712 + ) + ) + for existing in existing_defaults.scalars().all(): + existing.is_default = False + + for key, value in update_data.items(): + setattr(db_role, key, value) + + await session.commit() + await session.refresh(db_role) + return db_role + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.error(f"Failed to update role: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to update role: {e!s}" + ) from e + + +@router.delete("/searchspaces/{search_space_id}/roles/{role_id}") +async def delete_role( + search_space_id: int, + role_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Delete a custom role. + Requires ROLES_DELETE permission. + System roles cannot be deleted. + """ + try: + await check_permission( + session, + user, + search_space_id, + Permission.ROLES_DELETE.value, + "You don't have permission to delete roles", + ) + + result = await session.execute( + select(SearchSpaceRole).filter( + SearchSpaceRole.id == role_id, + SearchSpaceRole.search_space_id == search_space_id, + ) + ) + db_role = result.scalars().first() + + if not db_role: + raise HTTPException(status_code=404, detail="Role not found") + + if db_role.is_system_role: + raise HTTPException( + status_code=400, + detail="System roles cannot be deleted", + ) + + await session.delete(db_role) + await session.commit() + return {"message": "Role deleted successfully"} + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.error(f"Failed to delete role: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to delete role: {e!s}" + ) from e + + +# ============ Membership Endpoints ============ + + +@router.get( + "/searchspaces/{search_space_id}/members", + response_model=list[MembershipRead], +) +async def list_members( + search_space_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + List all members of a search space. + Requires MEMBERS_VIEW permission. + """ + try: + await check_permission( + session, + user, + search_space_id, + Permission.MEMBERS_VIEW.value, + "You don't have permission to view members", + ) + + result = await session.execute( + select(SearchSpaceMembership) + .options(selectinload(SearchSpaceMembership.role)) + .filter(SearchSpaceMembership.search_space_id == search_space_id) + ) + memberships = result.scalars().all() + + # Fetch user emails for each membership + response = [] + for membership in memberships: + user_result = await session.execute( + select(User).filter(User.id == membership.user_id) + ) + member_user = user_result.scalars().first() + + membership_dict = { + "id": membership.id, + "user_id": membership.user_id, + "search_space_id": membership.search_space_id, + "role_id": membership.role_id, + "is_owner": membership.is_owner, + "joined_at": membership.joined_at, + "created_at": membership.created_at, + "role": membership.role, + "user_email": member_user.email if member_user else None, + } + response.append(membership_dict) + + return response + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to fetch members: {e!s}" + ) from e + + +@router.put( + "/searchspaces/{search_space_id}/members/{membership_id}", + response_model=MembershipRead, +) +async def update_member_role( + search_space_id: int, + membership_id: int, + membership_update: MembershipUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Update a member's role. + Requires MEMBERS_MANAGE_ROLES permission. + Cannot change owner's role. + """ + try: + await check_permission( + session, + user, + search_space_id, + Permission.MEMBERS_MANAGE_ROLES.value, + "You don't have permission to manage member roles", + ) + + result = await session.execute( + select(SearchSpaceMembership) + .options(selectinload(SearchSpaceMembership.role)) + .filter( + SearchSpaceMembership.id == membership_id, + SearchSpaceMembership.search_space_id == search_space_id, + ) + ) + db_membership = result.scalars().first() + + if not db_membership: + raise HTTPException(status_code=404, detail="Membership not found") + + # Cannot change owner's role + if db_membership.is_owner: + raise HTTPException( + status_code=400, + detail="Cannot change the owner's role", + ) + + # Verify the new role exists in this search space + if membership_update.role_id: + role_result = await session.execute( + select(SearchSpaceRole).filter( + SearchSpaceRole.id == membership_update.role_id, + SearchSpaceRole.search_space_id == search_space_id, + ) + ) + if not role_result.scalars().first(): + raise HTTPException( + status_code=404, + detail="Role not found in this search space", + ) + + db_membership.role_id = membership_update.role_id + await session.commit() + await session.refresh(db_membership) + + # Fetch user email + user_result = await session.execute( + select(User).filter(User.id == db_membership.user_id) + ) + member_user = user_result.scalars().first() + + return { + "id": db_membership.id, + "user_id": db_membership.user_id, + "search_space_id": db_membership.search_space_id, + "role_id": db_membership.role_id, + "is_owner": db_membership.is_owner, + "joined_at": db_membership.joined_at, + "created_at": db_membership.created_at, + "role": db_membership.role, + "user_email": member_user.email if member_user else None, + } + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.error(f"Failed to update member role: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to update member role: {e!s}" + ) from e + + +@router.delete("/searchspaces/{search_space_id}/members/{membership_id}") +async def remove_member( + search_space_id: int, + membership_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Remove a member from a search space. + Requires MEMBERS_REMOVE permission. + Cannot remove the owner. + """ + try: + await check_permission( + session, + user, + search_space_id, + Permission.MEMBERS_REMOVE.value, + "You don't have permission to remove members", + ) + + result = await session.execute( + select(SearchSpaceMembership).filter( + SearchSpaceMembership.id == membership_id, + SearchSpaceMembership.search_space_id == search_space_id, + ) + ) + db_membership = result.scalars().first() + + if not db_membership: + raise HTTPException(status_code=404, detail="Membership not found") + + if db_membership.is_owner: + raise HTTPException( + status_code=400, + detail="Cannot remove the owner from the search space", + ) + + await session.delete(db_membership) + await session.commit() + return {"message": "Member removed successfully"} + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.error(f"Failed to remove member: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to remove member: {e!s}" + ) from e + + +@router.delete("/searchspaces/{search_space_id}/members/me") +async def leave_search_space( + search_space_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Leave a search space (remove own membership). + Owners cannot leave their search space. + """ + try: + result = await session.execute( + select(SearchSpaceMembership).filter( + SearchSpaceMembership.user_id == user.id, + SearchSpaceMembership.search_space_id == search_space_id, + ) + ) + db_membership = result.scalars().first() + + if not db_membership: + raise HTTPException( + status_code=404, + detail="You are not a member of this search space", + ) + + if db_membership.is_owner: + raise HTTPException( + status_code=400, + detail="Owners cannot leave their search space. Transfer ownership first or delete the search space.", + ) + + await session.delete(db_membership) + await session.commit() + return {"message": "Successfully left the search space"} + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.error(f"Failed to leave search space: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to leave search space: {e!s}" + ) from e + + +# ============ Invite Endpoints ============ + + +@router.post( + "/searchspaces/{search_space_id}/invites", + response_model=InviteRead, +) +async def create_invite( + search_space_id: int, + invite_data: InviteCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Create a new invite link for a search space. + Requires MEMBERS_INVITE permission. + """ + try: + await check_permission( + session, + user, + search_space_id, + Permission.MEMBERS_INVITE.value, + "You don't have permission to create invites", + ) + + # Verify role exists if specified + if invite_data.role_id: + role_result = await session.execute( + select(SearchSpaceRole).filter( + SearchSpaceRole.id == invite_data.role_id, + SearchSpaceRole.search_space_id == search_space_id, + ) + ) + if not role_result.scalars().first(): + raise HTTPException( + status_code=404, + detail="Role not found in this search space", + ) + + db_invite = SearchSpaceInvite( + **invite_data.model_dump(), + invite_code=generate_invite_code(), + search_space_id=search_space_id, + created_by_id=user.id, + ) + session.add(db_invite) + await session.commit() + + # Reload with role + result = await session.execute( + select(SearchSpaceInvite) + .options(selectinload(SearchSpaceInvite.role)) + .filter(SearchSpaceInvite.id == db_invite.id) + ) + db_invite = result.scalars().first() + + return db_invite + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.error(f"Failed to create invite: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to create invite: {e!s}" + ) from e + + +@router.get( + "/searchspaces/{search_space_id}/invites", + response_model=list[InviteRead], +) +async def list_invites( + search_space_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + List all invites for a search space. + Requires MEMBERS_INVITE permission. + """ + try: + await check_permission( + session, + user, + search_space_id, + Permission.MEMBERS_INVITE.value, + "You don't have permission to view invites", + ) + + result = await session.execute( + select(SearchSpaceInvite) + .options(selectinload(SearchSpaceInvite.role)) + .filter(SearchSpaceInvite.search_space_id == search_space_id) + ) + return result.scalars().all() + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to fetch invites: {e!s}" + ) from e + + +@router.put( + "/searchspaces/{search_space_id}/invites/{invite_id}", + response_model=InviteRead, +) +async def update_invite( + search_space_id: int, + invite_id: int, + invite_update: InviteUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Update an invite. + Requires MEMBERS_INVITE permission. + """ + try: + await check_permission( + session, + user, + search_space_id, + Permission.MEMBERS_INVITE.value, + "You don't have permission to update invites", + ) + + result = await session.execute( + select(SearchSpaceInvite) + .options(selectinload(SearchSpaceInvite.role)) + .filter( + SearchSpaceInvite.id == invite_id, + SearchSpaceInvite.search_space_id == search_space_id, + ) + ) + db_invite = result.scalars().first() + + if not db_invite: + raise HTTPException(status_code=404, detail="Invite not found") + + update_data = invite_update.model_dump(exclude_unset=True) + + # Verify role exists if updating role_id + if update_data.get("role_id"): + role_result = await session.execute( + select(SearchSpaceRole).filter( + SearchSpaceRole.id == update_data["role_id"], + SearchSpaceRole.search_space_id == search_space_id, + ) + ) + if not role_result.scalars().first(): + raise HTTPException( + status_code=404, + detail="Role not found in this search space", + ) + + for key, value in update_data.items(): + setattr(db_invite, key, value) + + await session.commit() + await session.refresh(db_invite) + return db_invite + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.error(f"Failed to update invite: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to update invite: {e!s}" + ) from e + + +@router.delete("/searchspaces/{search_space_id}/invites/{invite_id}") +async def revoke_invite( + search_space_id: int, + invite_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Revoke (delete) an invite. + Requires MEMBERS_INVITE permission. + """ + try: + await check_permission( + session, + user, + search_space_id, + Permission.MEMBERS_INVITE.value, + "You don't have permission to revoke invites", + ) + + result = await session.execute( + select(SearchSpaceInvite).filter( + SearchSpaceInvite.id == invite_id, + SearchSpaceInvite.search_space_id == search_space_id, + ) + ) + db_invite = result.scalars().first() + + if not db_invite: + raise HTTPException(status_code=404, detail="Invite not found") + + await session.delete(db_invite) + await session.commit() + return {"message": "Invite revoked successfully"} + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.error(f"Failed to revoke invite: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to revoke invite: {e!s}" + ) from e + + +# ============ Public Invite Endpoints ============ + + +@router.get("/invites/{invite_code}/info", response_model=InviteInfoResponse) +async def get_invite_info( + invite_code: str, + session: AsyncSession = Depends(get_async_session), +): + """ + Get information about an invite (public endpoint, no auth required). + Returns minimal info for displaying on invite acceptance page. + """ + try: + result = await session.execute( + select(SearchSpaceInvite) + .options( + selectinload(SearchSpaceInvite.role), + selectinload(SearchSpaceInvite.search_space), + ) + .filter(SearchSpaceInvite.invite_code == invite_code) + ) + invite = result.scalars().first() + + if not invite: + return InviteInfoResponse( + search_space_name="", + role_name=None, + is_valid=False, + message="Invite not found", + ) + + # Check if invite is still valid + if not invite.is_active: + return InviteInfoResponse( + search_space_name=invite.search_space.name + if invite.search_space + else "", + role_name=invite.role.name if invite.role else None, + is_valid=False, + message="This invite is no longer active", + ) + + if invite.expires_at and invite.expires_at < datetime.now(UTC): + return InviteInfoResponse( + search_space_name=invite.search_space.name + if invite.search_space + else "", + role_name=invite.role.name if invite.role else None, + is_valid=False, + message="This invite has expired", + ) + + if invite.max_uses and invite.uses_count >= invite.max_uses: + return InviteInfoResponse( + search_space_name=invite.search_space.name + if invite.search_space + else "", + role_name=invite.role.name if invite.role else None, + is_valid=False, + message="This invite has reached its maximum uses", + ) + + return InviteInfoResponse( + search_space_name=invite.search_space.name if invite.search_space else "", + role_name=invite.role.name if invite.role else "Default", + is_valid=True, + ) + + except Exception as e: + logger.error(f"Failed to get invite info: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to get invite info: {e!s}" + ) from e + + +@router.post("/invites/accept", response_model=InviteAcceptResponse) +async def accept_invite( + request: InviteAcceptRequest, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Accept an invite and join a search space. + """ + try: + result = await session.execute( + select(SearchSpaceInvite) + .options( + selectinload(SearchSpaceInvite.role), + selectinload(SearchSpaceInvite.search_space), + ) + .filter(SearchSpaceInvite.invite_code == request.invite_code) + ) + invite = result.scalars().first() + + if not invite: + raise HTTPException(status_code=404, detail="Invite not found") + + # Validate invite + if not invite.is_active: + raise HTTPException( + status_code=400, detail="This invite is no longer active" + ) + + if invite.expires_at and invite.expires_at < datetime.now(UTC): + raise HTTPException(status_code=400, detail="This invite has expired") + + if invite.max_uses and invite.uses_count >= invite.max_uses: + raise HTTPException( + status_code=400, detail="This invite has reached its maximum uses" + ) + + # Check if user is already a member + existing_membership = await session.execute( + select(SearchSpaceMembership).filter( + SearchSpaceMembership.user_id == user.id, + SearchSpaceMembership.search_space_id == invite.search_space_id, + ) + ) + if existing_membership.scalars().first(): + raise HTTPException( + status_code=400, + detail="You are already a member of this search space", + ) + + # Determine role to assign + role_id = invite.role_id + if not role_id: + # Use default role + default_role = await get_default_role(session, invite.search_space_id) + role_id = default_role.id if default_role else None + + # Create membership + membership = SearchSpaceMembership( + user_id=user.id, + search_space_id=invite.search_space_id, + role_id=role_id, + is_owner=False, + invited_by_invite_id=invite.id, + ) + session.add(membership) + + # Increment invite usage + invite.uses_count += 1 + + await session.commit() + + role_name = invite.role.name if invite.role else "Default" + search_space_name = invite.search_space.name if invite.search_space else "" + + return InviteAcceptResponse( + message="Successfully joined the search space", + search_space_id=invite.search_space_id, + search_space_name=search_space_name, + role_name=role_name, + ) + + except HTTPException: + raise + except Exception as e: + await session.rollback() + logger.error(f"Failed to accept invite: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to accept invite: {e!s}" + ) from e + + +# ============ User Access Info ============ + + +@router.get( + "/searchspaces/{search_space_id}/my-access", + response_model=UserSearchSpaceAccess, +) +async def get_my_access( + search_space_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Get the current user's access info for a search space. + """ + try: + membership = await check_search_space_access(session, user, search_space_id) + + # Get search space name + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + + # Get permissions + permissions = await get_user_permissions(session, user.id, search_space_id) + + return UserSearchSpaceAccess( + search_space_id=search_space_id, + search_space_name=search_space.name if search_space else "", + is_owner=membership.is_owner, + role_name=membership.role.name if membership.role else None, + permissions=permissions, + ) + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to get access info: {e!s}" + ) from e diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index bf397a352..624353e19 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -22,9 +22,9 @@ from sqlalchemy.future import select from app.connectors.github_connector import GitHubConnector from app.db import ( + Permission, SearchSourceConnector, SearchSourceConnectorType, - SearchSpace, User, async_session_maker, get_async_session, @@ -52,12 +52,12 @@ from app.tasks.connector_indexers import ( index_slack_messages, ) from app.users import current_active_user -from app.utils.check_ownership import check_ownership from app.utils.periodic_scheduler import ( create_periodic_schedule, delete_periodic_schedule, update_periodic_schedule, ) +from app.utils.rbac import check_permission # Set up logging logger = logging.getLogger(__name__) @@ -108,19 +108,25 @@ async def create_search_source_connector( ): """ Create a new search source connector. + Requires CONNECTORS_CREATE permission. - Each search space can have only one connector of each type per user (based on search_space_id, user_id, and connector_type). + Each search space can have only one connector of each type (based on search_space_id and connector_type). The config must contain the appropriate keys for the connector type. """ try: - # Check if the search space belongs to the user - await check_ownership(session, SearchSpace, search_space_id, user) + # Check if user has permission to create connectors + await check_permission( + session, + user, + search_space_id, + Permission.CONNECTORS_CREATE.value, + "You don't have permission to create connectors in this search space", + ) - # Check if a connector with the same type already exists for this search space and user + # Check if a connector with the same type already exists for this search space result = await session.execute( select(SearchSourceConnector).filter( SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.user_id == user.id, SearchSourceConnector.connector_type == connector.connector_type, ) ) @@ -128,7 +134,7 @@ async def create_search_source_connector( if existing_connector: raise HTTPException( status_code=409, - detail=f"A connector with type {connector.connector_type} already exists in this search space. Each search space can have only one connector of each type per user.", + detail=f"A connector with type {connector.connector_type} already exists in this search space.", ) # Prepare connector data @@ -198,22 +204,34 @@ async def read_search_source_connectors( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """List all search source connectors for the current user, optionally filtered by search space.""" + """ + List all search source connectors for a search space. + Requires CONNECTORS_READ permission. + """ try: - query = select(SearchSourceConnector).filter( - SearchSourceConnector.user_id == user.id + if search_space_id is None: + raise HTTPException( + status_code=400, + detail="search_space_id is required", + ) + + # Check if user has permission to read connectors + await check_permission( + session, + user, + search_space_id, + Permission.CONNECTORS_READ.value, + "You don't have permission to view connectors in this search space", ) - # Filter by search_space_id if provided - if search_space_id is not None: - # Verify the search space belongs to the user - await check_ownership(session, SearchSpace, search_space_id, user) - query = query.filter( - SearchSourceConnector.search_space_id == search_space_id - ) + query = select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id + ) result = await session.execute(query.offset(skip).limit(limit)) return result.scalars().all() + except HTTPException: + raise except Exception as e: raise HTTPException( status_code=500, @@ -229,9 +247,32 @@ async def read_search_source_connector( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Get a specific search source connector by ID.""" + """ + Get a specific search source connector by ID. + Requires CONNECTORS_READ permission. + """ try: - return await check_ownership(session, SearchSourceConnector, connector_id, user) + # Get the connector first + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id + ) + ) + connector = result.scalars().first() + + if not connector: + raise HTTPException(status_code=404, detail="Connector not found") + + # Check permission + await check_permission( + session, + user, + connector.search_space_id, + Permission.CONNECTORS_READ.value, + "You don't have permission to view this connector", + ) + + return connector except HTTPException: raise except Exception as e: @@ -251,10 +292,25 @@ async def update_search_source_connector( ): """ Update a search source connector. + Requires CONNECTORS_UPDATE permission. Handles partial updates, including merging changes into the 'config' field. """ - db_connector = await check_ownership( - session, SearchSourceConnector, connector_id, user + # Get the connector first + result = await session.execute( + select(SearchSourceConnector).filter(SearchSourceConnector.id == connector_id) + ) + db_connector = result.scalars().first() + + if not db_connector: + raise HTTPException(status_code=404, detail="Connector not found") + + # Check permission + await check_permission( + session, + user, + db_connector.search_space_id, + Permission.CONNECTORS_UPDATE.value, + "You don't have permission to update this connector", ) # Convert the sparse update data (only fields present in request) to a dict @@ -349,20 +405,19 @@ async def update_search_source_connector( for key, value in update_data.items(): # Prevent changing connector_type if it causes a duplicate (check moved here) if key == "connector_type" and value != db_connector.connector_type: - result = await session.execute( + check_result = await session.execute( select(SearchSourceConnector).filter( SearchSourceConnector.search_space_id == db_connector.search_space_id, - SearchSourceConnector.user_id == user.id, SearchSourceConnector.connector_type == value, SearchSourceConnector.id != connector_id, ) ) - existing_connector = result.scalars().first() + existing_connector = check_result.scalars().first() if existing_connector: raise HTTPException( status_code=409, - detail=f"A connector with type {value} already exists in this search space. Each search space can have only one connector of each type per user.", + detail=f"A connector with type {value} already exists in this search space.", ) setattr(db_connector, key, value) @@ -425,10 +480,29 @@ async def delete_search_source_connector( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Delete a search source connector.""" + """ + Delete a search source connector. + Requires CONNECTORS_DELETE permission. + """ try: - db_connector = await check_ownership( - session, SearchSourceConnector, connector_id, user + # Get the connector first + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id + ) + ) + db_connector = result.scalars().first() + + if not db_connector: + raise HTTPException(status_code=404, detail="Connector not found") + + # Check permission + await check_permission( + session, + user, + db_connector.search_space_id, + Permission.CONNECTORS_DELETE.value, + "You don't have permission to delete this connector", ) # Delete any periodic schedule associated with this connector @@ -473,6 +547,7 @@ async def index_connector_content( ): """ Index content from a connector to a search space. + Requires CONNECTORS_UPDATE permission (to trigger indexing). Currently supports: - SLACK_CONNECTOR: Indexes messages from all accessible Slack channels @@ -488,20 +563,29 @@ async def index_connector_content( Args: connector_id: ID of the connector to use search_space_id: ID of the search space to store indexed content - background_tasks: FastAPI background tasks Returns: Dictionary with indexing status """ try: - # Check if the connector belongs to the user - connector = await check_ownership( - session, SearchSourceConnector, connector_id, user + # Get the connector first + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id + ) ) + connector = result.scalars().first() - # Check if the search space belongs to the user - _search_space = await check_ownership( - session, SearchSpace, search_space_id, user + if not connector: + raise HTTPException(status_code=404, detail="Connector not found") + + # Check if user has permission to update connectors (indexing is an update operation) + await check_permission( + session, + user, + search_space_id, + Permission.CONNECTORS_UPDATE.value, + "You don't have permission to index content in this search space", ) # Handle different connector types diff --git a/surfsense_backend/app/routes/search_spaces_routes.py b/surfsense_backend/app/routes/search_spaces_routes.py index 7a01f2171..d04cf11ce 100644 --- a/surfsense_backend/app/routes/search_spaces_routes.py +++ b/surfsense_backend/app/routes/search_spaces_routes.py @@ -1,18 +1,77 @@ +import logging from pathlib import Path import yaml from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.db import SearchSpace, User, get_async_session -from app.schemas import SearchSpaceCreate, SearchSpaceRead, SearchSpaceUpdate +from app.db import ( + Permission, + SearchSpace, + SearchSpaceMembership, + SearchSpaceRole, + User, + get_async_session, + get_default_roles_config, +) +from app.schemas import ( + SearchSpaceCreate, + SearchSpaceRead, + SearchSpaceUpdate, + SearchSpaceWithStats, +) from app.users import current_active_user -from app.utils.check_ownership import check_ownership +from app.utils.rbac import check_permission, check_search_space_access + +logger = logging.getLogger(__name__) router = APIRouter() +async def create_default_roles_and_membership( + session: AsyncSession, + search_space_id: int, + owner_user_id, +) -> None: + """ + Create default system roles for a search space and add the owner as a member. + + Args: + session: Database session + search_space_id: The ID of the newly created search space + owner_user_id: The UUID of the user who created the search space + """ + # Create default roles + default_roles = get_default_roles_config() + owner_role_id = None + + for role_config in default_roles: + db_role = SearchSpaceRole( + name=role_config["name"], + description=role_config["description"], + permissions=role_config["permissions"], + is_default=role_config["is_default"], + is_system_role=role_config["is_system_role"], + search_space_id=search_space_id, + ) + session.add(db_role) + await session.flush() # Get the ID + + if role_config["name"] == "Owner": + owner_role_id = db_role.id + + # Create owner membership + owner_membership = SearchSpaceMembership( + user_id=owner_user_id, + search_space_id=search_space_id, + role_id=owner_role_id, + is_owner=True, + ) + session.add(owner_membership) + + @router.post("/searchspaces", response_model=SearchSpaceRead) async def create_search_space( search_space: SearchSpaceCreate, @@ -27,6 +86,11 @@ async def create_search_space( db_search_space = SearchSpace(**search_space_data, user_id=user.id) session.add(db_search_space) + await session.flush() # Get the search space ID + + # Create default roles and owner membership + await create_default_roles_and_membership(session, db_search_space.id, user.id) + await session.commit() await session.refresh(db_search_space) return db_search_space @@ -34,26 +98,86 @@ async def create_search_space( raise except Exception as e: await session.rollback() + logger.error(f"Failed to create search space: {e!s}", exc_info=True) raise HTTPException( status_code=500, detail=f"Failed to create search space: {e!s}" ) from e -@router.get("/searchspaces", response_model=list[SearchSpaceRead]) +@router.get("/searchspaces", response_model=list[SearchSpaceWithStats]) async def read_search_spaces( skip: int = 0, limit: int = 200, + owned_only: bool = False, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + Get all search spaces the user has access to, with member count and ownership info. + + Args: + skip: Number of items to skip + limit: Maximum number of items to return + owned_only: If True, only return search spaces owned by the user. + If False (default), return all search spaces the user has access to. + """ try: - result = await session.execute( - select(SearchSpace) - .filter(SearchSpace.user_id == user.id) - .offset(skip) - .limit(limit) - ) - return result.scalars().all() + if owned_only: + # Return only search spaces where user is the original creator (user_id) + result = await session.execute( + select(SearchSpace) + .filter(SearchSpace.user_id == user.id) + .offset(skip) + .limit(limit) + ) + else: + # Return all search spaces the user has membership in + result = await session.execute( + select(SearchSpace) + .join(SearchSpaceMembership) + .filter(SearchSpaceMembership.user_id == user.id) + .offset(skip) + .limit(limit) + ) + + search_spaces = result.scalars().all() + + # Get member counts and ownership info for each search space + search_spaces_with_stats = [] + for space in search_spaces: + # Get member count + count_result = await session.execute( + select(func.count(SearchSpaceMembership.id)).filter( + SearchSpaceMembership.search_space_id == space.id + ) + ) + member_count = count_result.scalar() or 1 + + # Check if current user is owner + ownership_result = await session.execute( + select(SearchSpaceMembership).filter( + SearchSpaceMembership.search_space_id == space.id, + SearchSpaceMembership.user_id == user.id, + SearchSpaceMembership.is_owner == True, # noqa: E712 + ) + ) + is_owner = ownership_result.scalars().first() is not None + + search_spaces_with_stats.append( + SearchSpaceWithStats( + id=space.id, + name=space.name, + description=space.description, + created_at=space.created_at, + user_id=space.user_id, + citations_enabled=space.citations_enabled, + qna_custom_instructions=space.qna_custom_instructions, + member_count=member_count, + is_owner=is_owner, + ) + ) + + return search_spaces_with_stats except Exception as e: raise HTTPException( status_code=500, detail=f"Failed to fetch search spaces: {e!s}" @@ -97,10 +221,22 @@ async def read_search_space( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + Get a specific search space by ID. + Requires SETTINGS_VIEW permission or membership. + """ try: - search_space = await check_ownership( - session, SearchSpace, search_space_id, user + # Check if user has access (is a member) + await check_search_space_access(session, user, search_space_id) + + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) ) + search_space = result.scalars().first() + + if not search_space: + raise HTTPException(status_code=404, detail="Search space not found") + return search_space except HTTPException: @@ -118,10 +254,28 @@ async def update_search_space( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + Update a search space. + Requires SETTINGS_UPDATE permission. + """ try: - db_search_space = await check_ownership( - session, SearchSpace, search_space_id, user + # Check permission + await check_permission( + session, + user, + search_space_id, + Permission.SETTINGS_UPDATE.value, + "You don't have permission to update this search space", ) + + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + db_search_space = result.scalars().first() + + if not db_search_space: + raise HTTPException(status_code=404, detail="Search space not found") + update_data = search_space_update.model_dump(exclude_unset=True) for key, value in update_data.items(): setattr(db_search_space, key, value) @@ -143,10 +297,28 @@ async def delete_search_space( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): + """ + Delete a search space. + Requires SETTINGS_DELETE permission (only owners have this by default). + """ try: - db_search_space = await check_ownership( - session, SearchSpace, search_space_id, user + # Check permission - only those with SETTINGS_DELETE can delete + await check_permission( + session, + user, + search_space_id, + Permission.SETTINGS_DELETE.value, + "You don't have permission to delete this search space", ) + + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + db_search_space = result.scalars().first() + + if not db_search_space: + raise HTTPException(status_code=404, detail="Search space not found") + await session.delete(db_search_space) await session.commit() return {"message": "Search space deleted successfully"} diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 41b2ce23c..d48d1b7f3 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -27,6 +27,23 @@ from .podcasts import ( PodcastRead, PodcastUpdate, ) +from .rbac_schemas import ( + InviteAcceptRequest, + InviteAcceptResponse, + InviteCreate, + InviteInfoResponse, + InviteRead, + InviteUpdate, + MembershipRead, + MembershipReadWithUser, + MembershipUpdate, + PermissionInfo, + PermissionsListResponse, + RoleCreate, + RoleRead, + RoleUpdate, + UserSearchSpaceAccess, +) from .search_source_connector import ( SearchSourceConnectorBase, SearchSourceConnectorCreate, @@ -38,6 +55,7 @@ from .search_space import ( SearchSpaceCreate, SearchSpaceRead, SearchSpaceUpdate, + SearchSpaceWithStats, ) from .users import UserCreate, UserRead, UserUpdate @@ -60,6 +78,13 @@ __all__ = [ "ExtensionDocumentContent", "ExtensionDocumentMetadata", "IDModel", + # RBAC schemas + "InviteAcceptRequest", + "InviteAcceptResponse", + "InviteCreate", + "InviteInfoResponse", + "InviteRead", + "InviteUpdate", "LLMConfigBase", "LLMConfigCreate", "LLMConfigRead", @@ -69,12 +94,20 @@ __all__ = [ "LogFilter", "LogRead", "LogUpdate", + "MembershipRead", + "MembershipReadWithUser", + "MembershipUpdate", "PaginatedResponse", + "PermissionInfo", + "PermissionsListResponse", "PodcastBase", "PodcastCreate", "PodcastGenerateRequest", "PodcastRead", "PodcastUpdate", + "RoleCreate", + "RoleRead", + "RoleUpdate", "SearchSourceConnectorBase", "SearchSourceConnectorCreate", "SearchSourceConnectorRead", @@ -83,8 +116,10 @@ __all__ = [ "SearchSpaceCreate", "SearchSpaceRead", "SearchSpaceUpdate", + "SearchSpaceWithStats", "TimestampModel", "UserCreate", "UserRead", + "UserSearchSpaceAccess", "UserUpdate", ] diff --git a/surfsense_backend/app/schemas/rbac_schemas.py b/surfsense_backend/app/schemas/rbac_schemas.py new file mode 100644 index 000000000..736d40807 --- /dev/null +++ b/surfsense_backend/app/schemas/rbac_schemas.py @@ -0,0 +1,186 @@ +""" +Pydantic schemas for RBAC (Role-Based Access Control) endpoints. +""" + +from datetime import datetime +from uuid import UUID + +from pydantic import BaseModel, Field + +# ============ Role Schemas ============ + + +class RoleBase(BaseModel): + """Base schema for roles.""" + + name: str = Field(..., min_length=1, max_length=100) + description: str | None = Field(None, max_length=500) + permissions: list[str] = Field(default_factory=list) + is_default: bool = False + + +class RoleCreate(RoleBase): + """Schema for creating a new role.""" + + pass + + +class RoleUpdate(BaseModel): + """Schema for updating a role (partial update).""" + + name: str | None = Field(None, min_length=1, max_length=100) + description: str | None = Field(None, max_length=500) + permissions: list[str] | None = None + is_default: bool | None = None + + +class RoleRead(RoleBase): + """Schema for reading a role.""" + + id: int + search_space_id: int + is_system_role: bool + created_at: datetime + + class Config: + from_attributes = True + + +# ============ Membership Schemas ============ + + +class MembershipBase(BaseModel): + """Base schema for memberships.""" + + pass + + +class MembershipUpdate(BaseModel): + """Schema for updating a membership (change role).""" + + role_id: int | None = None + + +class MembershipRead(BaseModel): + """Schema for reading a membership.""" + + id: int + user_id: UUID + search_space_id: int + role_id: int | None + is_owner: bool + joined_at: datetime + created_at: datetime + # Nested role info + role: RoleRead | None = None + # User email (populated separately) + user_email: str | None = None + + class Config: + from_attributes = True + + +class MembershipReadWithUser(MembershipRead): + """Schema for reading a membership with user details.""" + + user_email: str | None = None + user_is_active: bool | None = None + + +# ============ Invite Schemas ============ + + +class InviteBase(BaseModel): + """Base schema for invites.""" + + name: str | None = Field(None, max_length=100) + role_id: int | None = None + expires_at: datetime | None = None + max_uses: int | None = Field(None, ge=1) + + +class InviteCreate(InviteBase): + """Schema for creating a new invite.""" + + pass + + +class InviteUpdate(BaseModel): + """Schema for updating an invite (partial update).""" + + name: str | None = Field(None, max_length=100) + role_id: int | None = None + expires_at: datetime | None = None + max_uses: int | None = Field(None, ge=1) + is_active: bool | None = None + + +class InviteRead(InviteBase): + """Schema for reading an invite.""" + + id: int + invite_code: str + search_space_id: int + created_by_id: UUID | None + uses_count: int + is_active: bool + created_at: datetime + # Nested role info + role: RoleRead | None = None + + class Config: + from_attributes = True + + +class InviteAcceptRequest(BaseModel): + """Schema for accepting an invite.""" + + invite_code: str = Field(..., min_length=1) + + +class InviteAcceptResponse(BaseModel): + """Response schema for accepting an invite.""" + + message: str + search_space_id: int + search_space_name: str + role_name: str | None + + +class InviteInfoResponse(BaseModel): + """Response schema for getting invite info (public endpoint).""" + + search_space_name: str + role_name: str | None + is_valid: bool + message: str | None = None + + +# ============ Permission Schemas ============ + + +class PermissionInfo(BaseModel): + """Schema for permission information.""" + + value: str + name: str + category: str + + +class PermissionsListResponse(BaseModel): + """Response schema for listing all available permissions.""" + + permissions: list[PermissionInfo] + + +# ============ User Access Info ============ + + +class UserSearchSpaceAccess(BaseModel): + """Schema for user's access info in a search space.""" + + search_space_id: int + search_space_name: str + is_owner: bool + role_name: str | None + permissions: list[str] diff --git a/surfsense_backend/app/schemas/search_space.py b/surfsense_backend/app/schemas/search_space.py index 49cc0791f..729ff4e7d 100644 --- a/surfsense_backend/app/schemas/search_space.py +++ b/surfsense_backend/app/schemas/search_space.py @@ -34,3 +34,10 @@ class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel): qna_custom_instructions: str | None = None model_config = ConfigDict(from_attributes=True) + + +class SearchSpaceWithStats(SearchSpaceRead): + """Extended search space info with member count and ownership status.""" + + member_count: int = 1 + is_owner: bool = False diff --git a/surfsense_backend/app/services/connector_service.py b/surfsense_backend/app/services/connector_service.py index 3445d69f7..20a9ffa32 100644 --- a/surfsense_backend/app/services/connector_service.py +++ b/surfsense_backend/app/services/connector_service.py @@ -15,18 +15,17 @@ from app.db import ( Document, SearchSourceConnector, SearchSourceConnectorType, - SearchSpace, ) from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever from app.retriver.documents_hybrid_search import DocumentHybridSearchRetriever class ConnectorService: - def __init__(self, session: AsyncSession, user_id: str | None = None): + def __init__(self, session: AsyncSession, search_space_id: int | None = None): self.session = session self.chunk_retriever = ChucksHybridSearchRetriever(session) self.document_retriever = DocumentHybridSearchRetriever(session) - self.user_id = user_id + self.search_space_id = search_space_id self.source_id_counter = ( 100000 # High starting value to avoid collisions with existing IDs ) @@ -36,23 +35,22 @@ class ConnectorService: async def initialize_counter(self): """ - Initialize the source_id_counter based on the total number of chunks for the user. + Initialize the source_id_counter based on the total number of chunks for the search space. This ensures unique IDs across different sessions. """ - if self.user_id: + if self.search_space_id: try: - # Count total chunks for documents belonging to this user + # Count total chunks for documents belonging to this search space result = await self.session.execute( select(func.count(Chunk.id)) .join(Document) - .join(SearchSpace) - .filter(SearchSpace.user_id == self.user_id) + .filter(Document.search_space_id == self.search_space_id) ) chunk_count = result.scalar() or 0 self.source_id_counter = chunk_count + 1 print( - f"Initialized source_id_counter to {self.source_id_counter} for user {self.user_id}" + f"Initialized source_id_counter to {self.source_id_counter} for search space {self.search_space_id}" ) except Exception as e: print(f"Error initializing source_id_counter: {e!s}") @@ -62,7 +60,6 @@ class ConnectorService: async def search_crawled_urls( self, user_query: str, - user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS, @@ -72,7 +69,6 @@ class ConnectorService: Args: user_query: The user's query - user_id: The user's ID search_space_id: The search space ID to search in top_k: Maximum number of results to return search_mode: Search mode (CHUNKS or DOCUMENTS) @@ -84,7 +80,6 @@ class ConnectorService: crawled_urls_chunks = await self.chunk_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="CRAWLED_URL", ) @@ -92,7 +87,6 @@ class ConnectorService: crawled_urls_chunks = await self.document_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="CRAWLED_URL", ) @@ -171,7 +165,6 @@ class ConnectorService: async def search_files( self, user_query: str, - user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS, @@ -186,7 +179,6 @@ class ConnectorService: files_chunks = await self.chunk_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="FILE", ) @@ -194,7 +186,6 @@ class ConnectorService: files_chunks = await self.document_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="FILE", ) @@ -274,43 +265,35 @@ class ConnectorService: async def get_connector_by_type( self, - user_id: str, connector_type: SearchSourceConnectorType, - search_space_id: int | None = None, + search_space_id: int, ) -> SearchSourceConnector | None: """ - Get a connector by type for a specific user and optionally a search space + Get a connector by type for a specific search space Args: - user_id: The user's ID connector_type: The connector type to retrieve - search_space_id: Optional search space ID to filter by + search_space_id: The search space ID to filter by Returns: Optional[SearchSourceConnector]: The connector if found, None otherwise """ query = select(SearchSourceConnector).filter( - SearchSourceConnector.user_id == user_id, + SearchSourceConnector.search_space_id == search_space_id, SearchSourceConnector.connector_type == connector_type, ) - if search_space_id is not None: - query = query.filter( - SearchSourceConnector.search_space_id == search_space_id - ) - result = await self.session.execute(query) return result.scalars().first() async def search_tavily( - self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20 + self, user_query: str, search_space_id: int, top_k: int = 20 ) -> tuple: """ Search using Tavily API and return both the source information and documents Args: user_query: The user's query - user_id: The user's ID search_space_id: The search space ID top_k: Maximum number of results to return @@ -319,7 +302,7 @@ class ConnectorService: """ # Get Tavily connector configuration tavily_connector = await self.get_connector_by_type( - user_id, SearchSourceConnectorType.TAVILY_API, search_space_id + SearchSourceConnectorType.TAVILY_API, search_space_id ) if not tavily_connector: @@ -412,7 +395,6 @@ class ConnectorService: async def search_searxng( self, user_query: str, - user_id: str, search_space_id: int, top_k: int = 20, ) -> tuple: @@ -420,7 +402,7 @@ class ConnectorService: Search using a configured SearxNG instance and return both sources and documents. """ searx_connector = await self.get_connector_by_type( - user_id, SearchSourceConnectorType.SEARXNG_API, search_space_id + SearchSourceConnectorType.SEARXNG_API, search_space_id ) if not searx_connector: @@ -598,7 +580,6 @@ class ConnectorService: async def search_baidu( self, user_query: str, - user_id: str, search_space_id: int, top_k: int = 20, ) -> tuple: @@ -610,7 +591,6 @@ class ConnectorService: Args: user_query: User's search query - user_id: User ID search_space_id: Search space ID top_k: Maximum number of results to return @@ -619,7 +599,7 @@ class ConnectorService: """ # Get Baidu connector configuration baidu_connector = await self.get_connector_by_type( - user_id, SearchSourceConnectorType.BAIDU_SEARCH_API, search_space_id + SearchSourceConnectorType.BAIDU_SEARCH_API, search_space_id ) if not baidu_connector: @@ -824,7 +804,6 @@ class ConnectorService: async def search_slack( self, user_query: str, - user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS, @@ -839,7 +818,6 @@ class ConnectorService: slack_chunks = await self.chunk_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="SLACK_CONNECTOR", ) @@ -847,7 +825,6 @@ class ConnectorService: slack_chunks = await self.document_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="SLACK_CONNECTOR", ) @@ -912,7 +889,6 @@ class ConnectorService: async def search_notion( self, user_query: str, - user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS, @@ -922,7 +898,6 @@ class ConnectorService: Args: user_query: The user's query - user_id: The user's ID search_space_id: The search space ID to search in top_k: Maximum number of results to return @@ -933,7 +908,6 @@ class ConnectorService: notion_chunks = await self.chunk_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="NOTION_CONNECTOR", ) @@ -941,7 +915,6 @@ class ConnectorService: notion_chunks = await self.document_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="NOTION_CONNECTOR", ) @@ -1009,7 +982,6 @@ class ConnectorService: async def search_extension( self, user_query: str, - user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS, @@ -1019,7 +991,6 @@ class ConnectorService: Args: user_query: The user's query - user_id: The user's ID search_space_id: The search space ID to search in top_k: Maximum number of results to return @@ -1030,7 +1001,6 @@ class ConnectorService: extension_chunks = await self.chunk_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="EXTENSION", ) @@ -1038,7 +1008,6 @@ class ConnectorService: extension_chunks = await self.document_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="EXTENSION", ) @@ -1130,7 +1099,6 @@ class ConnectorService: async def search_youtube( self, user_query: str, - user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS, @@ -1140,7 +1108,6 @@ class ConnectorService: Args: user_query: The user's query - user_id: The user's ID search_space_id: The search space ID to search in top_k: Maximum number of results to return @@ -1151,7 +1118,6 @@ class ConnectorService: youtube_chunks = await self.chunk_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="YOUTUBE_VIDEO", ) @@ -1159,7 +1125,6 @@ class ConnectorService: youtube_chunks = await self.document_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="YOUTUBE_VIDEO", ) @@ -1227,7 +1192,6 @@ class ConnectorService: async def search_github( self, user_query: str, - user_id: int, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS, @@ -1242,7 +1206,6 @@ class ConnectorService: github_chunks = await self.chunk_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="GITHUB_CONNECTOR", ) @@ -1250,7 +1213,6 @@ class ConnectorService: github_chunks = await self.document_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="GITHUB_CONNECTOR", ) @@ -1302,7 +1264,6 @@ class ConnectorService: async def search_linear( self, user_query: str, - user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS, @@ -1312,7 +1273,6 @@ class ConnectorService: Args: user_query: The user's query - user_id: The user's ID search_space_id: The search space ID to search in top_k: Maximum number of results to return @@ -1323,7 +1283,6 @@ class ConnectorService: linear_chunks = await self.chunk_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="LINEAR_CONNECTOR", ) @@ -1331,7 +1290,6 @@ class ConnectorService: linear_chunks = await self.document_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="LINEAR_CONNECTOR", ) @@ -1411,7 +1369,6 @@ class ConnectorService: async def search_jira( self, user_query: str, - user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS, @@ -1421,7 +1378,6 @@ class ConnectorService: Args: user_query: The user's query - user_id: The user's ID search_space_id: The search space ID to search in top_k: Maximum number of results to return search_mode: Search mode (CHUNKS or DOCUMENTS) @@ -1433,7 +1389,6 @@ class ConnectorService: jira_chunks = await self.chunk_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="JIRA_CONNECTOR", ) @@ -1441,7 +1396,6 @@ class ConnectorService: jira_chunks = await self.document_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="JIRA_CONNECTOR", ) @@ -1532,7 +1486,6 @@ class ConnectorService: async def search_google_calendar( self, user_query: str, - user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS, @@ -1542,7 +1495,6 @@ class ConnectorService: Args: user_query: The user's query - user_id: The user's ID search_space_id: The search space ID to search in top_k: Maximum number of results to return search_mode: Search mode (CHUNKS or DOCUMENTS) @@ -1554,7 +1506,6 @@ class ConnectorService: calendar_chunks = await self.chunk_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="GOOGLE_CALENDAR_CONNECTOR", ) @@ -1562,7 +1513,6 @@ class ConnectorService: calendar_chunks = await self.document_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="GOOGLE_CALENDAR_CONNECTOR", ) @@ -1665,7 +1615,6 @@ class ConnectorService: async def search_airtable( self, user_query: str, - user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS, @@ -1675,7 +1624,6 @@ class ConnectorService: Args: user_query: The user's query - user_id: The user's ID search_space_id: The search space ID to search in top_k: Maximum number of results to return search_mode: Search mode (CHUNKS or DOCUMENTS) @@ -1687,7 +1635,6 @@ class ConnectorService: airtable_chunks = await self.chunk_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="AIRTABLE_CONNECTOR", ) @@ -1695,7 +1642,6 @@ class ConnectorService: airtable_chunks = await self.document_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="AIRTABLE_CONNECTOR", ) @@ -1753,7 +1699,6 @@ class ConnectorService: async def search_google_gmail( self, user_query: str, - user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS, @@ -1763,7 +1708,6 @@ class ConnectorService: Args: user_query: The user's query - user_id: The user's ID search_space_id: The search space ID to search in top_k: Maximum number of results to return search_mode: Search mode (CHUNKS or DOCUMENTS) @@ -1775,7 +1719,6 @@ class ConnectorService: gmail_chunks = await self.chunk_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="GOOGLE_GMAIL_CONNECTOR", ) @@ -1783,7 +1726,6 @@ class ConnectorService: gmail_chunks = await self.document_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="GOOGLE_GMAIL_CONNECTOR", ) @@ -1877,7 +1819,6 @@ class ConnectorService: async def search_confluence( self, user_query: str, - user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS, @@ -1887,7 +1828,6 @@ class ConnectorService: Args: user_query: The user's query - user_id: The user's ID search_space_id: The search space ID to search in top_k: Maximum number of results to return search_mode: Search mode (CHUNKS or DOCUMENTS) @@ -1899,7 +1839,6 @@ class ConnectorService: confluence_chunks = await self.chunk_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="CONFLUENCE_CONNECTOR", ) @@ -1907,7 +1846,6 @@ class ConnectorService: confluence_chunks = await self.document_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="CONFLUENCE_CONNECTOR", ) @@ -1972,7 +1910,6 @@ class ConnectorService: async def search_clickup( self, user_query: str, - user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS, @@ -1982,7 +1919,6 @@ class ConnectorService: Args: user_query: The user's query - user_id: The user's ID search_space_id: The search space ID to search in top_k: Maximum number of results to return search_mode: Search mode (CHUNKS or DOCUMENTS) @@ -1994,7 +1930,6 @@ class ConnectorService: clickup_chunks = await self.chunk_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="CLICKUP_CONNECTOR", ) @@ -2002,7 +1937,6 @@ class ConnectorService: clickup_chunks = await self.document_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="CLICKUP_CONNECTOR", ) @@ -2088,7 +2022,6 @@ class ConnectorService: async def search_linkup( self, user_query: str, - user_id: str, search_space_id: int, mode: str = "standard", ) -> tuple: @@ -2097,7 +2030,6 @@ class ConnectorService: Args: user_query: The user's query - user_id: The user's ID search_space_id: The search space ID mode: Search depth mode, can be "standard" or "deep" @@ -2106,7 +2038,7 @@ class ConnectorService: """ # Get Linkup connector configuration linkup_connector = await self.get_connector_by_type( - user_id, SearchSourceConnectorType.LINKUP_API, search_space_id + SearchSourceConnectorType.LINKUP_API, search_space_id ) if not linkup_connector: @@ -2211,7 +2143,6 @@ class ConnectorService: async def search_discord( self, user_query: str, - user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS, @@ -2221,7 +2152,6 @@ class ConnectorService: Args: user_query: The user's query - user_id: The user's ID search_space_id: The search space ID to search in top_k: Maximum number of results to return @@ -2232,7 +2162,6 @@ class ConnectorService: discord_chunks = await self.chunk_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="DISCORD_CONNECTOR", ) @@ -2240,7 +2169,6 @@ class ConnectorService: discord_chunks = await self.document_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="DISCORD_CONNECTOR", ) @@ -2308,7 +2236,6 @@ class ConnectorService: async def search_luma( self, user_query: str, - user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS, @@ -2318,7 +2245,6 @@ class ConnectorService: Args: user_query: The user's query - user_id: The user's ID search_space_id: The search space ID to search in top_k: Maximum number of results to return search_mode: Search mode (CHUNKS or DOCUMENTS) @@ -2330,7 +2256,6 @@ class ConnectorService: luma_chunks = await self.chunk_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="LUMA_CONNECTOR", ) @@ -2338,7 +2263,6 @@ class ConnectorService: luma_chunks = await self.document_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="LUMA_CONNECTOR", ) @@ -2466,7 +2390,6 @@ class ConnectorService: async def search_elasticsearch( self, user_query: str, - user_id: str, search_space_id: int, top_k: int = 20, search_mode: SearchMode = SearchMode.CHUNKS, @@ -2476,7 +2399,6 @@ class ConnectorService: Args: user_query: The user's query - user_id: The user's ID search_space_id: The search space ID to search in top_k: Maximum number of results to return search_mode: Search mode (CHUNKS or DOCUMENTS) @@ -2488,7 +2410,6 @@ class ConnectorService: elasticsearch_chunks = await self.chunk_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="ELASTICSEARCH_CONNECTOR", ) @@ -2496,7 +2417,6 @@ class ConnectorService: elasticsearch_chunks = await self.document_retriever.hybrid_search( query_text=user_query, top_k=top_k, - user_id=user_id, search_space_id=search_space_id, document_type="ELASTICSEARCH_CONNECTOR", ) diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index ea9140f8e..c3270b59e 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from app.config import config -from app.db import LLMConfig, UserSearchSpacePreference +from app.db import LLMConfig, SearchSpace # Configure litellm to automatically drop unsupported parameters litellm.drop_params = True @@ -144,15 +144,16 @@ async def validate_llm_config( return False, error_msg -async def get_user_llm_instance( - session: AsyncSession, user_id: str, search_space_id: int, role: str +async def get_search_space_llm_instance( + session: AsyncSession, search_space_id: int, role: str ) -> ChatLiteLLM | None: """ - Get a ChatLiteLLM instance for a specific user, search space, and role. + Get a ChatLiteLLM instance for a specific search space and role. + + LLM preferences are stored at the search space level and shared by all members. Args: session: Database session - user_id: User ID search_space_id: Search Space ID role: LLM role ('long_context', 'fast', or 'strategic') @@ -160,37 +161,30 @@ async def get_user_llm_instance( ChatLiteLLM instance or None if not found """ try: - # Get user's LLM preferences for this search space + # Get the search space with its LLM preferences result = await session.execute( - select(UserSearchSpacePreference).where( - UserSearchSpacePreference.user_id == user_id, - UserSearchSpacePreference.search_space_id == search_space_id, - ) + select(SearchSpace).where(SearchSpace.id == search_space_id) ) - preference = result.scalars().first() + search_space = result.scalars().first() - if not preference: - logger.error( - f"No LLM preferences found for user {user_id} in search space {search_space_id}" - ) + if not search_space: + logger.error(f"Search space {search_space_id} not found") return None # Get the appropriate LLM config ID based on role llm_config_id = None if role == LLMRole.LONG_CONTEXT: - llm_config_id = preference.long_context_llm_id + llm_config_id = search_space.long_context_llm_id elif role == LLMRole.FAST: - llm_config_id = preference.fast_llm_id + llm_config_id = search_space.fast_llm_id elif role == LLMRole.STRATEGIC: - llm_config_id = preference.strategic_llm_id + llm_config_id = search_space.strategic_llm_id else: logger.error(f"Invalid LLM role: {role}") return None if not llm_config_id: - logger.error( - f"No {role} LLM configured for user {user_id} in search space {search_space_id}" - ) + logger.error(f"No {role} LLM configured for search space {search_space_id}") return None # Check if this is a global config (negative ID) @@ -331,31 +325,63 @@ async def get_user_llm_instance( except Exception as e: logger.error( - f"Error getting LLM instance for user {user_id}, role {role}: {e!s}" + f"Error getting LLM instance for search space {search_space_id}, role {role}: {e!s}" ) return None +async def get_long_context_llm( + session: AsyncSession, search_space_id: int +) -> ChatLiteLLM | None: + """Get the search space's long context LLM instance.""" + return await get_search_space_llm_instance( + session, search_space_id, LLMRole.LONG_CONTEXT + ) + + +async def get_fast_llm( + session: AsyncSession, search_space_id: int +) -> ChatLiteLLM | None: + """Get the search space's fast LLM instance.""" + return await get_search_space_llm_instance(session, search_space_id, LLMRole.FAST) + + +async def get_strategic_llm( + session: AsyncSession, search_space_id: int +) -> ChatLiteLLM | None: + """Get the search space's strategic LLM instance.""" + return await get_search_space_llm_instance( + session, search_space_id, LLMRole.STRATEGIC + ) + + +# Backward-compatible aliases (deprecated - will be removed in future versions) +async def get_user_llm_instance( + session: AsyncSession, user_id: str, search_space_id: int, role: str +) -> ChatLiteLLM | None: + """ + Deprecated: Use get_search_space_llm_instance instead. + LLM preferences are now stored at the search space level, not per-user. + """ + return await get_search_space_llm_instance(session, search_space_id, role) + + async def get_user_long_context_llm( session: AsyncSession, user_id: str, search_space_id: int ) -> ChatLiteLLM | None: - """Get user's long context LLM instance for a specific search space.""" - return await get_user_llm_instance( - session, user_id, search_space_id, LLMRole.LONG_CONTEXT - ) + """Deprecated: Use get_long_context_llm instead.""" + return await get_long_context_llm(session, search_space_id) async def get_user_fast_llm( session: AsyncSession, user_id: str, search_space_id: int ) -> ChatLiteLLM | None: - """Get user's fast LLM instance for a specific search space.""" - return await get_user_llm_instance(session, user_id, search_space_id, LLMRole.FAST) + """Deprecated: Use get_fast_llm instead.""" + return await get_fast_llm(session, search_space_id) async def get_user_strategic_llm( session: AsyncSession, user_id: str, search_space_id: int ) -> ChatLiteLLM | None: - """Get user's strategic LLM instance for a specific search space.""" - return await get_user_llm_instance( - session, user_id, search_space_id, LLMRole.STRATEGIC - ) + """Deprecated: Use get_strategic_llm instead.""" + return await get_strategic_llm(session, search_space_id) diff --git a/surfsense_backend/app/services/query_service.py b/surfsense_backend/app/services/query_service.py index d2759ab27..0521dc942 100644 --- a/surfsense_backend/app/services/query_service.py +++ b/surfsense_backend/app/services/query_service.py @@ -4,7 +4,7 @@ from typing import Any from langchain.schema import AIMessage, HumanMessage, SystemMessage from sqlalchemy.ext.asyncio import AsyncSession -from app.services.llm_service import get_user_strategic_llm +from app.services.llm_service import get_strategic_llm class QueryService: @@ -16,19 +16,17 @@ class QueryService: async def reformulate_query_with_chat_history( user_query: str, session: AsyncSession, - user_id: str, search_space_id: int, chat_history_str: str | None = None, ) -> str: """ - Reformulate the user query using the user's strategic LLM to make it more + Reformulate the user query using the search space's strategic LLM to make it more effective for information retrieval and research purposes. Args: user_query: The original user query - session: Database session for accessing user LLM configs - user_id: User ID to get their specific LLM configuration - search_space_id: Search Space ID to get user's LLM preferences + session: Database session for accessing LLM configs + search_space_id: Search Space ID to get LLM preferences chat_history_str: Optional chat history string Returns: @@ -38,11 +36,11 @@ class QueryService: return user_query try: - # Get the user's strategic LLM instance - llm = await get_user_strategic_llm(session, user_id, search_space_id) + # Get the search space's strategic LLM instance + llm = await get_strategic_llm(session, search_space_id) if not llm: print( - f"Warning: No strategic LLM configured for user {user_id} in search space {search_space_id}. Using original query." + f"Warning: No strategic LLM configured for search space {search_space_id}. Using original query." ) return user_query diff --git a/surfsense_backend/app/utils/check_ownership.py b/surfsense_backend/app/utils/check_ownership.py deleted file mode 100644 index 0bd290ff3..000000000 --- a/surfsense_backend/app/utils/check_ownership.py +++ /dev/null @@ -1,19 +0,0 @@ -from fastapi import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select - -from app.db import User - - -# Helper function to check user ownership -async def check_ownership(session: AsyncSession, model, item_id: int, user: User): - item = await session.execute( - select(model).filter(model.id == item_id, model.user_id == user.id) - ) - item = item.scalars().first() - if not item: - raise HTTPException( - status_code=404, - detail="Item not found or you don't have permission to access it", - ) - return item diff --git a/surfsense_backend/app/utils/rbac.py b/surfsense_backend/app/utils/rbac.py new file mode 100644 index 000000000..6cb180d80 --- /dev/null +++ b/surfsense_backend/app/utils/rbac.py @@ -0,0 +1,274 @@ +""" +RBAC (Role-Based Access Control) utility functions. +Provides helpers for checking user permissions in search spaces. +""" + +import secrets +from uuid import UUID + +from fastapi import HTTPException +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy.orm import selectinload + +from app.db import ( + Permission, + SearchSpace, + SearchSpaceMembership, + SearchSpaceRole, + User, + has_permission, +) + + +async def get_user_membership( + session: AsyncSession, + user_id: UUID, + search_space_id: int, +) -> SearchSpaceMembership | None: + """ + Get the user's membership in a search space. + + Args: + session: Database session + user_id: User UUID + search_space_id: Search space ID + + Returns: + SearchSpaceMembership if found, None otherwise + """ + result = await session.execute( + select(SearchSpaceMembership) + .options(selectinload(SearchSpaceMembership.role)) + .filter( + SearchSpaceMembership.user_id == user_id, + SearchSpaceMembership.search_space_id == search_space_id, + ) + ) + return result.scalars().first() + + +async def get_user_permissions( + session: AsyncSession, + user_id: UUID, + search_space_id: int, +) -> list[str]: + """ + Get the user's permissions in a search space. + + Args: + session: Database session + user_id: User UUID + search_space_id: Search space ID + + Returns: + List of permission strings + """ + membership = await get_user_membership(session, user_id, search_space_id) + + if not membership: + return [] + + # Owners always have full access + if membership.is_owner: + return [Permission.FULL_ACCESS.value] + + # Get permissions from role + if membership.role: + return membership.role.permissions or [] + + return [] + + +async def check_permission( + session: AsyncSession, + user: User, + search_space_id: int, + required_permission: str, + error_message: str = "You don't have permission to perform this action", +) -> SearchSpaceMembership: + """ + Check if a user has a specific permission in a search space. + Raises HTTPException if permission is denied. + + Args: + session: Database session + user: User object + search_space_id: Search space ID + required_permission: Permission string to check + error_message: Custom error message for permission denied + + Returns: + SearchSpaceMembership if permission granted + + Raises: + HTTPException: If user doesn't have access or permission + """ + membership = await get_user_membership(session, user.id, search_space_id) + + if not membership: + raise HTTPException( + status_code=403, + detail="You don't have access to this search space", + ) + + # Get user's permissions + if membership.is_owner: + permissions = [Permission.FULL_ACCESS.value] + elif membership.role: + permissions = membership.role.permissions or [] + else: + permissions = [] + + if not has_permission(permissions, required_permission): + raise HTTPException(status_code=403, detail=error_message) + + return membership + + +async def check_search_space_access( + session: AsyncSession, + user: User, + search_space_id: int, +) -> SearchSpaceMembership: + """ + Check if a user has any access to a search space. + This is used for basic access control (user is a member). + + Args: + session: Database session + user: User object + search_space_id: Search space ID + + Returns: + SearchSpaceMembership if user has access + + Raises: + HTTPException: If user doesn't have access + """ + membership = await get_user_membership(session, user.id, search_space_id) + + if not membership: + raise HTTPException( + status_code=403, + detail="You don't have access to this search space", + ) + + return membership + + +async def is_search_space_owner( + session: AsyncSession, + user_id: UUID, + search_space_id: int, +) -> bool: + """ + Check if a user is the owner of a search space. + + Args: + session: Database session + user_id: User UUID + search_space_id: Search space ID + + Returns: + True if user is the owner, False otherwise + """ + membership = await get_user_membership(session, user_id, search_space_id) + return membership is not None and membership.is_owner + + +async def get_search_space_with_access_check( + session: AsyncSession, + user: User, + search_space_id: int, + required_permission: str | None = None, +) -> tuple[SearchSpace, SearchSpaceMembership]: + """ + Get a search space with access and optional permission check. + + Args: + session: Database session + user: User object + search_space_id: Search space ID + required_permission: Optional permission to check + + Returns: + Tuple of (SearchSpace, SearchSpaceMembership) + + Raises: + HTTPException: If search space not found or user lacks access/permission + """ + # Get the search space + result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == search_space_id) + ) + search_space = result.scalars().first() + + if not search_space: + raise HTTPException(status_code=404, detail="Search space not found") + + # Check access + if required_permission: + membership = await check_permission( + session, user, search_space_id, required_permission + ) + else: + membership = await check_search_space_access(session, user, search_space_id) + + return search_space, membership + + +def generate_invite_code() -> str: + """ + Generate a unique invite code for search space invites. + + Returns: + A 32-character URL-safe invite code + """ + return secrets.token_urlsafe(24) + + +async def get_default_role( + session: AsyncSession, + search_space_id: int, +) -> SearchSpaceRole | None: + """ + Get the default role for a search space (used when accepting invites without a specific role). + + Args: + session: Database session + search_space_id: Search space ID + + Returns: + Default SearchSpaceRole or None + """ + result = await session.execute( + select(SearchSpaceRole).filter( + SearchSpaceRole.search_space_id == search_space_id, + SearchSpaceRole.is_default == True, # noqa: E712 + ) + ) + return result.scalars().first() + + +async def get_owner_role( + session: AsyncSession, + search_space_id: int, +) -> SearchSpaceRole | None: + """ + Get the Owner role for a search space. + + Args: + session: Database session + search_space_id: Search space ID + + Returns: + Owner SearchSpaceRole or None + """ + result = await session.execute( + select(SearchSpaceRole).filter( + SearchSpaceRole.search_space_id == search_space_id, + SearchSpaceRole.name == "Owner", + ) + ) + return result.scalars().first() diff --git a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx index 4ec8046a4..105c21e26 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx @@ -18,6 +18,7 @@ import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/com import { Separator } from "@/components/ui/separator"; import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar"; import { useLLMPreferences } from "@/hooks/use-llm-configs"; +import { useUserAccess } from "@/hooks/use-rbac"; import { cn } from "@/lib/utils"; export function DashboardClientLayout({ @@ -60,11 +61,15 @@ export function DashboardClientLayout({ }, [activeChatId, isChatPannelOpen]); const { loading, error, isOnboardingComplete } = useLLMPreferences(searchSpaceIdNum); + const { access, loading: accessLoading } = useUserAccess(searchSpaceIdNum); const [hasCheckedOnboarding, setHasCheckedOnboarding] = useState(false); // Skip onboarding check if we're already on the onboarding page const isOnboardingPage = pathname?.includes("/onboard"); + // Only owners should see onboarding - invited members use existing config + const isOwner = access?.is_owner ?? false; + // Translate navigation items const tNavMenu = useTranslations("nav_menu"); const translatedNavMain = useMemo(() => { @@ -102,11 +107,13 @@ export function DashboardClientLayout({ return; } - // Only check once after preferences have loaded - if (!loading && !hasCheckedOnboarding) { + // Wait for both preferences and access data to load + if (!loading && !accessLoading && !hasCheckedOnboarding) { const onboardingComplete = isOnboardingComplete(); - if (!onboardingComplete) { + // Only redirect to onboarding if user is the owner and onboarding is not complete + // Invited members (non-owners) should skip onboarding and use existing config + if (!onboardingComplete && isOwner) { router.push(`/dashboard/${searchSpaceId}/onboard`); } @@ -114,8 +121,10 @@ export function DashboardClientLayout({ } }, [ loading, + accessLoading, isOnboardingComplete, isOnboardingPage, + isOwner, router, searchSpaceId, hasCheckedOnboarding, @@ -145,7 +154,7 @@ export function DashboardClientLayout({ }, [chat_id, search_space_id]); // Show loading screen while checking onboarding status (only on first load) - if (!hasCheckedOnboarding && loading && !isOnboardingPage) { + if (!hasCheckedOnboarding && (loading || accessLoading) && !isOnboardingPage) { return (
Loading team data...
++ Manage members, roles, and invite links for your search space +
++ {members.filter((m) => m.is_owner).length} owner + {members.filter((m) => m.is_owner).length !== 1 ? "s" : ""} +
++ {roles.filter((r) => r.is_system_role).length} system roles +
++ {invites.reduce((acc, i) => acc + i.uses_count, 0)} total uses +
++ Create an invite link to allow others to join your search space with specific roles. +
+{invite.name || "Unnamed Invite"}
+ {isExpired && ( +{space.description}
Loading invite details...
+{acceptedData.search_space_name}
+Search Space
+{acceptedData.role_name}
+Your Role
++ The invite may have expired, reached its maximum uses, or been revoked by the + owner. +
+{inviteInfo?.search_space_name}
+Search Space
+{inviteInfo.role_name}
+Role you'll receive
+{inviteInfo?.search_space_name}
+Search Space
+{inviteInfo.role_name}
+Role you'll receive
+