diff --git a/surfsense_backend/alembic/versions/0_initial_schema.py b/surfsense_backend/alembic/versions/0_initial_schema.py new file mode 100644 index 000000000..77bd9dd1b --- /dev/null +++ b/surfsense_backend/alembic/versions/0_initial_schema.py @@ -0,0 +1,54 @@ +"""Initial schema setup + +Revision ID: 0 +Revises: None + +Creates all tables from SQLAlchemy models. Idempotent - safe to run on existing databases. +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "0" +down_revision: str | None = None +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + from app.db import Base + + connection = op.get_bind() + + # Create tables + op.execute(sa.text("CREATE EXTENSION IF NOT EXISTS vector")) + Base.metadata.create_all(bind=connection) + + # Set up indexes + op.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)" + ) + ) + op.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector('english', content))" + ) + ) + op.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)" + ) + ) + op.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector('english', content))" + ) + ) + + +def downgrade() -> None: + pass diff --git a/surfsense_backend/alembic/versions/10_update_chattype_enum_to_qna_report_structure.py b/surfsense_backend/alembic/versions/10_update_chattype_enum_to_qna_report_structure.py index 665585a85..dca37b90e 100644 --- a/surfsense_backend/alembic/versions/10_update_chattype_enum_to_qna_report_structure.py +++ b/surfsense_backend/alembic/versions/10_update_chattype_enum_to_qna_report_structure.py @@ -6,6 +6,8 @@ Revises: 9 from collections.abc import Sequence +import sqlalchemy as sa + from alembic import op # revision identifiers, used by Alembic. @@ -18,9 +20,35 @@ depends_on: str | Sequence[str] | None = None CHAT_TYPE_ENUM = "chattype" +def enum_exists(enum_name: str) -> bool: + """Check if an enum type exists in the database.""" + conn = op.get_bind() + result = conn.execute( + sa.text("SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)"), + {"enum_name": enum_name}, + ) + return result.scalar() + + +def table_exists(table_name: str) -> bool: + """Check if a table exists in the database.""" + conn = op.get_bind() + result = conn.execute( + sa.text( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)" + ), + {"table_name": table_name}, + ) + return result.scalar() + + def upgrade() -> None: """Upgrade schema - replace ChatType enum values with new QNA/REPORT structure.""" + # Skip if chats table or chattype enum doesn't exist (fresh database) + if not table_exists("chats") or not enum_exists(CHAT_TYPE_ENUM): + return + # Old enum name for temporary storage old_enum_name = f"{CHAT_TYPE_ENUM}_old" @@ -72,6 +100,10 @@ def upgrade() -> None: def downgrade() -> None: """Downgrade schema - revert ChatType enum to old GENERAL/DEEP/DEEPER/DEEPEST structure.""" + # Skip if chats table or chattype enum doesn't exist + if not table_exists("chats") or not enum_exists(CHAT_TYPE_ENUM): + return + # Old enum name for temporary storage old_enum_name = f"{CHAT_TYPE_ENUM}_old" diff --git a/surfsense_backend/alembic/versions/1_add_github_connector_enum.py b/surfsense_backend/alembic/versions/1_add_github_connector_enum.py index 235908b1f..a031e7693 100644 --- a/surfsense_backend/alembic/versions/1_add_github_connector_enum.py +++ b/surfsense_backend/alembic/versions/1_add_github_connector_enum.py @@ -7,22 +7,34 @@ Revises: from collections.abc import Sequence +import sqlalchemy as sa + from alembic import op -# Import pgvector if needed for other types, though not for this ENUM change -# import pgvector - - # revision identifiers, used by Alembic. revision: str = "1" -down_revision: str | None = None +down_revision: str | None = "0" branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None +def enum_exists(enum_name: str) -> bool: + """Check if an enum type exists in the database.""" + conn = op.get_bind() + result = conn.execute( + sa.text("SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)"), + {"enum_name": enum_name}, + ) + return result.scalar() + + def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### + # Skip if the enum doesn't exist (fresh DB after downgrade - create_db_and_tables will handle it) + if not enum_exists("searchsourceconnectortype"): + return + # Manually add the command to add the enum value # Note: It's generally better to let autogenerate handle this, but we're bypassing it op.execute( @@ -51,6 +63,10 @@ END$$; def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### + # Skip if the enum doesn't exist + if not enum_exists("searchsourceconnectortype"): + return + # Downgrading removal of an enum value is complex and potentially dangerous # if the value is in use. Often omitted or requires manual SQL based on context. # For now, we'll just pass. If you needed to reverse this, you'd likely diff --git a/surfsense_backend/alembic/versions/24_fix_null_chat_types.py b/surfsense_backend/alembic/versions/24_fix_null_chat_types.py index e0d371f1e..e513605f0 100644 --- a/surfsense_backend/alembic/versions/24_fix_null_chat_types.py +++ b/surfsense_backend/alembic/versions/24_fix_null_chat_types.py @@ -7,6 +7,8 @@ Revises: 23 from collections.abc import Sequence +import sqlalchemy as sa + from alembic import op # revision identifiers, used by Alembic. @@ -16,11 +18,27 @@ branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None +def table_exists(table_name: str) -> bool: + """Check if a table exists in the database.""" + conn = op.get_bind() + result = conn.execute( + sa.text( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)" + ), + {"table_name": table_name}, + ) + return result.scalar() + + def upgrade() -> None: """ Fix any chats with NULL type values by setting them to QNA. This handles edge cases from previous migrations where type values were not properly migrated. """ + # Skip if chats table doesn't exist (fresh database) + if not table_exists("chats"): + return + # Update any NULL type values to QNA (the default chat type) op.execute( """ diff --git a/surfsense_backend/alembic/versions/34_add_podcast_staleness_detection.py b/surfsense_backend/alembic/versions/34_add_podcast_staleness_detection.py index 4991cd58e..74bb7fe86 100644 --- a/surfsense_backend/alembic/versions/34_add_podcast_staleness_detection.py +++ b/surfsense_backend/alembic/versions/34_add_podcast_staleness_detection.py @@ -10,6 +10,8 @@ Revises: 33 from collections.abc import Sequence +import sqlalchemy as sa + from alembic import op # revision identifiers @@ -19,42 +21,59 @@ branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None +def table_exists(table_name: str) -> bool: + """Check if a table exists in the database.""" + conn = op.get_bind() + result = conn.execute( + sa.text( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)" + ), + {"table_name": table_name}, + ) + return result.scalar() + + def upgrade() -> None: """Add columns only if they don't already exist (safe for re-runs).""" # Add 'state_version' column to chats table (default 1) - op.execute(""" - ALTER TABLE chats - ADD COLUMN IF NOT EXISTS state_version BIGINT DEFAULT 1 NOT NULL - """) + # Skip if chats table doesn't exist (fresh database) + if table_exists("chats"): + op.execute(""" + ALTER TABLE chats + ADD COLUMN IF NOT EXISTS state_version BIGINT DEFAULT 1 NOT NULL + """) # Add 'chat_state_version' column to podcasts table - op.execute(""" - ALTER TABLE podcasts - ADD COLUMN IF NOT EXISTS chat_state_version BIGINT - """) + if table_exists("podcasts"): + op.execute(""" + ALTER TABLE podcasts + ADD COLUMN IF NOT EXISTS chat_state_version BIGINT + """) - # Add 'chat_id' column to podcasts table - op.execute(""" - ALTER TABLE podcasts - ADD COLUMN IF NOT EXISTS chat_id INTEGER - """) + # Add 'chat_id' column to podcasts table + op.execute(""" + ALTER TABLE podcasts + ADD COLUMN IF NOT EXISTS chat_id INTEGER + """) def downgrade() -> None: """Remove columns only if they exist.""" - op.execute(""" - ALTER TABLE podcasts - DROP COLUMN IF EXISTS chat_state_version - """) + if table_exists("podcasts"): + op.execute(""" + ALTER TABLE podcasts + DROP COLUMN IF EXISTS chat_state_version + """) - op.execute(""" - ALTER TABLE podcasts - DROP COLUMN IF EXISTS chat_id - """) + op.execute(""" + ALTER TABLE podcasts + DROP COLUMN IF EXISTS chat_id + """) - op.execute(""" - ALTER TABLE chats - DROP COLUMN IF EXISTS state_version - """) + if table_exists("chats"): + op.execute(""" + ALTER TABLE chats + DROP COLUMN IF EXISTS state_version + """) diff --git a/surfsense_backend/alembic/versions/49_migrate_old_chats_to_new_chat.py b/surfsense_backend/alembic/versions/49_migrate_old_chats_to_new_chat.py index 61a3ddb48..488f46227 100644 --- a/surfsense_backend/alembic/versions/49_migrate_old_chats_to_new_chat.py +++ b/surfsense_backend/alembic/versions/49_migrate_old_chats_to_new_chat.py @@ -62,8 +62,25 @@ def parse_timestamp(ts, fallback): return fallback +def table_exists(table_name: str) -> bool: + """Check if a table exists in the database.""" + conn = op.get_bind() + result = conn.execute( + sa.text( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)" + ), + {"table_name": table_name}, + ) + return result.scalar() + + def upgrade() -> None: """Migrate old chats to new_chat_threads and remove old tables.""" + # Skip if chats table doesn't exist (fresh database) + if not table_exists("chats"): + print("[Migration 49] Chats table does not exist, skipping migration") + return + connection = op.get_bind() # Get all old chats @@ -176,36 +193,47 @@ def upgrade() -> None: print("[Migration 49] Migration complete!") +def enum_exists(enum_name: str) -> bool: + """Check if an enum type exists in the database.""" + conn = op.get_bind() + result = conn.execute( + sa.text("SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)"), + {"enum_name": enum_name}, + ) + return result.scalar() + + def downgrade() -> None: """Recreate old chats table (data cannot be restored).""" - # Recreate chattype enum + # Skip if chats table already exists + if table_exists("chats"): + print("[Migration 49 Downgrade] Chats table already exists, skipping") + return + + # Recreate chattype enum if it doesn't exist + if not enum_exists("chattype"): + op.execute( + sa.text(""" + CREATE TYPE chattype AS ENUM ('QNA') + """) + ) + + # Recreate chats table using raw SQL to avoid SQLAlchemy trying to create the enum op.execute( sa.text(""" - CREATE TYPE chattype AS ENUM ('QNA') + CREATE TABLE chats ( + id SERIAL PRIMARY KEY, + type chattype NOT NULL, + title VARCHAR NOT NULL, + initial_connectors VARCHAR[], + messages JSON NOT NULL, + state_version BIGINT NOT NULL DEFAULT 1, + search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() + ) """) ) - - # Recreate chats table - op.create_table( - "chats", - sa.Column("id", sa.Integer(), primary_key=True, index=True), - sa.Column("type", sa.Enum("QNA", name="chattype"), nullable=False), - sa.Column("title", sa.String(), nullable=False, index=True), - sa.Column("initial_connectors", sa.ARRAY(sa.String()), nullable=True), - sa.Column("messages", sa.JSON(), nullable=False), - sa.Column("state_version", sa.BigInteger(), nullable=False, default=1), - sa.Column( - "search_space_id", - sa.Integer(), - sa.ForeignKey("searchspaces.id", ondelete="CASCADE"), - nullable=False, - ), - sa.Column( - "created_at", - sa.TIMESTAMP(timezone=True), - nullable=False, - server_default=sa.func.now(), - ), - ) + op.execute(sa.text("CREATE INDEX ix_chats_id ON chats (id)")) + op.execute(sa.text("CREATE INDEX ix_chats_title ON chats (title)")) print("[Migration 49 Downgrade] Chats table recreated (data not restored)") diff --git a/surfsense_backend/alembic/versions/52_rename_llm_preference_columns.py b/surfsense_backend/alembic/versions/52_rename_llm_preference_columns.py index cd1a1dbbc..08177ca70 100644 --- a/surfsense_backend/alembic/versions/52_rename_llm_preference_columns.py +++ b/surfsense_backend/alembic/versions/52_rename_llm_preference_columns.py @@ -39,7 +39,7 @@ def upgrade(): """ ) - # Rename columns (only if they exist with old names) + # Rename columns (only if source exists and target doesn't already exist) op.execute( """ DO $$ @@ -47,6 +47,9 @@ def upgrade(): IF EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_name = 'searchspaces' AND column_name = 'fast_llm_id' + ) AND NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'searchspaces' AND column_name = 'agent_llm_id' ) THEN ALTER TABLE searchspaces RENAME COLUMN fast_llm_id TO agent_llm_id; END IF; @@ -61,6 +64,9 @@ def upgrade(): IF EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_name = 'searchspaces' AND column_name = 'long_context_llm_id' + ) AND NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'searchspaces' AND column_name = 'document_summary_llm_id' ) THEN ALTER TABLE searchspaces RENAME COLUMN long_context_llm_id TO document_summary_llm_id; END IF; @@ -100,7 +106,7 @@ def downgrade(): """ ) - # Rename columns back + # Rename columns back (only if source exists and target doesn't already exist) op.execute( """ DO $$ @@ -108,6 +114,9 @@ def downgrade(): IF EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_name = 'searchspaces' AND column_name = 'agent_llm_id' + ) AND NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'searchspaces' AND column_name = 'fast_llm_id' ) THEN ALTER TABLE searchspaces RENAME COLUMN agent_llm_id TO fast_llm_id; END IF; @@ -122,6 +131,9 @@ def downgrade(): IF EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_name = 'searchspaces' AND column_name = 'document_summary_llm_id' + ) AND NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'searchspaces' AND column_name = 'long_context_llm_id' ) THEN ALTER TABLE searchspaces RENAME COLUMN document_summary_llm_id TO long_context_llm_id; END IF; diff --git a/surfsense_backend/alembic/versions/55_rename_google_drive_connector_to_file.py b/surfsense_backend/alembic/versions/55_rename_google_drive_connector_to_file.py index 9ce57d95f..baaf1991f 100644 --- a/surfsense_backend/alembic/versions/55_rename_google_drive_connector_to_file.py +++ b/surfsense_backend/alembic/versions/55_rename_google_drive_connector_to_file.py @@ -60,14 +60,28 @@ def downgrade() -> None: connection = op.get_bind() - connection.execute( + # Only update if the target enum value exists (it won't on fresh databases) + result = connection.execute( text( """ - UPDATE documents - SET document_type = 'GOOGLE_DRIVE_CONNECTOR' - WHERE document_type = 'GOOGLE_DRIVE_FILE'; + SELECT EXISTS ( + SELECT 1 FROM pg_type t + JOIN pg_enum e ON t.oid = e.enumtypid + WHERE t.typname = 'documenttype' AND e.enumlabel = 'GOOGLE_DRIVE_CONNECTOR' + ); """ ) ) + enum_exists = result.scalar() - connection.commit() + if enum_exists: + connection.execute( + text( + """ + UPDATE documents + SET document_type = 'GOOGLE_DRIVE_CONNECTOR' + WHERE document_type = 'GOOGLE_DRIVE_FILE'; + """ + ) + ) + connection.commit() diff --git a/surfsense_backend/alembic/versions/5_remove_title_char_limit.py b/surfsense_backend/alembic/versions/5_remove_title_char_limit.py index 2e4cd56d1..afdbaa803 100644 --- a/surfsense_backend/alembic/versions/5_remove_title_char_limit.py +++ b/surfsense_backend/alembic/versions/5_remove_title_char_limit.py @@ -18,59 +18,77 @@ branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None -def upgrade() -> None: - # Alter Chat table - op.alter_column( - "chats", - "title", - existing_type=sa.String(200), - type_=sa.String(), - existing_nullable=False, +def table_exists(table_name: str) -> bool: + """Check if a table exists in the database.""" + conn = op.get_bind() + result = conn.execute( + sa.text( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)" + ), + {"table_name": table_name}, ) + return result.scalar() + + +def upgrade() -> None: + # Alter Chat table (may not exist on fresh databases, removed in migration 49) + if table_exists("chats"): + op.alter_column( + "chats", + "title", + existing_type=sa.String(200), + type_=sa.String(), + existing_nullable=False, + ) # Alter Document table - op.alter_column( - "documents", - "title", - existing_type=sa.String(200), - type_=sa.String(), - existing_nullable=False, - ) + if table_exists("documents"): + op.alter_column( + "documents", + "title", + existing_type=sa.String(200), + type_=sa.String(), + existing_nullable=False, + ) # Alter Podcast table - op.alter_column( - "podcasts", - "title", - existing_type=sa.String(200), - type_=sa.String(), - existing_nullable=False, - ) + if table_exists("podcasts"): + op.alter_column( + "podcasts", + "title", + existing_type=sa.String(200), + type_=sa.String(), + existing_nullable=False, + ) def downgrade() -> None: # Revert Chat table - op.alter_column( - "chats", - "title", - existing_type=sa.String(), - type_=sa.String(200), - existing_nullable=False, - ) + if table_exists("chats"): + op.alter_column( + "chats", + "title", + existing_type=sa.String(), + type_=sa.String(200), + existing_nullable=False, + ) # Revert Document table - op.alter_column( - "documents", - "title", - existing_type=sa.String(), - type_=sa.String(200), - existing_nullable=False, - ) + if table_exists("documents"): + op.alter_column( + "documents", + "title", + existing_type=sa.String(), + type_=sa.String(200), + existing_nullable=False, + ) # Revert Podcast table - op.alter_column( - "podcasts", - "title", - existing_type=sa.String(), - type_=sa.String(200), - existing_nullable=False, - ) + if table_exists("podcasts"): + op.alter_column( + "podcasts", + "title", + existing_type=sa.String(), + type_=sa.String(200), + existing_nullable=False, + ) diff --git a/surfsense_backend/alembic/versions/62_add_mcp_connector_type.py b/surfsense_backend/alembic/versions/62_add_mcp_connector_type.py new file mode 100644 index 000000000..5c5ccf106 --- /dev/null +++ b/surfsense_backend/alembic/versions/62_add_mcp_connector_type.py @@ -0,0 +1,38 @@ +"""Add MCP connector type + +Revision ID: 62 +Revises: 61 +Create Date: 2026-01-09 15:19:51.827647 + +""" + +from collections.abc import Sequence + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "62" +down_revision: str | None = "61" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Add MCP_CONNECTOR to SearchSourceConnectorType enum.""" + # Add new enum value using raw SQL + op.execute( + """ + ALTER TYPE searchsourceconnectortype ADD VALUE IF NOT EXISTS 'MCP_CONNECTOR'; + """ + ) + + +def downgrade() -> None: + """Remove MCP_CONNECTOR from SearchSourceConnectorType enum.""" + # Note: PostgreSQL does not support removing enum values directly. + # To downgrade, you would need to: + # 1. Create a new enum without MCP_CONNECTOR + # 2. Alter the column to use the new enum + # 3. Drop the old enum + # This is left as a manual operation if needed. + pass diff --git a/surfsense_backend/alembic/versions/63_allow_multiple_connectors_with_unique_.py b/surfsense_backend/alembic/versions/63_allow_multiple_connectors_with_unique_.py new file mode 100644 index 000000000..ff3f98906 --- /dev/null +++ b/surfsense_backend/alembic/versions/63_allow_multiple_connectors_with_unique_.py @@ -0,0 +1,97 @@ +"""allow_multiple_connectors_with_unique_names + +Revision ID: 63 +Revises: 62 +Create Date: 2026-01-13 12:23:31.481643 + +""" + +from collections.abc import Sequence + +from sqlalchemy import text + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "63" +down_revision: str | None = "62" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + connection = op.get_bind() + + # Check if old constraint exists before trying to drop it + old_constraint_exists = connection.execute( + text(""" + SELECT 1 FROM information_schema.table_constraints + WHERE table_name='search_source_connectors' + AND constraint_type='UNIQUE' + AND constraint_name='uq_searchspace_user_connector_type' + """) + ).scalar() + + if old_constraint_exists: + op.drop_constraint( + "uq_searchspace_user_connector_type", + "search_source_connectors", + type_="unique", + ) + + # Check if new constraint already exists before creating it + new_constraint_exists = connection.execute( + text(""" + SELECT 1 FROM information_schema.table_constraints + WHERE table_name='search_source_connectors' + AND constraint_type='UNIQUE' + AND constraint_name='uq_searchspace_user_connector_type_name' + """) + ).scalar() + + if not new_constraint_exists: + op.create_unique_constraint( + "uq_searchspace_user_connector_type_name", + "search_source_connectors", + ["search_space_id", "user_id", "connector_type", "name"], + ) + + +def downgrade() -> None: + """Downgrade schema.""" + connection = op.get_bind() + + # Check if new constraint exists before trying to drop it + new_constraint_exists = connection.execute( + text(""" + SELECT 1 FROM information_schema.table_constraints + WHERE table_name='search_source_connectors' + AND constraint_type='UNIQUE' + AND constraint_name='uq_searchspace_user_connector_type_name' + """) + ).scalar() + + if new_constraint_exists: + op.drop_constraint( + "uq_searchspace_user_connector_type_name", + "search_source_connectors", + type_="unique", + ) + + # Check if old constraint already exists before creating it + old_constraint_exists = connection.execute( + text(""" + SELECT 1 FROM information_schema.table_constraints + WHERE table_name='search_source_connectors' + AND constraint_type='UNIQUE' + AND constraint_name='uq_searchspace_user_connector_type' + """) + ).scalar() + + if not old_constraint_exists: + op.create_unique_constraint( + "uq_searchspace_user_connector_type", + "search_source_connectors", + ["search_space_id", "user_id", "connector_type"], + ) diff --git a/surfsense_backend/alembic/versions/64_add_user_profile_columns.py b/surfsense_backend/alembic/versions/64_add_user_profile_columns.py new file mode 100644 index 000000000..db45982d8 --- /dev/null +++ b/surfsense_backend/alembic/versions/64_add_user_profile_columns.py @@ -0,0 +1,72 @@ +"""Add display_name and avatar_url columns to user table + +This migration adds: +- display_name column for user's full name from OAuth +- avatar_url column for user's profile picture URL from OAuth + +Revision ID: 64 +Revises: 63 +""" + +from collections.abc import Sequence + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "64" +down_revision: str | None = "63" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Add display_name and avatar_url columns to user table.""" + + # Add display_name column (nullable for existing users) + op.execute( + """ + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'user' AND column_name = 'display_name' + ) THEN + ALTER TABLE "user" + ADD COLUMN display_name VARCHAR; + END IF; + END$$; + """ + ) + + # Add avatar_url column (nullable for existing users) + op.execute( + """ + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'user' AND column_name = 'avatar_url' + ) THEN + ALTER TABLE "user" + ADD COLUMN avatar_url VARCHAR; + END IF; + END$$; + """ + ) + + +def downgrade() -> None: + """Remove display_name and avatar_url columns from user table.""" + + op.execute( + """ + ALTER TABLE "user" + DROP COLUMN IF EXISTS avatar_url; + """ + ) + op.execute( + """ + ALTER TABLE "user" + DROP COLUMN IF EXISTS display_name; + """ + ) diff --git a/surfsense_backend/alembic/versions/65_add_message_author_id.py b/surfsense_backend/alembic/versions/65_add_message_author_id.py new file mode 100644 index 000000000..8d891db81 --- /dev/null +++ b/surfsense_backend/alembic/versions/65_add_message_author_id.py @@ -0,0 +1,46 @@ +"""Add author_id column to new_chat_messages table + +Revision ID: 65 +Revises: 64 +""" + +from collections.abc import Sequence + +from alembic import op + +revision: str = "65" +down_revision: str | None = "64" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Add author_id column to new_chat_messages table.""" + op.execute( + """ + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'new_chat_messages' AND column_name = 'author_id' + ) THEN + ALTER TABLE new_chat_messages + ADD COLUMN author_id UUID REFERENCES "user"(id) ON DELETE SET NULL; + + CREATE INDEX IF NOT EXISTS ix_new_chat_messages_author_id + ON new_chat_messages(author_id); + END IF; + END$$; + """ + ) + + +def downgrade() -> None: + """Remove author_id column from new_chat_messages table.""" + op.execute( + """ + DROP INDEX IF EXISTS ix_new_chat_messages_author_id; + ALTER TABLE new_chat_messages + DROP COLUMN IF EXISTS author_id; + """ + ) diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 6c8deb409..9675521f5 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -20,7 +20,7 @@ from app.agents.new_chat.system_prompt import ( build_configurable_system_prompt, build_surfsense_system_prompt, ) -from app.agents.new_chat.tools import build_tools +from app.agents.new_chat.tools.registry import build_tools_async from app.services.connector_service import ConnectorService # ============================================================================= @@ -28,7 +28,7 @@ from app.services.connector_service import ConnectorService # ============================================================================= -def create_surfsense_deep_agent( +async def create_surfsense_deep_agent( llm: ChatLiteLLM, search_space_id: int, db_session: AsyncSession, @@ -120,8 +120,8 @@ def create_surfsense_deep_agent( "firecrawl_api_key": firecrawl_api_key, } - # Build tools using the registry - tools = build_tools( + # Build tools using the async registry (includes MCP tools) + tools = await build_tools_async( dependencies=dependencies, enabled_tools=enabled_tools, disabled_tools=disabled_tools, diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_client.py b/surfsense_backend/app/agents/new_chat/tools/mcp_client.py new file mode 100644 index 000000000..437f93043 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_client.py @@ -0,0 +1,203 @@ +"""MCP Client Wrapper. + +This module provides a client for communicating with MCP servers via stdio transport. +It handles server lifecycle management, tool discovery, and tool execution. +""" + +import logging +import os +from contextlib import asynccontextmanager +from typing import Any + +from mcp import ClientSession +from mcp.client.stdio import StdioServerParameters, stdio_client + +logger = logging.getLogger(__name__) + + +class MCPClient: + """Client for communicating with an MCP server.""" + + def __init__( + self, command: str, args: list[str], env: dict[str, str] | None = None + ): + """Initialize MCP client. + + Args: + command: Command to spawn the MCP server (e.g., "uvx", "node") + args: Arguments for the command (e.g., ["mcp-server-git"]) + env: Optional environment variables for the server process + + """ + self.command = command + self.args = args + self.env = env or {} + self.session: ClientSession | None = None + + @asynccontextmanager + async def connect(self): + """Connect to the MCP server and manage its lifecycle. + + Yields: + ClientSession: Active MCP session for making requests + + """ + try: + # Merge env vars with current environment + server_env = os.environ.copy() + server_env.update(self.env) + + # Create server parameters with env + server_params = StdioServerParameters( + command=self.command, args=self.args, env=server_env + ) + + # Spawn server process and create session + # Note: Cannot combine these context managers because ClientSession + # needs the read/write streams from stdio_client + async with stdio_client(server=server_params) as (read, write): # noqa: SIM117 + async with ClientSession(read, write) as session: + # Initialize the connection + await session.initialize() + self.session = session + logger.info( + "Connected to MCP server: %s %s", + self.command, + " ".join(self.args), + ) + yield session + + except Exception as e: + logger.error("Failed to connect to MCP server: %s", e, exc_info=True) + raise + finally: + self.session = None + logger.info("Disconnected from MCP server: %s", self.command) + + async def list_tools(self) -> list[dict[str, Any]]: + """List all tools available from the MCP server. + + Returns: + List of tool definitions with name, description, and input schema + + Raises: + RuntimeError: If not connected to server + + """ + if not self.session: + raise RuntimeError( + "Not connected to MCP server. Use 'async with client.connect():'" + ) + + try: + # Call tools/list RPC method + response = await self.session.list_tools() + + tools = [] + for tool in response.tools: + tools.append( + { + "name": tool.name, + "description": tool.description or "", + "input_schema": tool.inputSchema + if hasattr(tool, "inputSchema") + else {}, + } + ) + + logger.info("Listed %d tools from MCP server", len(tools)) + return tools + + except Exception as e: + logger.error("Failed to list tools from MCP server: %s", e, exc_info=True) + raise + + async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: + """Call a tool on the MCP server. + + Args: + tool_name: Name of the tool to call + arguments: Arguments to pass to the tool + + Returns: + Tool execution result + + Raises: + RuntimeError: If not connected to server + + """ + if not self.session: + raise RuntimeError( + "Not connected to MCP server. Use 'async with client.connect():'" + ) + + try: + logger.info( + "Calling MCP tool '%s' with arguments: %s", tool_name, arguments + ) + + # Call tools/call RPC method + response = await self.session.call_tool(tool_name, arguments=arguments) + + # Extract content from response + result = [] + for content in response.content: + if hasattr(content, "text"): + result.append(content.text) + elif hasattr(content, "data"): + result.append(str(content.data)) + else: + result.append(str(content)) + + result_str = "\n".join(result) if result else "" + logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200]) + return result_str + + except RuntimeError as e: + # Handle validation errors from MCP server responses + # Some MCP servers (like server-memory) return extra fields not in their schema + if "Invalid structured content" in str(e): + logger.warning( + "MCP server returned data not matching its schema, but continuing: %s", + e, + ) + # Try to extract result from error message or return a success message + return "Operation completed (server returned unexpected format)" + raise + except (ValueError, TypeError, AttributeError, KeyError) as e: + logger.error( + "Failed to call MCP tool '%s': %s", tool_name, e, exc_info=True + ) + return f"Error calling tool: {e!s}" + + +async def test_mcp_connection( + command: str, args: list[str], env: dict[str, str] | None = None +) -> dict[str, Any]: + """Test connection to an MCP server and fetch available tools. + + Args: + command: Command to spawn the MCP server + args: Arguments for the command + env: Optional environment variables + + Returns: + Dict with connection status and available tools + + """ + client = MCPClient(command, args, env) + + try: + async with client.connect(): + tools = await client.list_tools() + return { + "status": "success", + "message": f"Connected successfully. Found {len(tools)} tools.", + "tools": tools, + } + except (RuntimeError, ConnectionError, TimeoutError, OSError) as e: + return { + "status": "error", + "message": f"Failed to connect: {e!s}", + "tools": [], + } diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py new file mode 100644 index 000000000..0e5f1b993 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -0,0 +1,198 @@ +"""MCP Tool Factory. + +This module creates LangChain tools from MCP servers using the Model Context Protocol. +Tools are dynamically discovered from MCP servers - no manual configuration needed. + +This implements real MCP protocol support similar to Cursor's implementation. +""" + +import logging +from typing import Any + +from langchain_core.tools import StructuredTool +from pydantic import BaseModel, create_model +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.tools.mcp_client import MCPClient +from app.db import SearchSourceConnector, SearchSourceConnectorType + +logger = logging.getLogger(__name__) + + +def _create_dynamic_input_model_from_schema( + tool_name: str, + input_schema: dict[str, Any], +) -> type[BaseModel]: + """Create a Pydantic model from MCP tool's JSON schema. + + Args: + tool_name: Name of the tool (used for model class name) + input_schema: JSON schema from MCP server + + Returns: + Pydantic model class for tool input validation + + """ + properties = input_schema.get("properties", {}) + required_fields = input_schema.get("required", []) + + # Build Pydantic field definitions + field_definitions = {} + for param_name, param_schema in properties.items(): + param_description = param_schema.get("description", "") + is_required = param_name in required_fields + + # Use Any type for complex schemas to preserve structure + # This allows the MCP server to do its own validation + from typing import Any as AnyType + + from pydantic import Field + + if is_required: + field_definitions[param_name] = ( + AnyType, + Field(..., description=param_description), + ) + else: + field_definitions[param_name] = ( + AnyType | None, + Field(None, description=param_description), + ) + + # Create dynamic model + model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input" + return create_model(model_name, **field_definitions) + + +async def _create_mcp_tool_from_definition( + tool_def: dict[str, Any], + mcp_client: MCPClient, +) -> StructuredTool: + """Create a LangChain tool from an MCP tool definition. + + Args: + tool_def: Tool definition from MCP server with name, description, input_schema + mcp_client: MCP client instance for calling the tool + + Returns: + LangChain StructuredTool instance + + """ + tool_name = tool_def.get("name", "unnamed_tool") + tool_description = tool_def.get("description", "No description provided") + input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}}) + + # Log the actual schema for debugging + logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}") + + # Create dynamic input model from schema + input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema) + + async def mcp_tool_call(**kwargs) -> str: + """Execute the MCP tool call via the client.""" + logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}") + + try: + # Connect to server and call tool + async with mcp_client.connect(): + result = await mcp_client.call_tool(tool_name, kwargs) + return str(result) + except Exception as e: + error_msg = f"MCP tool '{tool_name}' failed: {e!s}" + logger.exception(error_msg) + return f"Error: {error_msg}" + + # Create StructuredTool with response_format to preserve exact schema + tool = StructuredTool( + name=tool_name, + description=tool_description, + coroutine=mcp_tool_call, + args_schema=input_model, + # Store the original MCP schema as metadata so we can access it later + metadata={"mcp_input_schema": input_schema}, + ) + + logger.info(f"Created MCP tool: '{tool_name}'") + return tool + + +async def load_mcp_tools( + session: AsyncSession, + search_space_id: int, +) -> list[StructuredTool]: + """Load all MCP tools from user's active MCP server connectors. + + This discovers tools dynamically from MCP servers using the protocol. + + Args: + session: Database session + search_space_id: User's search space ID + + Returns: + List of LangChain StructuredTool instances + + """ + try: + # Fetch all MCP connectors for this search space + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.connector_type + == SearchSourceConnectorType.MCP_CONNECTOR, + SearchSourceConnector.search_space_id == search_space_id, + ), + ) + + tools: list[StructuredTool] = [] + for connector in result.scalars(): + try: + # Extract server config + config = connector.config or {} + server_config = config.get("server_config", {}) + + command = server_config.get("command") + args = server_config.get("args", []) + env = server_config.get("env", {}) + + if not command: + logger.warning( + f"MCP connector {connector.id} missing command, skipping" + ) + continue + + # Create MCP client + mcp_client = MCPClient(command, args, env) + + # Connect and discover tools + async with mcp_client.connect(): + tool_definitions = await mcp_client.list_tools() + + logger.info( + f"Discovered {len(tool_definitions)} tools from MCP server " + f"'{command}' (connector {connector.id})" + ) + + # Create LangChain tools from definitions + for tool_def in tool_definitions: + try: + tool = await _create_mcp_tool_from_definition( + tool_def, mcp_client + ) + tools.append(tool) + except Exception as e: + logger.exception( + f"Failed to create tool '{tool_def.get('name')}' " + f"from connector {connector.id}: {e!s}", + ) + + except Exception as e: + logger.exception( + f"Failed to load tools from MCP connector {connector.id}: {e!s}", + ) + + logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}") + return tools + + except Exception as e: + logger.exception(f"Failed to load MCP tools: {e!s}") + return [] diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index c7439bf8f..6873f864c 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -1,5 +1,4 @@ -""" -Tools registry for SurfSense deep agent. +"""Tools registry for SurfSense deep agent. This module provides a registry pattern for managing tools in the SurfSense agent. It makes it easy for OSS contributors to add new tools by: @@ -37,6 +36,7 @@ Example of adding a new tool: ), """ +import logging from collections.abc import Callable from dataclasses import dataclass, field from typing import Any @@ -46,6 +46,7 @@ from langchain_core.tools import BaseTool from .display_image import create_display_image_tool from .knowledge_base import create_search_knowledge_base_tool from .link_preview import create_link_preview_tool +from .mcp_tool import load_mcp_tools from .podcast import create_generate_podcast_tool from .scrape_webpage import create_scrape_webpage_tool from .search_surfsense_docs import create_search_surfsense_docs_tool @@ -57,8 +58,7 @@ from .search_surfsense_docs import create_search_surfsense_docs_tool @dataclass class ToolDefinition: - """ - Definition of a tool that can be added to the agent. + """Definition of a tool that can be added to the agent. Attributes: name: Unique identifier for the tool @@ -66,6 +66,7 @@ class ToolDefinition: factory: Callable that creates the tool. Receives a dict of dependencies. requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session") enabled_by_default: Whether the tool is enabled when no explicit config is provided + """ name: str @@ -178,8 +179,7 @@ def build_tools( disabled_tools: list[str] | None = None, additional_tools: list[BaseTool] | None = None, ) -> list[BaseTool]: - """ - Build the list of tools for the agent. + """Build the list of tools for the agent. Args: dependencies: Dict containing all possible dependencies: @@ -206,6 +206,7 @@ def build_tools( # Add custom tools tools = build_tools(deps, additional_tools=[my_custom_tool]) + """ # Determine which tools to enable if enabled_tools is not None: @@ -226,8 +227,9 @@ def build_tools( # Check that all required dependencies are provided missing_deps = [dep for dep in tool_def.requires if dep not in dependencies] if missing_deps: + msg = f"Tool '{tool_def.name}' requires dependencies: {missing_deps}" raise ValueError( - f"Tool '{tool_def.name}' requires dependencies: {missing_deps}" + msg, ) # Create the tool @@ -239,3 +241,62 @@ def build_tools( tools.extend(additional_tools) return tools + + +async def build_tools_async( + dependencies: dict[str, Any], + enabled_tools: list[str] | None = None, + disabled_tools: list[str] | None = None, + additional_tools: list[BaseTool] | None = None, + include_mcp_tools: bool = True, +) -> list[BaseTool]: + """Async version of build_tools that also loads MCP tools from database. + + Design Note: + This function exists because MCP tools require database queries to load user configs, + while built-in tools are created synchronously from static code. + + Alternative: We could make build_tools() itself async and always query the database, + but that would force async everywhere even when only using built-in tools. The current + design keeps the simple case (static tools only) synchronous while supporting dynamic + database-loaded tools through this async wrapper. + + Args: + dependencies: Dict containing all possible dependencies + enabled_tools: Explicit list of tool names to enable. If None, uses defaults. + disabled_tools: List of tool names to disable (applied after enabled_tools). + additional_tools: Extra tools to add (e.g., custom tools not in registry). + include_mcp_tools: Whether to load user's MCP tools from database. + + Returns: + List of configured tool instances ready for the agent, including MCP tools. + + """ + # Build standard tools + tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools) + + # Load MCP tools if requested and dependencies are available + if ( + include_mcp_tools + and "db_session" in dependencies + and "search_space_id" in dependencies + ): + try: + mcp_tools = await load_mcp_tools( + dependencies["db_session"], + dependencies["search_space_id"], + ) + tools.extend(mcp_tools) + logging.info( + f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}", + ) + except Exception as e: + # Log error but don't fail - just continue without MCP tools + logging.exception(f"Failed to load MCP tools: {e!s}") + + # Log all tools being returned to agent + logging.info( + f"Total tools for agent: {len(tools)} - {[t.name for t in tools]}", + ) + + return tools diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 73727a9ef..fd2100400 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -80,6 +80,7 @@ class SearchSourceConnectorType(str, Enum): WEBCRAWLER_CONNECTOR = "WEBCRAWLER_CONNECTOR" BOOKSTACK_CONNECTOR = "BOOKSTACK_CONNECTOR" CIRCLEBACK_CONNECTOR = "CIRCLEBACK_CONNECTOR" + MCP_CONNECTOR = "MCP_CONNECTOR" # Model Context Protocol - User-defined API tools class LiteLLMProvider(str, Enum): @@ -412,8 +413,17 @@ class NewChatMessage(BaseModel, TimestampMixin): index=True, ) - # Relationship + # Track who sent this message (for shared chats) + author_id = Column( + UUID(as_uuid=True), + ForeignKey("user.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + + # Relationships thread = relationship("NewChatThread", back_populates="messages") + author = relationship("User") class Document(BaseModel, TimestampMixin): @@ -605,7 +615,8 @@ class SearchSourceConnector(BaseModel, TimestampMixin): "search_space_id", "user_id", "connector_type", - name="uq_searchspace_user_connector_type", + "name", + name="uq_searchspace_user_connector_type_name", ), ) @@ -874,6 +885,10 @@ if config.AUTH_TYPE == "GOOGLE": ) pages_used = Column(Integer, nullable=False, default=0, server_default="0") + # User profile from OAuth + display_name = Column(String, nullable=True) + avatar_url = Column(String, nullable=True) + else: class User(SQLAlchemyBaseUserTableUUID, Base): @@ -907,6 +922,10 @@ else: ) pages_used = Column(Integer, nullable=False, default=0, server_default="0") + # User profile (can be set manually for non-OAuth users) + display_name = Column(String, nullable=True) + avatar_url = Column(String, nullable=True) + engine = create_async_engine(DATABASE_URL) async_session_maker = async_sessionmaker(engine, expire_on_commit=False) diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index fb5808307..e4dc5714a 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -411,11 +411,9 @@ async def get_thread_messages( Requires CHATS_READ permission. """ try: - # Get thread with messages + # Get thread first result = await session.execute( - select(NewChatThread) - .options(selectinload(NewChatThread.messages)) - .filter(NewChatThread.id == thread_id) + select(NewChatThread).filter(NewChatThread.id == thread_id) ) thread = result.scalars().first() @@ -434,6 +432,15 @@ async def get_thread_messages( # Check thread-level access based on visibility await check_thread_access(session, thread, user) + # Get messages with their authors loaded + messages_result = await session.execute( + select(NewChatMessage) + .options(selectinload(NewChatMessage.author)) + .filter(NewChatMessage.thread_id == thread_id) + .order_by(NewChatMessage.created_at) + ) + db_messages = messages_result.scalars().all() + # Return messages in the format expected by assistant-ui messages = [ NewChatMessageRead( @@ -442,8 +449,11 @@ async def get_thread_messages( role=msg.role, content=msg.content, created_at=msg.created_at, + author_id=msg.author_id, + author_display_name=msg.author.display_name if msg.author else None, + author_avatar_url=msg.author.avatar_url if msg.author else None, ) - for msg in thread.messages + for msg in db_messages ] return ThreadHistoryLoadResponse(messages=messages) @@ -782,6 +792,7 @@ async def append_message( thread_id=thread_id, role=message_role, content=message.content, + author_id=user.id, ) session.add(db_message) diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index 8e8ebb72d..a7c577bba 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -7,6 +7,13 @@ PUT /search-source-connectors/{connector_id} - Update a specific connector DELETE /search-source-connectors/{connector_id} - Delete a specific connector POST /search-source-connectors/{connector_id}/index - Index content from a connector to a search space +MCP (Model Context Protocol) Connector routes: +POST /connectors/mcp - Create a new MCP connector with custom API tools +GET /connectors/mcp - List all MCP connectors for the current user's search space +GET /connectors/mcp/{connector_id} - Get a specific MCP connector with tools config +PUT /connectors/mcp/{connector_id} - Update an MCP connector's tools config +DELETE /connectors/mcp/{connector_id} - Delete an MCP connector + Note: OAuth connectors (Gmail, Drive, Slack, etc.) support multiple accounts per search space. Non-OAuth connectors (BookStack, GitHub, etc.) are limited to one per search space. """ @@ -32,6 +39,9 @@ from app.db import ( ) from app.schemas import ( GoogleDriveIndexRequest, + MCPConnectorCreate, + MCPConnectorRead, + MCPConnectorUpdate, SearchSourceConnectorBase, SearchSourceConnectorCreate, SearchSourceConnectorRead, @@ -127,18 +137,20 @@ async def create_search_source_connector( # Check if a connector with the same type already exists for this search space # (for non-OAuth connectors that don't support multiple accounts) - result = await session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.connector_type == connector.connector_type, - ) - ) - existing_connector = result.scalars().first() - if existing_connector: - raise HTTPException( - status_code=409, - detail=f"A connector with type {connector.connector_type} already exists in this search space.", + # Exception: MCP_CONNECTOR can have multiple instances with different names + if connector.connector_type != SearchSourceConnectorType.MCP_CONNECTOR: + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.connector_type == connector.connector_type, + ) ) + existing_connector = result.scalars().first() + if existing_connector: + raise HTTPException( + status_code=409, + detail=f"A connector with type {connector.connector_type} already exists in this search space.", + ) # Prepare connector data connector_data = connector.model_dump() @@ -1964,3 +1976,348 @@ async def run_bookstack_indexing( f"Critical error in run_bookstack_indexing for connector {connector_id}: {e}", exc_info=True, ) + + +# ============================================================================= +# MCP Connector Routes +# ============================================================================= + + +@router.post("/connectors/mcp", response_model=MCPConnectorRead, status_code=201) +async def create_mcp_connector( + connector_data: MCPConnectorCreate, + search_space_id: int = Query(..., description="Search space ID"), + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Create a new MCP (Model Context Protocol) connector. + + MCP connectors allow users to connect to MCP servers (like in Cursor). + Tools are auto-discovered from the server - no manual configuration needed. + + Args: + connector_data: MCP server configuration (command, args, env) + search_space_id: ID of the search space to attach the connector to + session: Database session + user: Current authenticated user + + Returns: + Created MCP connector with server configuration + + Raises: + HTTPException: If search space not found or permission denied + """ + try: + # Check user has permission to create connectors + await check_permission( + session, + user, + search_space_id, + Permission.CONNECTORS_CREATE.value, + "You don't have permission to create connectors in this search space", + ) + + # Create the connector with server config + db_connector = SearchSourceConnector( + name=connector_data.name, + connector_type=SearchSourceConnectorType.MCP_CONNECTOR, + is_indexable=False, # MCP connectors are not indexable + config={"server_config": connector_data.server_config.model_dump()}, + periodic_indexing_enabled=False, + indexing_frequency_minutes=None, + search_space_id=search_space_id, + user_id=user.id, + ) + + session.add(db_connector) + await session.commit() + await session.refresh(db_connector) + + logger.info( + f"Created MCP connector {db_connector.id} for server '{connector_data.server_config.command}' " + f"for user {user.id} in search space {search_space_id}" + ) + + # Convert to read schema + connector_read = SearchSourceConnectorRead.model_validate(db_connector) + return MCPConnectorRead.from_connector(connector_read) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to create MCP connector: {e!s}", exc_info=True) + await session.rollback() + raise HTTPException( + status_code=500, detail=f"Failed to create MCP connector: {e!s}" + ) from e + + +@router.get("/connectors/mcp", response_model=list[MCPConnectorRead]) +async def list_mcp_connectors( + search_space_id: int = Query(..., description="Search space ID"), + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + List all MCP connectors for a search space. + + Args: + search_space_id: ID of the search space + session: Database session + user: Current authenticated user + + Returns: + List of MCP connectors with their tool configurations + """ + try: + # Check user has permission to read connectors + await check_permission( + session, + user, + search_space_id, + Permission.CONNECTORS_READ.value, + "You don't have permission to view connectors in this search space", + ) + + # Fetch MCP connectors + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.connector_type + == SearchSourceConnectorType.MCP_CONNECTOR, + SearchSourceConnector.search_space_id == search_space_id, + ) + ) + + connectors = result.scalars().all() + return [ + MCPConnectorRead.from_connector(SearchSourceConnectorRead.model_validate(c)) + for c in connectors + ] + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to list MCP connectors: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to list MCP connectors: {e!s}" + ) from e + + +@router.get("/connectors/mcp/{connector_id}", response_model=MCPConnectorRead) +async def get_mcp_connector( + connector_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Get a specific MCP connector by ID. + + Args: + connector_id: ID of the connector + session: Database session + user: Current authenticated user + + Returns: + MCP connector with tool configurations + """ + try: + # Fetch connector + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.MCP_CONNECTOR, + ) + ) + connector = result.scalars().first() + + if not connector: + raise HTTPException(status_code=404, detail="MCP connector not found") + + # Check user has permission to read connectors + await check_permission( + session, + user, + connector.search_space_id, + Permission.CONNECTORS_READ.value, + "You don't have permission to view this connector", + ) + + connector_read = SearchSourceConnectorRead.model_validate(connector) + return MCPConnectorRead.from_connector(connector_read) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to get MCP connector: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to get MCP connector: {e!s}" + ) from e + + +@router.put("/connectors/mcp/{connector_id}", response_model=MCPConnectorRead) +async def update_mcp_connector( + connector_id: int, + connector_update: MCPConnectorUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Update an MCP connector. + + Args: + connector_id: ID of the connector to update + connector_update: Updated connector data + session: Database session + user: Current authenticated user + + Returns: + Updated MCP connector + """ + try: + # Fetch connector + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.MCP_CONNECTOR, + ) + ) + connector = result.scalars().first() + + if not connector: + raise HTTPException(status_code=404, detail="MCP connector not found") + + # Check user has permission to update connectors + await check_permission( + session, + user, + connector.search_space_id, + Permission.CONNECTORS_UPDATE.value, + "You don't have permission to update this connector", + ) + + # Update fields + if connector_update.name is not None: + connector.name = connector_update.name + + if connector_update.server_config is not None: + connector.config = { + "server_config": connector_update.server_config.model_dump() + } + + connector.updated_at = datetime.now(UTC) + + await session.commit() + await session.refresh(connector) + + logger.info(f"Updated MCP connector {connector_id}") + + connector_read = SearchSourceConnectorRead.model_validate(connector) + return MCPConnectorRead.from_connector(connector_read) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to update MCP connector: {e!s}", exc_info=True) + await session.rollback() + raise HTTPException( + status_code=500, detail=f"Failed to update MCP connector: {e!s}" + ) from e + + +@router.delete("/connectors/mcp/{connector_id}", status_code=204) +async def delete_mcp_connector( + connector_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Delete an MCP connector. + + Args: + connector_id: ID of the connector to delete + session: Database session + user: Current authenticated user + """ + try: + # Fetch connector + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.MCP_CONNECTOR, + ) + ) + connector = result.scalars().first() + + if not connector: + raise HTTPException(status_code=404, detail="MCP connector not found") + + # Check user has permission to delete connectors + await check_permission( + session, + user, + connector.search_space_id, + Permission.CONNECTORS_DELETE.value, + "You don't have permission to delete this connector", + ) + + await session.delete(connector) + await session.commit() + + logger.info(f"Deleted MCP connector {connector_id}") + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to delete MCP connector: {e!s}", exc_info=True) + await session.rollback() + raise HTTPException( + status_code=500, detail=f"Failed to delete MCP connector: {e!s}" + ) from e + + +@router.post("/connectors/mcp/test") +async def test_mcp_server_connection( + server_config: dict = Body(...), + user: User = Depends(current_active_user), +): + """ + Test connection to an MCP server and fetch available tools. + + This endpoint allows users to test their MCP server configuration + before saving it, similar to Cursor's flow. + + Args: + server_config: Server configuration with command, args, env + user: Current authenticated user + + Returns: + Connection status and list of available tools + """ + try: + from app.agents.new_chat.tools.mcp_client import test_mcp_connection + + command = server_config.get("command") + args = server_config.get("args", []) + env = server_config.get("env", {}) + + if not command: + raise HTTPException(status_code=400, detail="Server command is required") + + # Test the connection + result = await test_mcp_connection(command, args, env) + + return result + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to test MCP connection: {e!s}", exc_info=True) + return { + "status": "error", + "message": f"Failed to test connection: {e!s}", + "tools": [], + } diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index a8bde7ed9..076ac5915 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -55,6 +55,10 @@ from .rbac_schemas import ( UserSearchSpaceAccess, ) from .search_source_connector import ( + MCPConnectorCreate, + MCPConnectorRead, + MCPConnectorUpdate, + MCPServerConfig, SearchSourceConnectorBase, SearchSourceConnectorCreate, SearchSourceConnectorRead, @@ -108,6 +112,11 @@ __all__ = [ "LogFilter", "LogRead", "LogUpdate", + # Search source connector schemas + "MCPConnectorCreate", + "MCPConnectorRead", + "MCPConnectorUpdate", + "MCPServerConfig", "MembershipRead", "MembershipReadWithUser", "MembershipUpdate", @@ -135,7 +144,6 @@ __all__ = [ "RoleCreate", "RoleRead", "RoleUpdate", - # Search source connector schemas "SearchSourceConnectorBase", "SearchSourceConnectorCreate", "SearchSourceConnectorRead", diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index e6dbcd920..3734b0470 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -38,6 +38,9 @@ class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel): """Schema for reading a message.""" thread_id: int + author_id: UUID | None = None + author_display_name: str | None = None + author_avatar_url: str | None = None model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/schemas/search_source_connector.py b/surfsense_backend/app/schemas/search_source_connector.py index dbe4dce1f..5fd7a5aab 100644 --- a/surfsense_backend/app/schemas/search_source_connector.py +++ b/surfsense_backend/app/schemas/search_source_connector.py @@ -23,7 +23,9 @@ class SearchSourceConnectorBase(BaseModel): @field_validator("config") @classmethod def validate_config_for_connector_type( - cls, config: dict[str, Any], values: dict[str, Any] + cls, + config: dict[str, Any], + values: dict[str, Any], ) -> dict[str, Any]: connector_type = values.data.get("connector_type") return validate_connector_config(connector_type, config) @@ -38,15 +40,18 @@ class SearchSourceConnectorBase(BaseModel): """ if self.periodic_indexing_enabled: if not self.is_indexable: + msg = "periodic_indexing_enabled can only be True for indexable connectors" raise ValueError( - "periodic_indexing_enabled can only be True for indexable connectors" + msg, ) if self.indexing_frequency_minutes is None: + msg = "indexing_frequency_minutes is required when periodic_indexing_enabled is True" raise ValueError( - "indexing_frequency_minutes is required when periodic_indexing_enabled is True" + msg, ) if self.indexing_frequency_minutes <= 0: - raise ValueError("indexing_frequency_minutes must be greater than 0") + msg = "indexing_frequency_minutes must be greater than 0" + raise ValueError(msg) return self @@ -70,3 +75,63 @@ class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampMod user_id: uuid.UUID model_config = ConfigDict(from_attributes=True) + + +# ============================================================================= +# MCP-specific schemas +# ============================================================================= + + +class MCPServerConfig(BaseModel): + """Configuration for an MCP server connection (similar to Cursor's config).""" + + command: str # e.g., "uvx", "node", "python" + args: list[str] = [] # e.g., ["mcp-server-git", "--repository", "/path"] + env: dict[str, str] = {} # Environment variables for the server process + transport: str = "stdio" # "stdio" | "sse" | "http" (stdio is most common) + + +class MCPConnectorCreate(BaseModel): + """Schema for creating an MCP connector.""" + + name: str + server_config: MCPServerConfig + + +class MCPConnectorUpdate(BaseModel): + """Schema for updating an MCP connector.""" + + name: str | None = None + server_config: MCPServerConfig | None = None + + +class MCPConnectorRead(BaseModel): + """Schema for reading an MCP connector with server config.""" + + id: int + name: str + connector_type: SearchSourceConnectorType + server_config: MCPServerConfig + search_space_id: int + user_id: uuid.UUID + created_at: datetime + updated_at: datetime + + model_config = ConfigDict(from_attributes=True) + + @classmethod + def from_connector(cls, connector: SearchSourceConnectorRead) -> "MCPConnectorRead": + """Convert from base SearchSourceConnectorRead.""" + config = connector.config or {} + server_config = MCPServerConfig(**config.get("server_config", {})) + + return cls( + id=connector.id, + name=connector.name, + connector_type=connector.connector_type, + server_config=server_config, + search_space_id=connector.search_space_id, + user_id=connector.user_id, + created_at=connector.created_at, + updated_at=connector.updated_at, + ) diff --git a/surfsense_backend/app/schemas/users.py b/surfsense_backend/app/schemas/users.py index a8e0cfac8..88d0a4f37 100644 --- a/surfsense_backend/app/schemas/users.py +++ b/surfsense_backend/app/schemas/users.py @@ -6,6 +6,8 @@ from fastapi_users import schemas class UserRead(schemas.BaseUser[uuid.UUID]): pages_limit: int pages_used: int + display_name: str | None = None + avatar_url: str | None = None class UserCreate(schemas.BaseUserCreate): @@ -13,4 +15,5 @@ class UserCreate(schemas.BaseUserCreate): class UserUpdate(schemas.BaseUserUpdate): - pass + display_name: str | None = None + avatar_url: str | None = None diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index a74f134dc..5f8cd638b 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -237,7 +237,7 @@ async def stream_new_chat( checkpointer = await get_checkpointer() # Create the deep agent with checkpointer and configurable prompts - agent = create_surfsense_deep_agent( + agent = await create_surfsense_deep_agent( llm=llm, search_space_id=search_space_id, db_session=session, diff --git a/surfsense_backend/app/tasks/document_processors/file_processors.py b/surfsense_backend/app/tasks/document_processors/file_processors.py index 596cd9830..f3b5cba9d 100644 --- a/surfsense_backend/app/tasks/document_processors/file_processors.py +++ b/surfsense_backend/app/tasks/document_processors/file_processors.py @@ -2,11 +2,14 @@ File document processors for different ETL services (Unstructured, LlamaCloud, Docling). """ +import asyncio import contextlib import logging +import ssl import warnings from logging import ERROR, getLogger +import httpx from fastapi import HTTPException from langchain_core.documents import Document as LangChainDocument from litellm import atranscription @@ -31,6 +34,122 @@ from .base import ( ) from .markdown_processor import add_received_markdown_file_document +# Constants for LlamaCloud retry configuration +LLAMACLOUD_MAX_RETRIES = 3 +LLAMACLOUD_BASE_DELAY = 5 # Base delay in seconds for exponential backoff +LLAMACLOUD_RETRYABLE_EXCEPTIONS = ( + ssl.SSLError, + httpx.ConnectError, + httpx.ConnectTimeout, + httpx.ReadTimeout, + httpx.WriteTimeout, + ConnectionError, + TimeoutError, +) + + +async def parse_with_llamacloud_retry( + file_path: str, + estimated_pages: int, + task_logger: TaskLoggingService | None = None, + log_entry: Log | None = None, +): + """ + Parse a file with LlamaCloud with retry logic for transient SSL/connection errors. + + Args: + file_path: Path to the file to parse + estimated_pages: Estimated number of pages for timeout calculation + task_logger: Optional task logger for progress updates + log_entry: Optional log entry for progress updates + + Returns: + LlamaParse result object + + Raises: + Exception: If all retries fail + """ + from llama_cloud_services import LlamaParse + from llama_cloud_services.parse.utils import ResultType + + # Calculate timeouts based on estimated pages + # Base timeout of 300 seconds + 30 seconds per page for large documents + base_timeout = 300 + per_page_timeout = 30 + job_timeout = base_timeout + (estimated_pages * per_page_timeout) + + # Create custom httpx client with larger timeouts for file uploads + # The SSL error often occurs during large file uploads, so we need generous timeouts + custom_timeout = httpx.Timeout( + connect=60.0, # 60 seconds to establish connection + read=300.0, # 5 minutes to read response + write=300.0, # 5 minutes to write/upload (important for large files) + pool=60.0, # 60 seconds to acquire connection from pool + ) + + last_exception = None + + for attempt in range(1, LLAMACLOUD_MAX_RETRIES + 1): + try: + # Create a fresh httpx client for each attempt + async with httpx.AsyncClient(timeout=custom_timeout) as custom_client: + # Create LlamaParse parser instance with optimized settings + parser = LlamaParse( + api_key=app_config.LLAMA_CLOUD_API_KEY, + num_workers=1, # Use single worker for file processing + verbose=True, + language="en", + result_type=ResultType.MD, + # Timeout settings for large files + max_timeout=max(2000, job_timeout), # Overall max timeout + job_timeout_in_seconds=job_timeout, + job_timeout_extra_time_per_page_in_seconds=per_page_timeout, + # Use our custom client with larger timeouts + custom_client=custom_client, + ) + + # Parse the file asynchronously + result = await parser.aparse(file_path) + return result + + except LLAMACLOUD_RETRYABLE_EXCEPTIONS as e: + last_exception = e + error_type = type(e).__name__ + + if attempt < LLAMACLOUD_MAX_RETRIES: + # Calculate exponential backoff delay + delay = LLAMACLOUD_BASE_DELAY * (2 ** (attempt - 1)) + + if task_logger and log_entry: + await task_logger.log_task_progress( + log_entry, + f"LlamaCloud upload failed (attempt {attempt}/{LLAMACLOUD_MAX_RETRIES}), retrying in {delay}s", + { + "error_type": error_type, + "error_message": str(e)[:200], + "attempt": attempt, + "retry_delay": delay, + }, + ) + else: + logging.warning( + f"LlamaCloud upload failed (attempt {attempt}/{LLAMACLOUD_MAX_RETRIES}): {error_type}. " + f"Retrying in {delay}s..." + ) + + await asyncio.sleep(delay) + else: + logging.error( + f"LlamaCloud upload failed after {LLAMACLOUD_MAX_RETRIES} attempts: {error_type} - {e}" + ) + + except Exception: + # Non-retryable exception, raise immediately + raise + + # All retries exhausted + raise last_exception or RuntimeError("LlamaCloud parsing failed after all retries") + async def add_received_file_document_using_unstructured( session: AsyncSession, @@ -819,24 +938,18 @@ async def process_file_in_background( "file_type": "document", "etl_service": "LLAMACLOUD", "processing_stage": "parsing", + "estimated_pages": estimated_pages_before, }, ) - from llama_cloud_services import LlamaParse - from llama_cloud_services.parse.utils import ResultType - - # Create LlamaParse parser instance - parser = LlamaParse( - api_key=app_config.LLAMA_CLOUD_API_KEY, - num_workers=1, # Use single worker for file processing - verbose=True, - language="en", - result_type=ResultType.MD, + # Parse file with retry logic for SSL/connection errors (common with large files) + result = await parse_with_llamacloud_retry( + file_path=file_path, + estimated_pages=estimated_pages_before, + task_logger=task_logger, + log_entry=log_entry, ) - # Parse the file asynchronously - result = await parser.aparse(file_path) - # Clean up the temp file import os diff --git a/surfsense_backend/app/users.py b/surfsense_backend/app/users.py index dd284307f..e86eb752b 100644 --- a/surfsense_backend/app/users.py +++ b/surfsense_backend/app/users.py @@ -1,6 +1,7 @@ import logging import uuid +import httpx from fastapi import Depends, Request, Response from fastapi.responses import JSONResponse, RedirectResponse from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models @@ -46,6 +47,71 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): reset_password_token_secret = SECRET verification_token_secret = SECRET + async def oauth_callback( + self, + oauth_name: str, + access_token: str, + account_id: str, + account_email: str, + expires_at: int | None = None, + refresh_token: str | None = None, + request: Request | None = None, + *, + associate_by_email: bool = False, + is_verified_by_default: bool = False, + ) -> User: + """ + Override OAuth callback to capture Google profile data (name, avatar). + """ + # Call parent implementation to create/get user + user = await super().oauth_callback( + oauth_name, + access_token, + account_id, + account_email, + expires_at, + refresh_token, + request, + associate_by_email=associate_by_email, + is_verified_by_default=is_verified_by_default, + ) + + # Fetch and store Google profile data if not already set + if oauth_name == "google" and (not user.display_name or not user.avatar_url): + try: + async with httpx.AsyncClient() as client: + response = await client.get( + "https://people.googleapis.com/v1/people/me", + params={"personFields": "names,photos"}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + response.raise_for_status() + profile = response.json() + + update_dict = {} + + # Extract name from names array + names = profile.get("names", []) + if not user.display_name and names: + display_name = names[0].get("displayName") + if display_name: + update_dict["display_name"] = display_name + + # Extract photo URL from photos array + photos = profile.get("photos", []) + if not user.avatar_url and photos: + photo_url = photos[0].get("url") + if photo_url: + update_dict["avatar_url"] = photo_url + + if update_dict: + user = await self.user_db.update(user, update_dict) + + except Exception as e: + logger.warning(f"Failed to fetch Google profile: {e}") + + return user + async def on_after_register(self, user: User, request: Request | None = None): """ Called after a user registers. Creates a default search space for the user diff --git a/surfsense_backend/app/utils/connector_naming.py b/surfsense_backend/app/utils/connector_naming.py index 731f419d6..a2b748a3a 100644 --- a/surfsense_backend/app/utils/connector_naming.py +++ b/surfsense_backend/app/utils/connector_naming.py @@ -8,9 +8,9 @@ from typing import Any from urllib.parse import urlparse from uuid import UUID -from sqlalchemy import func +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select +from sqlalchemy.sql import func from app.db import SearchSourceConnector, SearchSourceConnectorType @@ -27,6 +27,7 @@ BASE_NAME_FOR_TYPE = { SearchSourceConnectorType.DISCORD_CONNECTOR: "Discord", SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "Confluence", SearchSourceConnectorType.AIRTABLE_CONNECTOR: "Airtable", + SearchSourceConnectorType.MCP_CONNECTOR: "Model Context Protocol (MCP)", } @@ -75,7 +76,7 @@ def extract_identifier_from_credentials( if ".atlassian.net" in hostname: return hostname.replace(".atlassian.net", "") return hostname - except Exception: + except (ValueError, TypeError, AttributeError): pass return None diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml index e3e7583f8..83a00b4e4 100644 --- a/surfsense_backend/pyproject.toml +++ b/surfsense_backend/pyproject.toml @@ -57,6 +57,9 @@ dependencies = [ "chonkie[all]>=1.5.0", "langgraph-checkpoint-postgres>=3.0.2", "psycopg[binary,pool]>=3.3.2", + "mcp>=1.25.0", + "starlette>=0.40.0,<0.51.0", + "sse-starlette>=3.1.1,<3.1.2", ] [dependency-groups] diff --git a/surfsense_backend/uv.lock b/surfsense_backend/uv.lock index 8ec09ddd9..2187a88cb 100644 --- a/surfsense_backend/uv.lock +++ b/surfsense_backend/uv.lock @@ -175,6 +175,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/26/99/fc813cd978842c26c82534010ea849eee9ab3a13ea2b74e95cb9c99e747b/amqp-5.3.1-py3-none-any.whl", hash = "sha256:43b3319e1b4e7d1251833a93d672b4af1e40f3d632d479b98661a95f117880a2", size = 50944 }, ] +[[package]] +name = "annotated-doc" +version = "0.0.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303 }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -1568,16 +1577,17 @@ wheels = [ [[package]] name = "fastapi" -version = "0.115.9" +version = "0.128.0" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "annotated-doc" }, { name = "pydantic" }, { name = "starlette" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ab/dd/d854f85e70f7341b29e3fda754f2833aec197bd355f805238758e3bcd8ed/fastapi-0.115.9.tar.gz", hash = "sha256:9d7da3b196c5eed049bc769f9475cd55509a112fbe031c0ef2f53768ae68d13f", size = 293774 } +sdist = { url = "https://files.pythonhosted.org/packages/52/08/8c8508db6c7b9aae8f7175046af41baad690771c9bcde676419965e338c7/fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a", size = 365682 } wheels = [ - { url = "https://files.pythonhosted.org/packages/32/b6/7517af5234378518f27ad35a7b24af9591bc500b8c1780929c1295999eb6/fastapi-0.115.9-py3-none-any.whl", hash = "sha256:4a439d7923e4de796bcc88b64e9754340fcd1574673cbd865ba8a99fe0d28c56", size = 94919 }, + { url = "https://files.pythonhosted.org/packages/5c/05/5cbb59154b093548acd0f4c7c474a118eda06da25aa75c616b72d8fcd92a/fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d", size = 103094 }, ] [[package]] @@ -3482,6 +3492,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1b/92/9a45c91089c3cf690b5badd4be81e392ff086ccca8a1d4e3a08463d8a966/matplotlib-3.10.3-cp313-cp313t-win_amd64.whl", hash = "sha256:4f23ffe95c5667ef8a2b56eea9b53db7f43910fa4a2d5472ae0f72b64deab4d5", size = 8139044 }, ] +[[package]] +name = "mcp" +version = "1.25.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "httpx" }, + { name = "httpx-sse" }, + { name = "jsonschema" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "pyjwt", extra = ["crypto"] }, + { name = "python-multipart" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "sse-starlette" }, + { name = "starlette" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, + { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d5/2d/649d80a0ecf6a1f82632ca44bec21c0461a9d9fc8934d38cb5b319f2db5e/mcp-1.25.0.tar.gz", hash = "sha256:56310361ebf0364e2d438e5b45f7668cbb124e158bb358333cd06e49e83a6802", size = 605387 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e2/fc/6dc7659c2ae5ddf280477011f4213a74f806862856b796ef08f028e664bf/mcp-1.25.0-py3-none-any.whl", hash = "sha256:b37c38144a666add0862614cc79ec276e97d72aa8ca26d622818d4e278b9721a", size = 233076 }, +] + [[package]] name = "mdurl" version = "0.1.2" @@ -6382,15 +6417,29 @@ wheels = [ ] [[package]] -name = "starlette" -version = "0.45.3" +name = "sse-starlette" +version = "3.1.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, + { name = "starlette" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ff/fb/2984a686808b89a6781526129a4b51266f678b2d2b97ab2d325e56116df8/starlette-0.45.3.tar.gz", hash = "sha256:2cbcba2a75806f8a41c722141486f37c28e30a0921c5f6fe4346cb0dcee1302f", size = 2574076 } +sdist = { url = "https://files.pythonhosted.org/packages/62/08/8f554b0e5bad3e4e880521a1686d96c05198471eed860b0eb89b57ea3636/sse_starlette-3.1.1.tar.gz", hash = "sha256:bffa531420c1793ab224f63648c059bcadc412bf9fdb1301ac8de1cf9a67b7fb", size = 24306 } wheels = [ - { url = "https://files.pythonhosted.org/packages/d9/61/f2b52e107b1fc8944b33ef56bf6ac4ebbe16d91b94d2b87ce013bf63fb84/starlette-0.45.3-py3-none-any.whl", hash = "sha256:dfb6d332576f136ec740296c7e8bb8c8a7125044e7c6da30744718880cdd059d", size = 71507 }, + { url = "https://files.pythonhosted.org/packages/e3/31/4c281581a0f8de137b710a07f65518b34bcf333b201cfa06cfda9af05f8a/sse_starlette-3.1.1-py3-none-any.whl", hash = "sha256:bb38f71ae74cfd86b529907a9fda5632195dfa6ae120f214ea4c890c7ee9d436", size = 12442 }, +] + +[[package]] +name = "starlette" +version = "0.50.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ba/b8/73a0e6a6e079a9d9cfa64113d771e421640b6f679a52eeb9b32f72d871a1/starlette-0.50.0.tar.gz", hash = "sha256:a2a17b22203254bcbc2e1f926d2d55f3f9497f769416b3190768befe598fa3ca", size = 2646985 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d9/52/1064f510b141bd54025f9b55105e26d1fa970b9be67ad766380a3c9b74b0/starlette-0.50.0-py3-none-any.whl", hash = "sha256:9e5391843ec9b6e472eed1365a78c8098cfceb7a74bfd4d6b1c0c0095efb3bca", size = 74033 }, ] [[package]] @@ -6443,6 +6492,7 @@ dependencies = [ { name = "litellm" }, { name = "llama-cloud-services" }, { name = "markdownify" }, + { name = "mcp" }, { name = "notion-client" }, { name = "numpy" }, { name = "pgvector" }, @@ -6457,6 +6507,8 @@ dependencies = [ { name = "slack-sdk" }, { name = "soundfile" }, { name = "spacy" }, + { name = "sse-starlette" }, + { name = "starlette" }, { name = "static-ffmpeg" }, { name = "tavily-python" }, { name = "trafilatura" }, @@ -6505,6 +6557,7 @@ requires-dist = [ { name = "litellm", specifier = ">=1.80.10" }, { name = "llama-cloud-services", specifier = ">=0.6.25" }, { name = "markdownify", specifier = ">=0.14.1" }, + { name = "mcp", specifier = ">=1.25.0" }, { name = "notion-client", specifier = ">=2.3.0" }, { name = "numpy", specifier = ">=1.24.0" }, { name = "pgvector", specifier = ">=0.3.6" }, @@ -6519,6 +6572,8 @@ requires-dist = [ { name = "slack-sdk", specifier = ">=3.34.0" }, { name = "soundfile", specifier = ">=0.13.1" }, { name = "spacy", specifier = ">=3.8.7" }, + { name = "sse-starlette", specifier = ">=3.1.1,<3.1.2" }, + { name = "starlette", specifier = ">=0.40.0,<0.51.0" }, { name = "static-ffmpeg", specifier = ">=2.13" }, { name = "tavily-python", specifier = ">=0.3.2" }, { name = "trafilatura", specifier = ">=2.0.0" }, diff --git a/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsTableShell.tsx b/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsTableShell.tsx index 566e103ac..94c0626e6 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsTableShell.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsTableShell.tsx @@ -79,25 +79,17 @@ export function DocumentsTableShell({ [documents, sortKey, sortDesc] ); - // Filter out SURFSENSE_DOCS for selection purposes - const selectableDocs = React.useMemo( - () => sorted.filter((d) => d.document_type !== "SURFSENSE_DOCS"), - [sorted] - ); - - const allSelectedOnPage = - selectableDocs.length > 0 && selectableDocs.every((d) => selectedIds.has(d.id)); - const someSelectedOnPage = - selectableDocs.some((d) => selectedIds.has(d.id)) && !allSelectedOnPage; + const allSelectedOnPage = sorted.length > 0 && sorted.every((d) => selectedIds.has(d.id)); + const someSelectedOnPage = sorted.some((d) => selectedIds.has(d.id)) && !allSelectedOnPage; const toggleAll = (checked: boolean) => { const next = new Set(selectedIds); if (checked) - selectableDocs.forEach((d) => { + sorted.forEach((d) => { next.add(d.id); }); else - selectableDocs.forEach((d) => { + sorted.forEach((d) => { next.delete(d.id); }); setSelectedIds(next); @@ -238,10 +230,9 @@ export function DocumentsTableShell({ const icon = getDocumentTypeIcon(doc.document_type); const title = doc.title; const truncatedTitle = title.length > 30 ? `${title.slice(0, 30)}...` : title; - const isSurfsenseDoc = doc.document_type === "SURFSENSE_DOCS"; return ( !isSurfsenseDoc && toggleOne(doc.id, !!v)} - disabled={isSurfsenseDoc} + checked={selectedIds.has(doc.id)} + onCheckedChange={(v) => toggleOne(doc.id, !!v)} aria-label="Select row" /> diff --git a/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/page.tsx index 54fd490a1..c2ddf6f71 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/page.tsx @@ -20,7 +20,7 @@ import { DocumentsFilters } from "./components/DocumentsFilters"; import { DocumentsTableShell, type SortKey } from "./components/DocumentsTableShell"; import { PaginationControls } from "./components/PaginationControls"; import { ProcessingIndicator } from "./components/ProcessingIndicator"; -import type { ColumnVisibility, Document } from "./components/types"; +import type { ColumnVisibility } from "./components/types"; function useDebounced(value: T, delay = 250) { const [debounced, setDebounced] = useState(value); @@ -60,39 +60,30 @@ export default function DocumentsTable() { const { data: rawTypeCounts } = useAtomValue(documentTypeCountsAtom); const { mutateAsync: deleteDocumentMutation } = useAtomValue(deleteDocumentMutationAtom); - // Filter out SURFSENSE_DOCS from active types for regular documents API - const regularDocumentTypes = useMemo( - () => activeTypes.filter((t) => t !== "SURFSENSE_DOCS"), - [activeTypes] - ); - - // Check if only SURFSENSE_DOCS is selected (skip regular docs query) - const onlySurfsenseDocsSelected = activeTypes.length === 1 && activeTypes[0] === "SURFSENSE_DOCS"; - - // Build query parameters for fetching documents (excluding SURFSENSE_DOCS type) + // Build query parameters for fetching documents const queryParams = useMemo( () => ({ search_space_id: searchSpaceId, page: pageIndex, page_size: pageSize, - ...(regularDocumentTypes.length > 0 && { document_types: regularDocumentTypes }), + ...(activeTypes.length > 0 && { document_types: activeTypes }), }), - [searchSpaceId, pageIndex, pageSize, regularDocumentTypes] + [searchSpaceId, pageIndex, pageSize, activeTypes] ); - // Build search query parameters (excluding SURFSENSE_DOCS type) + // Build search query parameters const searchQueryParams = useMemo( () => ({ search_space_id: searchSpaceId, page: pageIndex, page_size: pageSize, title: debouncedSearch.trim(), - ...(regularDocumentTypes.length > 0 && { document_types: regularDocumentTypes }), + ...(activeTypes.length > 0 && { document_types: activeTypes }), }), - [searchSpaceId, pageIndex, pageSize, regularDocumentTypes, debouncedSearch] + [searchSpaceId, pageIndex, pageSize, activeTypes, debouncedSearch] ); - // Use query for fetching documents (disabled when only SURFSENSE_DOCS is selected) + // Use query for fetching documents const { data: documentsResponse, isLoading: isDocumentsLoading, @@ -102,10 +93,10 @@ export default function DocumentsTable() { queryKey: cacheKeys.documents.globalQueryParams(queryParams), queryFn: () => documentsApiService.getDocuments({ queryParams }), staleTime: 3 * 60 * 1000, // 3 minutes - enabled: !!searchSpaceId && !debouncedSearch.trim() && !onlySurfsenseDocsSelected, + enabled: !!searchSpaceId && !debouncedSearch.trim(), }); - // Use query for searching documents (disabled when only SURFSENSE_DOCS is selected) + // Use query for searching documents const { data: searchResponse, isLoading: isSearchLoading, @@ -115,114 +106,20 @@ export default function DocumentsTable() { queryKey: cacheKeys.documents.globalQueryParams(searchQueryParams), queryFn: () => documentsApiService.searchDocuments({ queryParams: searchQueryParams }), staleTime: 3 * 60 * 1000, // 3 minutes - enabled: !!searchSpaceId && !!debouncedSearch.trim() && !onlySurfsenseDocsSelected, + enabled: !!searchSpaceId && !!debouncedSearch.trim(), }); - // Determine if we should show SurfSense docs (when no type filter or SURFSENSE_DOCS is selected) - const showSurfsenseDocs = - activeTypes.length === 0 || activeTypes.includes("SURFSENSE_DOCS" as DocumentTypeEnum); - - // Use query for fetching SurfSense docs - const { - data: surfsenseDocsResponse, - isLoading: isSurfsenseDocsLoading, - refetch: refetchSurfsenseDocs, - } = useQuery({ - queryKey: ["surfsense-docs", debouncedSearch, pageIndex, pageSize], - queryFn: () => - documentsApiService.getSurfsenseDocs({ - queryParams: { - page: pageIndex, - page_size: pageSize, - title: debouncedSearch.trim() || undefined, - }, - }), - staleTime: 3 * 60 * 1000, // 3 minutes - enabled: showSurfsenseDocs, - }); - - // Transform SurfSense docs to match the Document type - const surfsenseDocsAsDocuments: Document[] = useMemo(() => { - if (!surfsenseDocsResponse?.items) return []; - return surfsenseDocsResponse.items.map((doc) => ({ - id: doc.id, - title: doc.title, - document_type: "SURFSENSE_DOCS", - document_metadata: { source: doc.source }, - content: doc.content, - created_at: doc.created_at || doc.updated_at || new Date().toISOString(), - search_space_id: -1, // Special value for global docs - })); - }, [surfsenseDocsResponse]); - - // Merge type counts with SURFSENSE_DOCS count - const typeCounts = useMemo(() => { - const counts = { ...(rawTypeCounts || {}) }; - if (surfsenseDocsResponse?.total) { - counts.SURFSENSE_DOCS = surfsenseDocsResponse.total; - } - return counts; - }, [rawTypeCounts, surfsenseDocsResponse?.total]); - // Extract documents and total based on search state - const regularDocuments = debouncedSearch.trim() + const documents = debouncedSearch.trim() ? searchResponse?.items || [] : documentsResponse?.items || []; - const regularTotal = debouncedSearch.trim() - ? searchResponse?.total || 0 - : documentsResponse?.total || 0; + const total = debouncedSearch.trim() ? searchResponse?.total || 0 : documentsResponse?.total || 0; - // Merge regular documents with SurfSense docs - const documents = useMemo(() => { - // If filtering by type and not including SURFSENSE_DOCS, only show regular docs - if (activeTypes.length > 0 && !activeTypes.includes("SURFSENSE_DOCS" as DocumentTypeEnum)) { - return regularDocuments; - } - // If filtering only by SURFSENSE_DOCS, only show surfsense docs - if (activeTypes.length === 1 && activeTypes[0] === "SURFSENSE_DOCS") { - return surfsenseDocsAsDocuments; - } - // Otherwise, merge both (surfsense docs first) - return [...surfsenseDocsAsDocuments, ...regularDocuments]; - }, [regularDocuments, surfsenseDocsAsDocuments, activeTypes]); + const loading = debouncedSearch.trim() ? isSearchLoading : isDocumentsLoading; + const error = debouncedSearch.trim() ? searchError : documentsError; - const total = useMemo(() => { - if (activeTypes.length > 0 && !activeTypes.includes("SURFSENSE_DOCS" as DocumentTypeEnum)) { - return regularTotal; - } - if (activeTypes.length === 1 && activeTypes[0] === "SURFSENSE_DOCS") { - return surfsenseDocsResponse?.total || 0; - } - return regularTotal + (surfsenseDocsResponse?.total || 0); - }, [regularTotal, surfsenseDocsResponse?.total, activeTypes]); - - const loading = useMemo(() => { - // If only SURFSENSE_DOCS selected, only check surfsense loading - if (onlySurfsenseDocsSelected) { - return isSurfsenseDocsLoading; - } - // Otherwise check both regular docs and surfsense docs loading - const regularLoading = debouncedSearch.trim() ? isSearchLoading : isDocumentsLoading; - return regularLoading || (showSurfsenseDocs && isSurfsenseDocsLoading); - }, [ - onlySurfsenseDocsSelected, - isSurfsenseDocsLoading, - debouncedSearch, - isSearchLoading, - isDocumentsLoading, - showSurfsenseDocs, - ]); - - const error = useMemo(() => { - // If only SURFSENSE_DOCS selected, no regular docs errors - if (onlySurfsenseDocsSelected) { - return null; - } - return debouncedSearch.trim() ? searchError : documentsError; - }, [onlySurfsenseDocsSelected, debouncedSearch, searchError, documentsError]); - - // Display server-filtered results directly - const displayDocs = documents || []; + // Display results directly + const displayDocs = documents; const displayTotal = total; const pageStart = pageIndex * pageSize; const pageEnd = Math.min(pageStart + pageSize, displayTotal); @@ -242,33 +139,16 @@ export default function DocumentsTable() { if (isRefreshing) return; setIsRefreshing(true); try { - const refetchPromises: Promise[] = []; - // Only refetch regular documents if not in "only surfsense docs" mode - if (!onlySurfsenseDocsSelected) { - if (debouncedSearch.trim()) { - refetchPromises.push(refetchSearch()); - } else { - refetchPromises.push(refetchDocuments()); - } + if (debouncedSearch.trim()) { + await refetchSearch(); + } else { + await refetchDocuments(); } - if (showSurfsenseDocs) { - refetchPromises.push(refetchSurfsenseDocs()); - } - await Promise.all(refetchPromises); toast.success(t("refresh_success") || "Documents refreshed"); } finally { setIsRefreshing(false); } - }, [ - debouncedSearch, - refetchSearch, - refetchDocuments, - refetchSurfsenseDocs, - showSurfsenseDocs, - onlySurfsenseDocsSelected, - t, - isRefreshing, - ]); + }, [debouncedSearch, refetchSearch, refetchDocuments, t, isRefreshing]); // Set up smart polling for active tasks - only polls when tasks are in progress const { summary } = useLogsSummary(searchSpaceId, 24, { @@ -385,7 +265,7 @@ export default function DocumentsTable() { createAttachmentAdapter(), []); @@ -306,12 +323,6 @@ export default function NewChatPage() { if (steps.length > 0) { restoredThinkingSteps.set(`msg-${msg.id}`, steps); } - // Hydrate write_todos plan state from persisted tool calls - // Disabled for now - // const writeTodosCalls = extractWriteTodosFromContent(msg.content); - // for (const todoData of writeTodosCalls) { - // hydratePlanState(todoData); - // } } if (msg.role === "user") { const docs = extractMentionedDocuments(msg.content); @@ -448,13 +459,27 @@ export default function NewChatPage() { // Add user message to state const userMsgId = `msg-user-${Date.now()}`; + + // Include author metadata for shared chats + const authorMetadata = + currentThread?.visibility === "SEARCH_SPACE" && currentUser + ? { + custom: { + author: { + displayName: currentUser.display_name ?? null, + avatarUrl: currentUser.avatar_url ?? null, + }, + }, + } + : undefined; + const userMessage: ThreadMessageLike = { id: userMsgId, role: "user", content: message.content, createdAt: new Date(), - // Include attachments so they can be displayed attachments: message.attachments || [], + metadata: authorMetadata, }; setMessages((prev) => [...prev, userMessage]); @@ -884,6 +909,8 @@ export default function NewChatPage() { setMentionedDocuments, setMessageDocumentsMap, queryClient, + currentThread, + currentUser, ] ); diff --git a/surfsense_web/app/dashboard/user/settings/components/ApiKeyContent.tsx b/surfsense_web/app/dashboard/user/settings/components/ApiKeyContent.tsx new file mode 100644 index 000000000..6bf10a78f --- /dev/null +++ b/surfsense_web/app/dashboard/user/settings/components/ApiKeyContent.tsx @@ -0,0 +1,122 @@ +"use client"; + +import { Check, Copy, Key, Menu, Shield } from "lucide-react"; +import { AnimatePresence, motion } from "motion/react"; +import { useTranslations } from "next-intl"; +import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; +import { Button } from "@/components/ui/button"; +import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; +import { useApiKey } from "@/hooks/use-api-key"; + +interface ApiKeyContentProps { + onMenuClick: () => void; +} + +export function ApiKeyContent({ onMenuClick }: ApiKeyContentProps) { + const t = useTranslations("userSettings"); + const { apiKey, isLoading, copied, copyToClipboard } = useApiKey(); + + return ( + +
+
+ + +
+ + + + +
+

+ {t("api_key_title")} +

+

{t("api_key_description")}

+
+
+
+
+ + + + + + {t("api_key_warning_title")} + {t("api_key_warning_description")} + + +
+

{t("your_api_key")}

+ {isLoading ? ( +
+ ) : apiKey ? ( +
+
+ {apiKey} +
+ + + + + + {copied ? t("copied") : t("copy")} + + +
+ ) : ( +

{t("no_api_key")}

+ )} +
+ +
+

{t("usage_title")}

+

{t("usage_description")}

+
+									Authorization: Bearer {apiKey || "YOUR_API_KEY"}
+								
+
+ + +
+
+ + ); +} diff --git a/surfsense_web/app/dashboard/user/settings/components/ProfileContent.tsx b/surfsense_web/app/dashboard/user/settings/components/ProfileContent.tsx new file mode 100644 index 000000000..511a09fd1 --- /dev/null +++ b/surfsense_web/app/dashboard/user/settings/components/ProfileContent.tsx @@ -0,0 +1,181 @@ +"use client"; + +import { useAtomValue } from "jotai"; +import { Loader2, Menu, User } from "lucide-react"; +import { AnimatePresence, motion } from "motion/react"; +import { useTranslations } from "next-intl"; +import { useEffect, useState } from "react"; +import { toast } from "sonner"; +import { updateUserMutationAtom } from "@/atoms/user/user-mutation.atoms"; +import { currentUserAtom } from "@/atoms/user/user-query.atoms"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; + +interface ProfileContentProps { + onMenuClick: () => void; +} + +function AvatarDisplay({ url, fallback }: { url?: string; fallback: string }) { + const [hasError, setHasError] = useState(false); + + useEffect(() => { + setHasError(false); + }, [url]); + + if (url && !hasError) { + return ( + Avatar setHasError(true)} + /> + ); + } + + return ( +
+ {fallback} +
+ ); +} + +export function ProfileContent({ onMenuClick }: ProfileContentProps) { + const t = useTranslations("userSettings"); + const { data: user, isLoading: isUserLoading } = useAtomValue(currentUserAtom); + const { mutateAsync: updateUser, isPending } = useAtomValue(updateUserMutationAtom); + + const [displayName, setDisplayName] = useState(""); + + useEffect(() => { + if (user) { + setDisplayName(user.display_name || ""); + } + }, [user]); + + const getInitials = (email: string) => { + const name = email.split("@")[0]; + return name.slice(0, 2).toUpperCase(); + }; + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + + try { + await updateUser({ + display_name: displayName || null, + }); + toast.success(t("profile_saved")); + } catch { + toast.error(t("profile_save_error")); + } + }; + + const hasChanges = displayName !== (user?.display_name || ""); + + return ( + +
+
+ + +
+ + + + +
+

+ {t("profile_title")} +

+

{t("profile_description")}

+
+
+
+
+ + + + {isUserLoading ? ( +
+ +
+ ) : ( +
+
+
+
+ + +
+ +
+ + setDisplayName(e.target.value)} + /> +

+ {t("profile_display_name_hint")} +

+
+ +
+ + +
+
+
+ +
+ +
+
+ )} +
+
+
+
+
+ ); +} diff --git a/surfsense_web/app/dashboard/user/settings/components/UserSettingsSidebar.tsx b/surfsense_web/app/dashboard/user/settings/components/UserSettingsSidebar.tsx new file mode 100644 index 000000000..b7040b4e3 --- /dev/null +++ b/surfsense_web/app/dashboard/user/settings/components/UserSettingsSidebar.tsx @@ -0,0 +1,154 @@ +"use client"; + +import type { LucideIcon } from "lucide-react"; +import { ArrowLeft, ChevronRight, X } from "lucide-react"; +import { AnimatePresence, motion } from "motion/react"; +import { useTranslations } from "next-intl"; +import { Button } from "@/components/ui/button"; +import { cn } from "@/lib/utils"; + +export interface SettingsNavItem { + id: string; + label: string; + description: string; + icon: LucideIcon; +} + +interface UserSettingsSidebarProps { + activeSection: string; + onSectionChange: (section: string) => void; + onBackToApp: () => void; + isOpen: boolean; + onClose: () => void; + navItems: SettingsNavItem[]; +} + +export function UserSettingsSidebar({ + activeSection, + onSectionChange, + onBackToApp, + isOpen, + onClose, + navItems, +}: UserSettingsSidebarProps) { + const t = useTranslations("userSettings"); + + const handleNavClick = (sectionId: string) => { + onSectionChange(sectionId); + onClose(); + }; + + return ( + <> + + {isOpen && ( + + )} + + + + + ); +} diff --git a/surfsense_web/app/dashboard/user/settings/page.tsx b/surfsense_web/app/dashboard/user/settings/page.tsx index bf88e65e5..8e04ce37a 100644 --- a/surfsense_web/app/dashboard/user/settings/page.tsx +++ b/surfsense_web/app/dashboard/user/settings/page.tsx @@ -1,286 +1,27 @@ "use client"; -import { - ArrowLeft, - Check, - ChevronRight, - Copy, - Key, - type LucideIcon, - Menu, - Shield, - X, -} from "lucide-react"; -import { AnimatePresence, motion } from "motion/react"; +import { Key, User } from "lucide-react"; +import { motion } from "motion/react"; import { useRouter } from "next/navigation"; import { useTranslations } from "next-intl"; import { useCallback, useState } from "react"; -import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; -import { Button } from "@/components/ui/button"; -import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; -import { useApiKey } from "@/hooks/use-api-key"; -import { cn } from "@/lib/utils"; - -interface SettingsNavItem { - id: string; - label: string; - description: string; - icon: LucideIcon; -} - -function UserSettingsSidebar({ - activeSection, - onSectionChange, - onBackToApp, - isOpen, - onClose, - navItems, -}: { - activeSection: string; - onSectionChange: (section: string) => void; - onBackToApp: () => void; - isOpen: boolean; - onClose: () => void; - navItems: SettingsNavItem[]; -}) { - const t = useTranslations("userSettings"); - - const handleNavClick = (sectionId: string) => { - onSectionChange(sectionId); - onClose(); - }; - - return ( - <> - - {isOpen && ( - - )} - - - - - ); -} - -function ApiKeyContent({ onMenuClick }: { onMenuClick: () => void }) { - const t = useTranslations("userSettings"); - const { apiKey, isLoading, copied, copyToClipboard } = useApiKey(); - - return ( - -
-
- - -
- - - - -
-

- {t("api_key_title")} -

-

{t("api_key_description")}

-
-
-
-
- - - - - - {t("api_key_warning_title")} - {t("api_key_warning_description")} - - -
-

{t("your_api_key")}

- {isLoading ? ( -
- ) : apiKey ? ( -
-
- {apiKey} -
- - - - - - {copied ? t("copied") : t("copy")} - - -
- ) : ( -

{t("no_api_key")}

- )} -
- -
-

{t("usage_title")}

-

{t("usage_description")}

-
-									Authorization: Bearer {apiKey || "YOUR_API_KEY"}
-								
-
- - -
-
- - ); -} +import { ApiKeyContent } from "./components/ApiKeyContent"; +import { ProfileContent } from "./components/ProfileContent"; +import { type SettingsNavItem, UserSettingsSidebar } from "./components/UserSettingsSidebar"; export default function UserSettingsPage() { const t = useTranslations("userSettings"); const router = useRouter(); - const [activeSection, setActiveSection] = useState("api-key"); + const [activeSection, setActiveSection] = useState("profile"); const [isSidebarOpen, setIsSidebarOpen] = useState(false); const navItems: SettingsNavItem[] = [ + { + id: "profile", + label: t("profile_nav_label"), + description: t("profile_nav_description"), + icon: User, + }, { id: "api-key", label: t("api_key_nav_label"), @@ -310,6 +51,9 @@ export default function UserSettingsPage() { onClose={() => setIsSidebarOpen(false)} navItems={navItems} /> + {activeSection === "profile" && ( + setIsSidebarOpen(true)} /> + )} {activeSection === "api-key" && ( setIsSidebarOpen(true)} /> )} diff --git a/surfsense_web/atoms/user/user-mutation.atoms.ts b/surfsense_web/atoms/user/user-mutation.atoms.ts new file mode 100644 index 000000000..caf4436a5 --- /dev/null +++ b/surfsense_web/atoms/user/user-mutation.atoms.ts @@ -0,0 +1,18 @@ +import { atomWithMutation, queryClientAtom } from "jotai-tanstack-query"; +import type { UpdateUserRequest } from "@/contracts/types/user.types"; +import { userApiService } from "@/lib/apis/user-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; + +export const updateUserMutationAtom = atomWithMutation((get) => { + const queryClient = get(queryClientAtom); + + return { + mutationKey: cacheKeys.user.current(), + mutationFn: async (request: UpdateUserRequest) => { + return userApiService.updateMe(request); + }, + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: cacheKeys.user.current() }); + }, + }; +}); diff --git a/surfsense_web/components/assistant-ui/document-upload-popup.tsx b/surfsense_web/components/assistant-ui/document-upload-popup.tsx index 453c6abde..1023c5c40 100644 --- a/surfsense_web/components/assistant-ui/document-upload-popup.tsx +++ b/surfsense_web/components/assistant-ui/document-upload-popup.tsx @@ -96,38 +96,37 @@ const DocumentUploadPopupContent: FC<{ return ( - + Upload Document - {/* Fixed Header */} -
- {/* Upload header */} -
-
- -
-
-

Upload Documents

-

- Upload and sync your documents to your search space -

+ {/* Scrollable container for mobile */} +
+ {/* Header - scrolls with content on mobile */} +
+ {/* Upload header */} +
+
+ +
+
+

+ Upload Documents +

+

+ Upload and sync your documents to your search space +

+
+ + {/* Content */} +
+ +
- {/* Scrollable Content */} -
-
-
- -
-
- {/* Bottom fade shadow */} -
-
+ {/* Bottom fade shadow - hidden on very small screens */} +
); diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index bf46e3d97..9f844ba2b 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -19,9 +19,7 @@ import { ChevronRightIcon, CopyIcon, DownloadIcon, - FileText, Loader2, - PencilIcon, RefreshCwIcon, SquareIcon, } from "lucide-react"; @@ -31,7 +29,6 @@ import { createPortal } from "react-dom"; import { mentionedDocumentIdsAtom, mentionedDocumentsAtom, - messageDocumentsMapAtom, } from "@/atoms/chat/mentioned-documents.atom"; import { globalNewLLMConfigsAtom, @@ -39,11 +36,7 @@ import { newLLMConfigsAtom, } from "@/atoms/new-llm-config/new-llm-config-query.atoms"; import { currentUserAtom } from "@/atoms/user/user-query.atoms"; -import { - ComposerAddAttachment, - ComposerAttachments, - UserMessageAttachments, -} from "@/components/assistant-ui/attachment"; +import { ComposerAddAttachment, ComposerAttachments } from "@/components/assistant-ui/attachment"; import { ConnectorIndicator } from "@/components/assistant-ui/connector-popup"; import { InlineMentionEditor, @@ -56,6 +49,7 @@ import { } from "@/components/assistant-ui/thinking-steps"; import { ToolFallback } from "@/components/assistant-ui/tool-fallback"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; +import { UserMessage } from "@/components/assistant-ui/user-message"; import { DocumentMentionPicker, type DocumentMentionPickerRef, @@ -639,70 +633,6 @@ const AssistantActionBar: FC = () => { ); }; -const UserMessage: FC = () => { - const messageId = useAssistantState(({ message }) => message?.id); - const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom); - const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined; - const hasAttachments = useAssistantState( - ({ message }) => message?.attachments && message.attachments.length > 0 - ); - - return ( - -
- {/* Display attachments and mentioned documents */} - {(hasAttachments || (mentionedDocs && mentionedDocs.length > 0)) && ( -
- {/* Attachments (images show as thumbnails, documents as chips) */} - - {/* Mentioned documents as chips */} - {mentionedDocs?.map((doc) => ( - - - {doc.title} - - ))} -
- )} - {/* Message bubble with action bar positioned relative to it */} -
-
- -
-
- -
-
-
- - -
- ); -}; - -const UserActionBar: FC = () => { - return ( - - - - - - - - ); -}; - const EditComposer: FC = () => { return ( diff --git a/surfsense_web/components/assistant-ui/user-message.tsx b/surfsense_web/components/assistant-ui/user-message.tsx index 745542304..15b5461b6 100644 --- a/surfsense_web/components/assistant-ui/user-message.tsx +++ b/surfsense_web/components/assistant-ui/user-message.tsx @@ -1,16 +1,54 @@ import { ActionBarPrimitive, MessagePrimitive, useAssistantState } from "@assistant-ui/react"; import { useAtomValue } from "jotai"; import { FileText, PencilIcon } from "lucide-react"; -import type { FC } from "react"; +import { type FC, useState } from "react"; import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom"; import { UserMessageAttachments } from "@/components/assistant-ui/attachment"; import { BranchPicker } from "@/components/assistant-ui/branch-picker"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; +interface AuthorMetadata { + displayName: string | null; + avatarUrl: string | null; +} + +const UserAvatar: FC = ({ displayName, avatarUrl }) => { + const [hasError, setHasError] = useState(false); + + const initials = displayName + ? displayName + .split(" ") + .map((n) => n[0]) + .join("") + .toUpperCase() + .slice(0, 2) + : "U"; + + if (avatarUrl && !hasError) { + return ( + {displayName setHasError(true)} + /> + ); + } + + return ( +
+ {initials} +
+ ); +}; + export const UserMessage: FC = () => { const messageId = useAssistantState(({ message }) => message?.id); const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom); const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined; + const metadata = useAssistantState(({ message }) => message?.metadata); + const author = metadata?.custom?.author as AuthorMetadata | undefined; const hasAttachments = useAssistantState( ({ message }) => message?.attachments && message.attachments.length > 0 ); @@ -20,34 +58,42 @@ export const UserMessage: FC = () => { className="aui-user-message-root fade-in slide-in-from-bottom-1 mx-auto grid w-full max-w-(--thread-max-width) animate-in auto-rows-auto grid-cols-[minmax(72px,1fr)_auto] content-start gap-y-2 px-2 py-3 duration-150 [&:where(>*)]:col-start-2" data-role="user" > -
- {/* Display attachments and mentioned documents */} - {(hasAttachments || (mentionedDocs && mentionedDocs.length > 0)) && ( -
- {/* Attachments (images show as thumbnails, documents as chips) */} - - {/* Mentioned documents as chips */} - {mentionedDocs?.map((doc) => ( - - - {doc.title} - - ))} -
- )} - {/* Message bubble with action bar positioned relative to it */} -
-
- -
-
- +
+
+ {/* Display attachments and mentioned documents */} + {(hasAttachments || (mentionedDocs && mentionedDocs.length > 0)) && ( +
+ {/* Attachments (images show as thumbnails, documents as chips) */} + + {/* Mentioned documents as chips */} + {mentionedDocs?.map((doc) => ( + + + {doc.title} + + ))} +
+ )} + {/* Message bubble with action bar positioned relative to it */} +
+
+ +
+
+ +
+ {/* User avatar - only shown in shared chats */} + {author && ( +
+ +
+ )}
diff --git a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx index 3d4e5630d..95ff5d782 100644 --- a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx +++ b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx @@ -354,7 +354,11 @@ export function LayoutDataProvider({ onChatDelete={handleChatDelete} onViewAllSharedChats={handleViewAllSharedChats} onViewAllPrivateChats={handleViewAllPrivateChats} - user={{ email: user?.email || "", name: user?.email?.split("@")[0] }} + user={{ + email: user?.email || "", + name: user?.display_name || user?.email?.split("@")[0], + avatarUrl: user?.avatar_url || undefined, + }} onSettings={handleSettings} onManageMembers={handleManageMembers} onUserSettings={handleUserSettings} diff --git a/surfsense_web/components/layout/types/layout.types.ts b/surfsense_web/components/layout/types/layout.types.ts index 73ac98fa5..3eac64e60 100644 --- a/surfsense_web/components/layout/types/layout.types.ts +++ b/surfsense_web/components/layout/types/layout.types.ts @@ -12,6 +12,7 @@ export interface SearchSpace { export interface User { email: string; name?: string; + avatarUrl?: string; } export interface NavItem { diff --git a/surfsense_web/components/layout/ui/sidebar/SidebarUserProfile.tsx b/surfsense_web/components/layout/ui/sidebar/SidebarUserProfile.tsx index d3e97c8eb..f67dbf7c6 100644 --- a/surfsense_web/components/layout/ui/sidebar/SidebarUserProfile.tsx +++ b/surfsense_web/components/layout/ui/sidebar/SidebarUserProfile.tsx @@ -61,6 +61,39 @@ function getInitials(email: string): string { return name.slice(0, 2).toUpperCase(); } +/** + * User avatar component - shows image if available, otherwise falls back to initials + */ +function UserAvatar({ + avatarUrl, + initials, + bgColor, +}: { + avatarUrl?: string; + initials: string; + bgColor: string; +}) { + if (avatarUrl) { + return ( + User avatar + ); + } + + return ( +
+ {initials} +
+ ); +} + export function SidebarUserProfile({ user, onUserSettings, @@ -88,12 +121,7 @@ export function SidebarUserProfile({ "focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring" )} > -
- {initials} -
+ {displayName} @@ -104,12 +132,7 @@ export function SidebarUserProfile({
-
- {initials} -
+

{displayName}

{user.email}

@@ -149,13 +172,7 @@ export function SidebarUserProfile({ "focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring" )} > - {/* Avatar */} -
- {initials} -
+ {/* Name and email */}
@@ -171,12 +188,7 @@ export function SidebarUserProfile({
-
- {initials} -
+

{displayName}

{user.email}

diff --git a/surfsense_web/components/new-chat/document-mention-picker.tsx b/surfsense_web/components/new-chat/document-mention-picker.tsx index e89885b1d..ba9e4ea95 100644 --- a/surfsense_web/components/new-chat/document-mention-picker.tsx +++ b/surfsense_web/components/new-chat/document-mention-picker.tsx @@ -215,6 +215,16 @@ export const DocumentMentionPicker = forwardRef< isSurfsenseDocsLoading) && currentPage === 0; + // Split documents into SurfSense docs and user docs for grouped rendering + const surfsenseDocsList = useMemo( + () => actualDocuments.filter((doc) => doc.document_type === "SURFSENSE_DOCS"), + [actualDocuments] + ); + const userDocsList = useMemo( + () => actualDocuments.filter((doc) => doc.document_type !== "SURFSENSE_DOCS"), + [actualDocuments] + ); + // Track already selected documents using unique key (document_type:id) to avoid ID collisions const selectedKeys = useMemo( () => new Set(initialSelectedDocuments.map((d) => `${d.document_type}:${d.id}`)), @@ -324,47 +334,102 @@ export const DocumentMentionPicker = forwardRef<
) : (
- {actualDocuments.map((doc) => { - const docKey = `${doc.document_type}:${doc.id}`; - const isAlreadySelected = selectedKeys.has(docKey); - const selectableIndex = selectableDocuments.findIndex( - (d) => d.document_type === doc.document_type && d.id === doc.id - ); - const isHighlighted = !isAlreadySelected && selectableIndex === highlightedIndex; + {/* SurfSense Documentation Section */} + {surfsenseDocsList.length > 0 && ( + <> +
+ SurfSense Docs +
+ {surfsenseDocsList.map((doc) => { + const docKey = `${doc.document_type}:${doc.id}`; + const isAlreadySelected = selectedKeys.has(docKey); + const selectableIndex = selectableDocuments.findIndex( + (d) => d.document_type === doc.document_type && d.id === doc.id + ); + const isHighlighted = !isAlreadySelected && selectableIndex === highlightedIndex; + + return ( + + ); + })} + + )} + + {/* User Documents Section */} + {userDocsList.length > 0 && ( + <> +
+ Your Documents +
+ {userDocsList.map((doc) => { + const docKey = `${doc.document_type}:${doc.id}`; + const isAlreadySelected = selectedKeys.has(docKey); + const selectableIndex = selectableDocuments.findIndex( + (d) => d.document_type === doc.document_type && d.id === doc.id + ); + const isHighlighted = !isAlreadySelected && selectableIndex === highlightedIndex; + + return ( + + ); + })} + + )} - return ( - - ); - })} {/* Loading indicator for additional pages */} {isLoadingMore && (
diff --git a/surfsense_web/components/sources/DocumentUploadTab.tsx b/surfsense_web/components/sources/DocumentUploadTab.tsx index 0b7f7b51f..cc27d326a 100644 --- a/surfsense_web/components/sources/DocumentUploadTab.tsx +++ b/surfsense_web/components/sources/DocumentUploadTab.tsx @@ -110,6 +110,11 @@ const FILE_TYPE_CONFIG: Record> = { const cardClass = "border border-border bg-slate-400/5 dark:bg-white/5"; +// Upload limits +const MAX_FILES = 10; +const MAX_TOTAL_SIZE_MB = 200; +const MAX_TOTAL_SIZE_BYTES = MAX_TOTAL_SIZE_MB * 1024 * 1024; + export function DocumentUploadTab({ searchSpaceId, onSuccess, @@ -134,15 +139,40 @@ export function DocumentUploadTab({ [acceptedFileTypes] ); - const onDrop = useCallback((acceptedFiles: File[]) => { - setFiles((prev) => [...prev, ...acceptedFiles]); - }, []); + const onDrop = useCallback( + (acceptedFiles: File[]) => { + setFiles((prev) => { + const newFiles = [...prev, ...acceptedFiles]; + + // Check file count limit + if (newFiles.length > MAX_FILES) { + toast.error(t("max_files_exceeded"), { + description: t("max_files_exceeded_desc", { max: MAX_FILES }), + }); + return prev; + } + + // Check total size limit + const newTotalSize = newFiles.reduce((sum, file) => sum + file.size, 0); + if (newTotalSize > MAX_TOTAL_SIZE_BYTES) { + toast.error(t("max_size_exceeded"), { + description: t("max_size_exceeded_desc", { max: MAX_TOTAL_SIZE_MB }), + }); + return prev; + } + + return newFiles; + }); + }, + [t] + ); const { getRootProps, getInputProps, isDragActive } = useDropzone({ onDrop, accept: acceptedFileTypes, - maxSize: 50 * 1024 * 1024, + maxSize: 50 * 1024 * 1024, // 50MB per file noClick: false, + disabled: files.length >= MAX_FILES, }); // Handle file input click to prevent event bubbling that might reopen dialog @@ -160,6 +190,15 @@ export function DocumentUploadTab({ const totalFileSize = files.reduce((total, file) => total + file.size, 0); + // Check if limits are reached + const isFileCountLimitReached = files.length >= MAX_FILES; + const isSizeLimitReached = totalFileSize >= MAX_TOTAL_SIZE_BYTES; + const remainingFiles = MAX_FILES - files.length; + const remainingSizeMB = Math.max( + 0, + (MAX_TOTAL_SIZE_BYTES - totalFileSize) / (1024 * 1024) + ).toFixed(1); + // Track accordion state changes const handleAccordionChange = useCallback( (value: string) => { @@ -210,7 +249,8 @@ export function DocumentUploadTab({ - {t("file_size_limit")} + {t("file_size_limit")}{" "} + {t("upload_limits", { maxFiles: MAX_FILES, maxSizeMB: MAX_TOTAL_SIZE_MB })} @@ -221,7 +261,11 @@ export function DocumentUploadTab({
- {isDragActive ? ( + {isFileCountLimitReached ? ( +
+ +
+

+ {t("file_limit_reached")} +

+

+ {t("file_limit_reached_desc", { max: MAX_FILES })} +

+
+
+ ) : isDragActive ? ( {t("drag_drop")}

{t("or_browse")}

+ {files.length > 0 && ( +

+ {t("remaining_capacity", { files: remainingFiles, sizeMB: remainingSizeMB })} +

+ )} +
+ )} + {!isFileCountLimitReached && ( +
+
)} -
- -
diff --git a/surfsense_web/contracts/types/user.types.ts b/surfsense_web/contracts/types/user.types.ts index f2d1f0ffc..85fee49a8 100644 --- a/surfsense_web/contracts/types/user.types.ts +++ b/surfsense_web/contracts/types/user.types.ts @@ -8,6 +8,8 @@ export const user = z.object({ is_verified: z.boolean(), pages_limit: z.number(), pages_used: z.number(), + display_name: z.string().nullish(), + avatar_url: z.string().nullish(), }); /** @@ -15,5 +17,20 @@ export const user = z.object({ */ export const getMeResponse = user; +/** + * Update current user request + */ +export const updateUserRequest = z.object({ + display_name: z.string().nullish(), + avatar_url: z.string().nullish(), +}); + +/** + * Update current user response + */ +export const updateUserResponse = user; + export type User = z.infer; export type GetMeResponse = z.infer; +export type UpdateUserRequest = z.infer; +export type UpdateUserResponse = z.infer; diff --git a/surfsense_web/lib/apis/user-api.service.ts b/surfsense_web/lib/apis/user-api.service.ts index ea46ac116..083fd8dee 100644 --- a/surfsense_web/lib/apis/user-api.service.ts +++ b/surfsense_web/lib/apis/user-api.service.ts @@ -1,4 +1,8 @@ -import { getMeResponse } from "@/contracts/types/user.types"; +import { + getMeResponse, + type UpdateUserRequest, + updateUserResponse, +} from "@/contracts/types/user.types"; import { baseApiService } from "./base-api.service"; class UserApiService { @@ -8,6 +12,15 @@ class UserApiService { getMe = async () => { return baseApiService.get(`/users/me`, getMeResponse); }; + + /** + * Update current authenticated user + */ + updateMe = async (request: UpdateUserRequest) => { + return baseApiService.patch(`/users/me`, updateUserResponse, { + body: request, + }); + }; } export const userApiService = new UserApiService(); diff --git a/surfsense_web/lib/chat/thread-persistence.ts b/surfsense_web/lib/chat/thread-persistence.ts index 5c65ad47e..23dd35800 100644 --- a/surfsense_web/lib/chat/thread-persistence.ts +++ b/surfsense_web/lib/chat/thread-persistence.ts @@ -31,6 +31,9 @@ export interface MessageRecord { role: "user" | "assistant" | "system"; content: unknown; created_at: string; + author_id?: string | null; + author_display_name?: string | null; + author_avatar_url?: string | null; } export interface ThreadListResponse { diff --git a/surfsense_web/lib/env-config.ts b/surfsense_web/lib/env-config.ts index 5e35b160c..6201a0425 100644 --- a/surfsense_web/lib/env-config.ts +++ b/surfsense_web/lib/env-config.ts @@ -1,10 +1,10 @@ /** * Environment configuration for the frontend. - * + * * This file centralizes access to NEXT_PUBLIC_* environment variables. * For Docker deployments, these placeholders are replaced at container startup * via sed in the entrypoint script. - * + * * IMPORTANT: Do not use template literals or complex expressions with these values * as it may prevent the sed replacement from working correctly. */ @@ -24,5 +24,5 @@ export const ETL_SERVICE = process.env.NEXT_PUBLIC_ETL_SERVICE || "DOCLING"; // Helper to check if local auth is enabled export const isLocalAuth = () => AUTH_TYPE === "LOCAL"; -// Helper to check if Google auth is enabled +// Helper to check if Google auth is enabled export const isGoogleAuth = () => AUTH_TYPE === "GOOGLE"; diff --git a/surfsense_web/messages/en.json b/surfsense_web/messages/en.json index ae7f98843..b6eaf8824 100644 --- a/surfsense_web/messages/en.json +++ b/surfsense_web/messages/en.json @@ -109,6 +109,17 @@ "title": "User Settings", "description": "Manage your account settings and API access", "back_to_app": "Back to app", + "profile_nav_label": "Profile", + "profile_nav_description": "Manage your display name and avatar", + "profile_title": "Profile", + "profile_description": "Update your personal information", + "profile_avatar": "Profile Picture", + "profile_display_name": "Display Name", + "profile_display_name_hint": "This is how your name appears across the app", + "profile_email": "Email", + "profile_save": "Save Changes", + "profile_saved": "Profile updated successfully", + "profile_save_error": "Failed to update profile", "api_key_nav_label": "API Key", "api_key_nav_description": "Manage your API access token", "api_key_title": "API Key", @@ -367,6 +378,7 @@ "title": "Upload Documents", "subtitle": "Upload your files to make them searchable and accessible through AI-powered conversations.", "file_size_limit": "Maximum file size: 50MB per file. Supported formats vary based on your ETL service configuration.", + "upload_limits": "Upload limit: {maxFiles} files, {maxSizeMB}MB total.", "drop_files": "Drop files here", "drag_drop": "Drag & drop files here", "or_browse": "or click to browse", @@ -382,7 +394,14 @@ "upload_error": "Upload Error", "upload_error_desc": "Error uploading files", "supported_file_types": "Supported File Types", - "file_types_desc": "These file types are supported based on your current ETL service configuration." + "file_types_desc": "These file types are supported based on your current ETL service configuration.", + "max_files_exceeded": "File Limit Exceeded", + "max_files_exceeded_desc": "You can upload a maximum of {max} files at a time.", + "max_size_exceeded": "Size Limit Exceeded", + "max_size_exceeded_desc": "Total file size cannot exceed {max}MB.", + "file_limit_reached": "Maximum Files Reached", + "file_limit_reached_desc": "Remove some files to add more (max {max} files).", + "remaining_capacity": "{files} files remaining • {sizeMB}MB available" }, "add_webpage": { "title": "Add Webpages for Crawling", diff --git a/surfsense_web/messages/zh.json b/surfsense_web/messages/zh.json index 1404c176f..b48e3e9c7 100644 --- a/surfsense_web/messages/zh.json +++ b/surfsense_web/messages/zh.json @@ -363,6 +363,7 @@ "title": "上传文档", "subtitle": "上传您的文件,使其可通过 AI 对话进行搜索和访问。", "file_size_limit": "最大文件大小:每个文件 50MB。支持的格式因您的 ETL 服务配置而异。", + "upload_limits": "上传限制:最多 {maxFiles} 个文件,总大小不超过 {maxSizeMB}MB。", "drop_files": "放下文件到这里", "drag_drop": "拖放文件到这里", "or_browse": "或点击浏览", @@ -378,7 +379,14 @@ "upload_error": "上传错误", "upload_error_desc": "上传文件时出错", "supported_file_types": "支持的文件类型", - "file_types_desc": "根据您当前的 ETL 服务配置支持这些文件类型。" + "file_types_desc": "根据您当前的 ETL 服务配置支持这些文件类型。", + "max_files_exceeded": "超过文件数量限制", + "max_files_exceeded_desc": "一次最多只能上传 {max} 个文件。", + "max_size_exceeded": "超过文件大小限制", + "max_size_exceeded_desc": "文件总大小不能超过 {max}MB。", + "file_limit_reached": "已达到最大文件数量", + "file_limit_reached_desc": "移除一些文件以添加更多(最多 {max} 个文件)。", + "remaining_capacity": "剩余 {files} 个文件名额 • 可用 {sizeMB}MB" }, "add_webpage": { "title": "添加网页爬取",