diff --git a/surfsense_backend/alembic/versions/0_initial_schema.py b/surfsense_backend/alembic/versions/0_initial_schema.py new file mode 100644 index 000000000..77bd9dd1b --- /dev/null +++ b/surfsense_backend/alembic/versions/0_initial_schema.py @@ -0,0 +1,54 @@ +"""Initial schema setup + +Revision ID: 0 +Revises: None + +Creates all tables from SQLAlchemy models. Idempotent - safe to run on existing databases. +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +revision: str = "0" +down_revision: str | None = None +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + from app.db import Base + + connection = op.get_bind() + + # Create tables + op.execute(sa.text("CREATE EXTENSION IF NOT EXISTS vector")) + Base.metadata.create_all(bind=connection) + + # Set up indexes + op.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)" + ) + ) + op.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector('english', content))" + ) + ) + op.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)" + ) + ) + op.execute( + sa.text( + "CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector('english', content))" + ) + ) + + +def downgrade() -> None: + pass diff --git a/surfsense_backend/alembic/versions/10_update_chattype_enum_to_qna_report_structure.py b/surfsense_backend/alembic/versions/10_update_chattype_enum_to_qna_report_structure.py index 665585a85..a4f6db0b8 100644 --- a/surfsense_backend/alembic/versions/10_update_chattype_enum_to_qna_report_structure.py +++ b/surfsense_backend/alembic/versions/10_update_chattype_enum_to_qna_report_structure.py @@ -6,6 +6,8 @@ Revises: 9 from collections.abc import Sequence +import sqlalchemy as sa + from alembic import op # revision identifiers, used by Alembic. @@ -18,9 +20,37 @@ depends_on: str | Sequence[str] | None = None CHAT_TYPE_ENUM = "chattype" +def enum_exists(enum_name: str) -> bool: + """Check if an enum type exists in the database.""" + conn = op.get_bind() + result = conn.execute( + sa.text( + "SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)" + ), + {"enum_name": enum_name}, + ) + return result.scalar() + + +def table_exists(table_name: str) -> bool: + """Check if a table exists in the database.""" + conn = op.get_bind() + result = conn.execute( + sa.text( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)" + ), + {"table_name": table_name}, + ) + return result.scalar() + + def upgrade() -> None: """Upgrade schema - replace ChatType enum values with new QNA/REPORT structure.""" + # Skip if chats table or chattype enum doesn't exist (fresh database) + if not table_exists("chats") or not enum_exists(CHAT_TYPE_ENUM): + return + # Old enum name for temporary storage old_enum_name = f"{CHAT_TYPE_ENUM}_old" @@ -72,6 +102,10 @@ def upgrade() -> None: def downgrade() -> None: """Downgrade schema - revert ChatType enum to old GENERAL/DEEP/DEEPER/DEEPEST structure.""" + # Skip if chats table or chattype enum doesn't exist + if not table_exists("chats") or not enum_exists(CHAT_TYPE_ENUM): + return + # Old enum name for temporary storage old_enum_name = f"{CHAT_TYPE_ENUM}_old" diff --git a/surfsense_backend/alembic/versions/1_add_github_connector_enum.py b/surfsense_backend/alembic/versions/1_add_github_connector_enum.py index 235908b1f..6f3ee2a01 100644 --- a/surfsense_backend/alembic/versions/1_add_github_connector_enum.py +++ b/surfsense_backend/alembic/versions/1_add_github_connector_enum.py @@ -7,22 +7,36 @@ Revises: from collections.abc import Sequence +import sqlalchemy as sa + from alembic import op -# Import pgvector if needed for other types, though not for this ENUM change -# import pgvector - - # revision identifiers, used by Alembic. revision: str = "1" -down_revision: str | None = None +down_revision: str | None = "0" branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None +def enum_exists(enum_name: str) -> bool: + """Check if an enum type exists in the database.""" + conn = op.get_bind() + result = conn.execute( + sa.text( + "SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)" + ), + {"enum_name": enum_name}, + ) + return result.scalar() + + def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### + # Skip if the enum doesn't exist (fresh DB after downgrade - create_db_and_tables will handle it) + if not enum_exists("searchsourceconnectortype"): + return + # Manually add the command to add the enum value # Note: It's generally better to let autogenerate handle this, but we're bypassing it op.execute( @@ -51,6 +65,10 @@ END$$; def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### + # Skip if the enum doesn't exist + if not enum_exists("searchsourceconnectortype"): + return + # Downgrading removal of an enum value is complex and potentially dangerous # if the value is in use. Often omitted or requires manual SQL based on context. # For now, we'll just pass. If you needed to reverse this, you'd likely diff --git a/surfsense_backend/alembic/versions/24_fix_null_chat_types.py b/surfsense_backend/alembic/versions/24_fix_null_chat_types.py index e0d371f1e..e513605f0 100644 --- a/surfsense_backend/alembic/versions/24_fix_null_chat_types.py +++ b/surfsense_backend/alembic/versions/24_fix_null_chat_types.py @@ -7,6 +7,8 @@ Revises: 23 from collections.abc import Sequence +import sqlalchemy as sa + from alembic import op # revision identifiers, used by Alembic. @@ -16,11 +18,27 @@ branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None +def table_exists(table_name: str) -> bool: + """Check if a table exists in the database.""" + conn = op.get_bind() + result = conn.execute( + sa.text( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)" + ), + {"table_name": table_name}, + ) + return result.scalar() + + def upgrade() -> None: """ Fix any chats with NULL type values by setting them to QNA. This handles edge cases from previous migrations where type values were not properly migrated. """ + # Skip if chats table doesn't exist (fresh database) + if not table_exists("chats"): + return + # Update any NULL type values to QNA (the default chat type) op.execute( """ diff --git a/surfsense_backend/alembic/versions/34_add_podcast_staleness_detection.py b/surfsense_backend/alembic/versions/34_add_podcast_staleness_detection.py index 4991cd58e..74bb7fe86 100644 --- a/surfsense_backend/alembic/versions/34_add_podcast_staleness_detection.py +++ b/surfsense_backend/alembic/versions/34_add_podcast_staleness_detection.py @@ -10,6 +10,8 @@ Revises: 33 from collections.abc import Sequence +import sqlalchemy as sa + from alembic import op # revision identifiers @@ -19,42 +21,59 @@ branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None +def table_exists(table_name: str) -> bool: + """Check if a table exists in the database.""" + conn = op.get_bind() + result = conn.execute( + sa.text( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)" + ), + {"table_name": table_name}, + ) + return result.scalar() + + def upgrade() -> None: """Add columns only if they don't already exist (safe for re-runs).""" # Add 'state_version' column to chats table (default 1) - op.execute(""" - ALTER TABLE chats - ADD COLUMN IF NOT EXISTS state_version BIGINT DEFAULT 1 NOT NULL - """) + # Skip if chats table doesn't exist (fresh database) + if table_exists("chats"): + op.execute(""" + ALTER TABLE chats + ADD COLUMN IF NOT EXISTS state_version BIGINT DEFAULT 1 NOT NULL + """) # Add 'chat_state_version' column to podcasts table - op.execute(""" - ALTER TABLE podcasts - ADD COLUMN IF NOT EXISTS chat_state_version BIGINT - """) + if table_exists("podcasts"): + op.execute(""" + ALTER TABLE podcasts + ADD COLUMN IF NOT EXISTS chat_state_version BIGINT + """) - # Add 'chat_id' column to podcasts table - op.execute(""" - ALTER TABLE podcasts - ADD COLUMN IF NOT EXISTS chat_id INTEGER - """) + # Add 'chat_id' column to podcasts table + op.execute(""" + ALTER TABLE podcasts + ADD COLUMN IF NOT EXISTS chat_id INTEGER + """) def downgrade() -> None: """Remove columns only if they exist.""" - op.execute(""" - ALTER TABLE podcasts - DROP COLUMN IF EXISTS chat_state_version - """) + if table_exists("podcasts"): + op.execute(""" + ALTER TABLE podcasts + DROP COLUMN IF EXISTS chat_state_version + """) - op.execute(""" - ALTER TABLE podcasts - DROP COLUMN IF EXISTS chat_id - """) + op.execute(""" + ALTER TABLE podcasts + DROP COLUMN IF EXISTS chat_id + """) - op.execute(""" - ALTER TABLE chats - DROP COLUMN IF EXISTS state_version - """) + if table_exists("chats"): + op.execute(""" + ALTER TABLE chats + DROP COLUMN IF EXISTS state_version + """) diff --git a/surfsense_backend/alembic/versions/49_migrate_old_chats_to_new_chat.py b/surfsense_backend/alembic/versions/49_migrate_old_chats_to_new_chat.py index 61a3ddb48..ef38add26 100644 --- a/surfsense_backend/alembic/versions/49_migrate_old_chats_to_new_chat.py +++ b/surfsense_backend/alembic/versions/49_migrate_old_chats_to_new_chat.py @@ -62,8 +62,25 @@ def parse_timestamp(ts, fallback): return fallback +def table_exists(table_name: str) -> bool: + """Check if a table exists in the database.""" + conn = op.get_bind() + result = conn.execute( + sa.text( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)" + ), + {"table_name": table_name}, + ) + return result.scalar() + + def upgrade() -> None: """Migrate old chats to new_chat_threads and remove old tables.""" + # Skip if chats table doesn't exist (fresh database) + if not table_exists("chats"): + print("[Migration 49] Chats table does not exist, skipping migration") + return + connection = op.get_bind() # Get all old chats @@ -176,36 +193,49 @@ def upgrade() -> None: print("[Migration 49] Migration complete!") +def enum_exists(enum_name: str) -> bool: + """Check if an enum type exists in the database.""" + conn = op.get_bind() + result = conn.execute( + sa.text( + "SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)" + ), + {"enum_name": enum_name}, + ) + return result.scalar() + + def downgrade() -> None: """Recreate old chats table (data cannot be restored).""" - # Recreate chattype enum + # Skip if chats table already exists + if table_exists("chats"): + print("[Migration 49 Downgrade] Chats table already exists, skipping") + return + + # Recreate chattype enum if it doesn't exist + if not enum_exists("chattype"): + op.execute( + sa.text(""" + CREATE TYPE chattype AS ENUM ('QNA') + """) + ) + + # Recreate chats table using raw SQL to avoid SQLAlchemy trying to create the enum op.execute( sa.text(""" - CREATE TYPE chattype AS ENUM ('QNA') + CREATE TABLE chats ( + id SERIAL PRIMARY KEY, + type chattype NOT NULL, + title VARCHAR NOT NULL, + initial_connectors VARCHAR[], + messages JSON NOT NULL, + state_version BIGINT NOT NULL DEFAULT 1, + search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW() + ) """) ) - - # Recreate chats table - op.create_table( - "chats", - sa.Column("id", sa.Integer(), primary_key=True, index=True), - sa.Column("type", sa.Enum("QNA", name="chattype"), nullable=False), - sa.Column("title", sa.String(), nullable=False, index=True), - sa.Column("initial_connectors", sa.ARRAY(sa.String()), nullable=True), - sa.Column("messages", sa.JSON(), nullable=False), - sa.Column("state_version", sa.BigInteger(), nullable=False, default=1), - sa.Column( - "search_space_id", - sa.Integer(), - sa.ForeignKey("searchspaces.id", ondelete="CASCADE"), - nullable=False, - ), - sa.Column( - "created_at", - sa.TIMESTAMP(timezone=True), - nullable=False, - server_default=sa.func.now(), - ), - ) + op.execute(sa.text("CREATE INDEX ix_chats_id ON chats (id)")) + op.execute(sa.text("CREATE INDEX ix_chats_title ON chats (title)")) print("[Migration 49 Downgrade] Chats table recreated (data not restored)") diff --git a/surfsense_backend/alembic/versions/5263aa4e7f94_allow_multiple_connectors_with_unique_.py b/surfsense_backend/alembic/versions/5263aa4e7f94_allow_multiple_connectors_with_unique_.py new file mode 100644 index 000000000..de9505e3a --- /dev/null +++ b/surfsense_backend/alembic/versions/5263aa4e7f94_allow_multiple_connectors_with_unique_.py @@ -0,0 +1,50 @@ +"""allow_multiple_connectors_with_unique_names + +Revision ID: 5263aa4e7f94 +Revises: a1b2c3d4e5f6 +Create Date: 2026-01-13 12:23:31.481643 + +""" +from collections.abc import Sequence + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '5263aa4e7f94' +down_revision: str | None = 'a1b2c3d4e5f6' +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # Drop the old unique constraint + op.drop_constraint( + 'uq_searchspace_user_connector_type', + 'search_source_connectors', + type_='unique' + ) + + # Create new unique constraint that includes name + op.create_unique_constraint( + 'uq_searchspace_user_connector_type_name', + 'search_source_connectors', + ['search_space_id', 'user_id', 'connector_type', 'name'] + ) + + +def downgrade() -> None: + """Downgrade schema.""" + # Drop the new constraint + op.drop_constraint( + 'uq_searchspace_user_connector_type_name', + 'search_source_connectors', + type_='unique' + ) + + # Restore the old constraint + op.create_unique_constraint( + 'uq_searchspace_user_connector_type', + 'search_source_connectors', + ['search_space_id', 'user_id', 'connector_type'] + ) diff --git a/surfsense_backend/alembic/versions/52_rename_llm_preference_columns.py b/surfsense_backend/alembic/versions/52_rename_llm_preference_columns.py index cd1a1dbbc..08177ca70 100644 --- a/surfsense_backend/alembic/versions/52_rename_llm_preference_columns.py +++ b/surfsense_backend/alembic/versions/52_rename_llm_preference_columns.py @@ -39,7 +39,7 @@ def upgrade(): """ ) - # Rename columns (only if they exist with old names) + # Rename columns (only if source exists and target doesn't already exist) op.execute( """ DO $$ @@ -47,6 +47,9 @@ def upgrade(): IF EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_name = 'searchspaces' AND column_name = 'fast_llm_id' + ) AND NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'searchspaces' AND column_name = 'agent_llm_id' ) THEN ALTER TABLE searchspaces RENAME COLUMN fast_llm_id TO agent_llm_id; END IF; @@ -61,6 +64,9 @@ def upgrade(): IF EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_name = 'searchspaces' AND column_name = 'long_context_llm_id' + ) AND NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'searchspaces' AND column_name = 'document_summary_llm_id' ) THEN ALTER TABLE searchspaces RENAME COLUMN long_context_llm_id TO document_summary_llm_id; END IF; @@ -100,7 +106,7 @@ def downgrade(): """ ) - # Rename columns back + # Rename columns back (only if source exists and target doesn't already exist) op.execute( """ DO $$ @@ -108,6 +114,9 @@ def downgrade(): IF EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_name = 'searchspaces' AND column_name = 'agent_llm_id' + ) AND NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'searchspaces' AND column_name = 'fast_llm_id' ) THEN ALTER TABLE searchspaces RENAME COLUMN agent_llm_id TO fast_llm_id; END IF; @@ -122,6 +131,9 @@ def downgrade(): IF EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_name = 'searchspaces' AND column_name = 'document_summary_llm_id' + ) AND NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'searchspaces' AND column_name = 'long_context_llm_id' ) THEN ALTER TABLE searchspaces RENAME COLUMN document_summary_llm_id TO long_context_llm_id; END IF; diff --git a/surfsense_backend/alembic/versions/55_rename_google_drive_connector_to_file.py b/surfsense_backend/alembic/versions/55_rename_google_drive_connector_to_file.py index 9ce57d95f..baaf1991f 100644 --- a/surfsense_backend/alembic/versions/55_rename_google_drive_connector_to_file.py +++ b/surfsense_backend/alembic/versions/55_rename_google_drive_connector_to_file.py @@ -60,14 +60,28 @@ def downgrade() -> None: connection = op.get_bind() - connection.execute( + # Only update if the target enum value exists (it won't on fresh databases) + result = connection.execute( text( """ - UPDATE documents - SET document_type = 'GOOGLE_DRIVE_CONNECTOR' - WHERE document_type = 'GOOGLE_DRIVE_FILE'; + SELECT EXISTS ( + SELECT 1 FROM pg_type t + JOIN pg_enum e ON t.oid = e.enumtypid + WHERE t.typname = 'documenttype' AND e.enumlabel = 'GOOGLE_DRIVE_CONNECTOR' + ); """ ) ) + enum_exists = result.scalar() - connection.commit() + if enum_exists: + connection.execute( + text( + """ + UPDATE documents + SET document_type = 'GOOGLE_DRIVE_CONNECTOR' + WHERE document_type = 'GOOGLE_DRIVE_FILE'; + """ + ) + ) + connection.commit() diff --git a/surfsense_backend/alembic/versions/5_remove_title_char_limit.py b/surfsense_backend/alembic/versions/5_remove_title_char_limit.py index 2e4cd56d1..afdbaa803 100644 --- a/surfsense_backend/alembic/versions/5_remove_title_char_limit.py +++ b/surfsense_backend/alembic/versions/5_remove_title_char_limit.py @@ -18,59 +18,77 @@ branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None -def upgrade() -> None: - # Alter Chat table - op.alter_column( - "chats", - "title", - existing_type=sa.String(200), - type_=sa.String(), - existing_nullable=False, +def table_exists(table_name: str) -> bool: + """Check if a table exists in the database.""" + conn = op.get_bind() + result = conn.execute( + sa.text( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)" + ), + {"table_name": table_name}, ) + return result.scalar() + + +def upgrade() -> None: + # Alter Chat table (may not exist on fresh databases, removed in migration 49) + if table_exists("chats"): + op.alter_column( + "chats", + "title", + existing_type=sa.String(200), + type_=sa.String(), + existing_nullable=False, + ) # Alter Document table - op.alter_column( - "documents", - "title", - existing_type=sa.String(200), - type_=sa.String(), - existing_nullable=False, - ) + if table_exists("documents"): + op.alter_column( + "documents", + "title", + existing_type=sa.String(200), + type_=sa.String(), + existing_nullable=False, + ) # Alter Podcast table - op.alter_column( - "podcasts", - "title", - existing_type=sa.String(200), - type_=sa.String(), - existing_nullable=False, - ) + if table_exists("podcasts"): + op.alter_column( + "podcasts", + "title", + existing_type=sa.String(200), + type_=sa.String(), + existing_nullable=False, + ) def downgrade() -> None: # Revert Chat table - op.alter_column( - "chats", - "title", - existing_type=sa.String(), - type_=sa.String(200), - existing_nullable=False, - ) + if table_exists("chats"): + op.alter_column( + "chats", + "title", + existing_type=sa.String(), + type_=sa.String(200), + existing_nullable=False, + ) # Revert Document table - op.alter_column( - "documents", - "title", - existing_type=sa.String(), - type_=sa.String(200), - existing_nullable=False, - ) + if table_exists("documents"): + op.alter_column( + "documents", + "title", + existing_type=sa.String(), + type_=sa.String(200), + existing_nullable=False, + ) # Revert Podcast table - op.alter_column( - "podcasts", - "title", - existing_type=sa.String(), - type_=sa.String(200), - existing_nullable=False, - ) + if table_exists("podcasts"): + op.alter_column( + "podcasts", + "title", + existing_type=sa.String(), + type_=sa.String(200), + existing_nullable=False, + ) diff --git a/surfsense_backend/alembic/versions/62_add_user_profile_columns.py b/surfsense_backend/alembic/versions/62_add_user_profile_columns.py new file mode 100644 index 000000000..a6fef0c5b --- /dev/null +++ b/surfsense_backend/alembic/versions/62_add_user_profile_columns.py @@ -0,0 +1,72 @@ +"""Add display_name and avatar_url columns to user table + +This migration adds: +- display_name column for user's full name from OAuth +- avatar_url column for user's profile picture URL from OAuth + +Revision ID: 62 +Revises: 61 +""" + +from collections.abc import Sequence + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "62" +down_revision: str | None = "61" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Add display_name and avatar_url columns to user table.""" + + # Add display_name column (nullable for existing users) + op.execute( + """ + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'user' AND column_name = 'display_name' + ) THEN + ALTER TABLE "user" + ADD COLUMN display_name VARCHAR; + END IF; + END$$; + """ + ) + + # Add avatar_url column (nullable for existing users) + op.execute( + """ + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'user' AND column_name = 'avatar_url' + ) THEN + ALTER TABLE "user" + ADD COLUMN avatar_url VARCHAR; + END IF; + END$$; + """ + ) + + +def downgrade() -> None: + """Remove display_name and avatar_url columns from user table.""" + + op.execute( + """ + ALTER TABLE "user" + DROP COLUMN IF EXISTS avatar_url; + """ + ) + op.execute( + """ + ALTER TABLE "user" + DROP COLUMN IF EXISTS display_name; + """ + ) diff --git a/surfsense_backend/alembic/versions/63_add_message_author_id.py b/surfsense_backend/alembic/versions/63_add_message_author_id.py new file mode 100644 index 000000000..2fc3f0b4c --- /dev/null +++ b/surfsense_backend/alembic/versions/63_add_message_author_id.py @@ -0,0 +1,47 @@ +"""Add author_id column to new_chat_messages table + +Revision ID: 63 +Revises: 62 +""" + +from collections.abc import Sequence + +from alembic import op + +revision: str = "63" +down_revision: str | None = "62" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Add author_id column to new_chat_messages table.""" + op.execute( + """ + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'new_chat_messages' AND column_name = 'author_id' + ) THEN + ALTER TABLE new_chat_messages + ADD COLUMN author_id UUID REFERENCES "user"(id) ON DELETE SET NULL; + + CREATE INDEX IF NOT EXISTS ix_new_chat_messages_author_id + ON new_chat_messages(author_id); + END IF; + END$$; + """ + ) + + +def downgrade() -> None: + """Remove author_id column from new_chat_messages table.""" + op.execute( + """ + DROP INDEX IF EXISTS ix_new_chat_messages_author_id; + ALTER TABLE new_chat_messages + DROP COLUMN IF EXISTS author_id; + """ + ) + diff --git a/surfsense_backend/alembic/versions/a1b2c3d4e5f6_add_mcp_connector_type.py b/surfsense_backend/alembic/versions/a1b2c3d4e5f6_add_mcp_connector_type.py new file mode 100644 index 000000000..e47bb2fa3 --- /dev/null +++ b/surfsense_backend/alembic/versions/a1b2c3d4e5f6_add_mcp_connector_type.py @@ -0,0 +1,37 @@ +"""Add MCP connector type + +Revision ID: a1b2c3d4e5f6 +Revises: 61 +Create Date: 2026-01-09 15:19:51.827647 + +""" +from collections.abc import Sequence + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = 'a1b2c3d4e5f6' +down_revision: str | None = '61' +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Add MCP_CONNECTOR to SearchSourceConnectorType enum.""" + # Add new enum value using raw SQL + op.execute( + """ + ALTER TYPE searchsourceconnectortype ADD VALUE IF NOT EXISTS 'MCP_CONNECTOR'; + """ + ) + + +def downgrade() -> None: + """Remove MCP_CONNECTOR from SearchSourceConnectorType enum.""" + # Note: PostgreSQL does not support removing enum values directly. + # To downgrade, you would need to: + # 1. Create a new enum without MCP_CONNECTOR + # 2. Alter the column to use the new enum + # 3. Drop the old enum + # This is left as a manual operation if needed. + pass diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index 6c8deb409..9675521f5 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -20,7 +20,7 @@ from app.agents.new_chat.system_prompt import ( build_configurable_system_prompt, build_surfsense_system_prompt, ) -from app.agents.new_chat.tools import build_tools +from app.agents.new_chat.tools.registry import build_tools_async from app.services.connector_service import ConnectorService # ============================================================================= @@ -28,7 +28,7 @@ from app.services.connector_service import ConnectorService # ============================================================================= -def create_surfsense_deep_agent( +async def create_surfsense_deep_agent( llm: ChatLiteLLM, search_space_id: int, db_session: AsyncSession, @@ -120,8 +120,8 @@ def create_surfsense_deep_agent( "firecrawl_api_key": firecrawl_api_key, } - # Build tools using the registry - tools = build_tools( + # Build tools using the async registry (includes MCP tools) + tools = await build_tools_async( dependencies=dependencies, enabled_tools=enabled_tools, disabled_tools=disabled_tools, diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_client.py b/surfsense_backend/app/agents/new_chat/tools/mcp_client.py new file mode 100644 index 000000000..d91065661 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_client.py @@ -0,0 +1,188 @@ +"""MCP Client Wrapper. + +This module provides a client for communicating with MCP servers via stdio transport. +It handles server lifecycle management, tool discovery, and tool execution. +""" + +import logging +import os +from contextlib import asynccontextmanager +from typing import Any + +from mcp import ClientSession +from mcp.client.stdio import StdioServerParameters, stdio_client + +logger = logging.getLogger(__name__) + + +class MCPClient: + """Client for communicating with an MCP server.""" + + def __init__(self, command: str, args: list[str], env: dict[str, str] | None = None): + """Initialize MCP client. + + Args: + command: Command to spawn the MCP server (e.g., "uvx", "node") + args: Arguments for the command (e.g., ["mcp-server-git"]) + env: Optional environment variables for the server process + + """ + self.command = command + self.args = args + self.env = env or {} + self.session: ClientSession | None = None + + @asynccontextmanager + async def connect(self): + """Connect to the MCP server and manage its lifecycle. + + Yields: + ClientSession: Active MCP session for making requests + + """ + try: + # Merge env vars with current environment + server_env = os.environ.copy() + server_env.update(self.env) + + # Create server parameters with env + server_params = StdioServerParameters( + command=self.command, + args=self.args, + env=server_env + ) + + # Spawn server process and create session + # Note: Cannot combine these context managers because ClientSession + # needs the read/write streams from stdio_client + async with stdio_client(server=server_params) as (read, write): + async with ClientSession(read, write) as session: + # Initialize the connection + await session.initialize() + self.session = session + logger.info( + "Connected to MCP server: %s %s", + self.command, + " ".join(self.args), + ) + yield session + + except Exception as e: + logger.error("Failed to connect to MCP server: %s", e, exc_info=True) + raise + finally: + self.session = None + logger.info("Disconnected from MCP server: %s", self.command) + + async def list_tools(self) -> list[dict[str, Any]]: + """List all tools available from the MCP server. + + Returns: + List of tool definitions with name, description, and input schema + + Raises: + RuntimeError: If not connected to server + + """ + if not self.session: + raise RuntimeError("Not connected to MCP server. Use 'async with client.connect():'") + + try: + # Call tools/list RPC method + response = await self.session.list_tools() + + tools = [] + for tool in response.tools: + tools.append({ + "name": tool.name, + "description": tool.description or "", + "input_schema": tool.inputSchema if hasattr(tool, "inputSchema") else {}, + }) + + logger.info("Listed %d tools from MCP server", len(tools)) + return tools + + except Exception as e: + logger.error("Failed to list tools from MCP server: %s", e, exc_info=True) + raise + + async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: + """Call a tool on the MCP server. + + Args: + tool_name: Name of the tool to call + arguments: Arguments to pass to the tool + + Returns: + Tool execution result + + Raises: + RuntimeError: If not connected to server + + """ + if not self.session: + raise RuntimeError("Not connected to MCP server. Use 'async with client.connect():'") + + try: + logger.info("Calling MCP tool '%s' with arguments: %s", tool_name, arguments) + + # Call tools/call RPC method + response = await self.session.call_tool(tool_name, arguments=arguments) + + # Extract content from response + result = [] + for content in response.content: + if hasattr(content, "text"): + result.append(content.text) + elif hasattr(content, "data"): + result.append(str(content.data)) + else: + result.append(str(content)) + + result_str = "\n".join(result) if result else "" + logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200]) + return result_str + + except RuntimeError as e: + # Handle validation errors from MCP server responses + # Some MCP servers (like server-memory) return extra fields not in their schema + if "Invalid structured content" in str(e): + logger.warning("MCP server returned data not matching its schema, but continuing: %s", e) + # Try to extract result from error message or return a success message + return "Operation completed (server returned unexpected format)" + raise + except (ValueError, TypeError, AttributeError, KeyError) as e: + logger.error("Failed to call MCP tool '%s': %s", tool_name, e, exc_info=True) + return f"Error calling tool: {e!s}" + + +async def test_mcp_connection( + command: str, args: list[str], env: dict[str, str] | None = None +) -> dict[str, Any]: + """Test connection to an MCP server and fetch available tools. + + Args: + command: Command to spawn the MCP server + args: Arguments for the command + env: Optional environment variables + + Returns: + Dict with connection status and available tools + + """ + client = MCPClient(command, args, env) + + try: + async with client.connect(): + tools = await client.list_tools() + return { + "status": "success", + "message": f"Connected successfully. Found {len(tools)} tools.", + "tools": tools, + } + except (RuntimeError, ConnectionError, TimeoutError, OSError) as e: + return { + "status": "error", + "message": f"Failed to connect: {e!s}", + "tools": [], + } diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py new file mode 100644 index 000000000..81c7d074f --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -0,0 +1,189 @@ +"""MCP Tool Factory. + +This module creates LangChain tools from MCP servers using the Model Context Protocol. +Tools are dynamically discovered from MCP servers - no manual configuration needed. + +This implements real MCP protocol support similar to Cursor's implementation. +""" + +import logging +from typing import Any + +from langchain_core.tools import StructuredTool +from pydantic import BaseModel, create_model +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.tools.mcp_client import MCPClient +from app.db import SearchSourceConnector, SearchSourceConnectorType + +logger = logging.getLogger(__name__) + + +def _create_dynamic_input_model_from_schema( + tool_name: str, input_schema: dict[str, Any], +) -> type[BaseModel]: + """Create a Pydantic model from MCP tool's JSON schema. + + Args: + tool_name: Name of the tool (used for model class name) + input_schema: JSON schema from MCP server + + Returns: + Pydantic model class for tool input validation + + """ + properties = input_schema.get("properties", {}) + required_fields = input_schema.get("required", []) + + # Build Pydantic field definitions + field_definitions = {} + for param_name, param_schema in properties.items(): + param_description = param_schema.get("description", "") + is_required = param_name in required_fields + + # Use Any type for complex schemas to preserve structure + # This allows the MCP server to do its own validation + from typing import Any as AnyType + + from pydantic import Field + + if is_required: + field_definitions[param_name] = (AnyType, Field(..., description=param_description)) + else: + field_definitions[param_name] = ( + AnyType | None, + Field(None, description=param_description), + ) + + # Create dynamic model + model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input" + return create_model(model_name, **field_definitions) + + +async def _create_mcp_tool_from_definition( + tool_def: dict[str, Any], + mcp_client: MCPClient, +) -> StructuredTool: + """Create a LangChain tool from an MCP tool definition. + + Args: + tool_def: Tool definition from MCP server with name, description, input_schema + mcp_client: MCP client instance for calling the tool + + Returns: + LangChain StructuredTool instance + + """ + tool_name = tool_def.get("name", "unnamed_tool") + tool_description = tool_def.get("description", "No description provided") + input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}}) + + # Log the actual schema for debugging + logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}") + + # Create dynamic input model from schema + input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema) + + async def mcp_tool_call(**kwargs) -> str: + """Execute the MCP tool call via the client.""" + logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}") + + try: + # Connect to server and call tool + async with mcp_client.connect(): + result = await mcp_client.call_tool(tool_name, kwargs) + return str(result) + except Exception as e: + error_msg = f"MCP tool '{tool_name}' failed: {e!s}" + logger.exception(error_msg) + return f"Error: {error_msg}" + + # Create StructuredTool with response_format to preserve exact schema + tool = StructuredTool( + name=tool_name, + description=tool_description, + coroutine=mcp_tool_call, + args_schema=input_model, + # Store the original MCP schema as metadata so we can access it later + metadata={"mcp_input_schema": input_schema}, + ) + + logger.info(f"Created MCP tool: '{tool_name}'") + return tool + + +async def load_mcp_tools( + session: AsyncSession, search_space_id: int, +) -> list[StructuredTool]: + """Load all MCP tools from user's active MCP server connectors. + + This discovers tools dynamically from MCP servers using the protocol. + + Args: + session: Database session + search_space_id: User's search space ID + + Returns: + List of LangChain StructuredTool instances + + """ + try: + # Fetch all MCP connectors for this search space + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.connector_type + == SearchSourceConnectorType.MCP_CONNECTOR, + SearchSourceConnector.search_space_id == search_space_id, + ), + ) + + tools: list[StructuredTool] = [] + for connector in result.scalars(): + try: + # Extract server config + config = connector.config or {} + server_config = config.get("server_config", {}) + + command = server_config.get("command") + args = server_config.get("args", []) + env = server_config.get("env", {}) + + if not command: + logger.warning(f"MCP connector {connector.id} missing command, skipping") + continue + + # Create MCP client + mcp_client = MCPClient(command, args, env) + + # Connect and discover tools + async with mcp_client.connect(): + tool_definitions = await mcp_client.list_tools() + + logger.info( + f"Discovered {len(tool_definitions)} tools from MCP server " + f"'{command}' (connector {connector.id})" + ) + + # Create LangChain tools from definitions + for tool_def in tool_definitions: + try: + tool = await _create_mcp_tool_from_definition(tool_def, mcp_client) + tools.append(tool) + except Exception as e: + logger.exception( + f"Failed to create tool '{tool_def.get('name')}' " + f"from connector {connector.id}: {e!s}", + ) + + except Exception as e: + logger.exception( + f"Failed to load tools from MCP connector {connector.id}: {e!s}", + ) + + logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}") + return tools + + except Exception as e: + logger.exception(f"Failed to load MCP tools: {e!s}") + return [] diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index c7439bf8f..bb8708b2b 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -1,5 +1,4 @@ -""" -Tools registry for SurfSense deep agent. +"""Tools registry for SurfSense deep agent. This module provides a registry pattern for managing tools in the SurfSense agent. It makes it easy for OSS contributors to add new tools by: @@ -37,6 +36,7 @@ Example of adding a new tool: ), """ +import logging from collections.abc import Callable from dataclasses import dataclass, field from typing import Any @@ -46,6 +46,7 @@ from langchain_core.tools import BaseTool from .display_image import create_display_image_tool from .knowledge_base import create_search_knowledge_base_tool from .link_preview import create_link_preview_tool +from .mcp_tool import load_mcp_tools from .podcast import create_generate_podcast_tool from .scrape_webpage import create_scrape_webpage_tool from .search_surfsense_docs import create_search_surfsense_docs_tool @@ -57,8 +58,7 @@ from .search_surfsense_docs import create_search_surfsense_docs_tool @dataclass class ToolDefinition: - """ - Definition of a tool that can be added to the agent. + """Definition of a tool that can be added to the agent. Attributes: name: Unique identifier for the tool @@ -66,6 +66,7 @@ class ToolDefinition: factory: Callable that creates the tool. Receives a dict of dependencies. requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session") enabled_by_default: Whether the tool is enabled when no explicit config is provided + """ name: str @@ -178,8 +179,7 @@ def build_tools( disabled_tools: list[str] | None = None, additional_tools: list[BaseTool] | None = None, ) -> list[BaseTool]: - """ - Build the list of tools for the agent. + """Build the list of tools for the agent. Args: dependencies: Dict containing all possible dependencies: @@ -206,6 +206,7 @@ def build_tools( # Add custom tools tools = build_tools(deps, additional_tools=[my_custom_tool]) + """ # Determine which tools to enable if enabled_tools is not None: @@ -226,8 +227,9 @@ def build_tools( # Check that all required dependencies are provided missing_deps = [dep for dep in tool_def.requires if dep not in dependencies] if missing_deps: + msg = f"Tool '{tool_def.name}' requires dependencies: {missing_deps}" raise ValueError( - f"Tool '{tool_def.name}' requires dependencies: {missing_deps}" + msg, ) # Create the tool @@ -239,3 +241,61 @@ def build_tools( tools.extend(additional_tools) return tools + + +async def build_tools_async( + dependencies: dict[str, Any], + enabled_tools: list[str] | None = None, + disabled_tools: list[str] | None = None, + additional_tools: list[BaseTool] | None = None, + include_mcp_tools: bool = True, +) -> list[BaseTool]: + """Async version of build_tools that also loads MCP tools from database. + + Design Note: + This function exists because MCP tools require database queries to load user configs, + while built-in tools are created synchronously from static code. + + Alternative: We could make build_tools() itself async and always query the database, + but that would force async everywhere even when only using built-in tools. The current + design keeps the simple case (static tools only) synchronous while supporting dynamic + database-loaded tools through this async wrapper. + + Args: + dependencies: Dict containing all possible dependencies + enabled_tools: Explicit list of tool names to enable. If None, uses defaults. + disabled_tools: List of tool names to disable (applied after enabled_tools). + additional_tools: Extra tools to add (e.g., custom tools not in registry). + include_mcp_tools: Whether to load user's MCP tools from database. + + Returns: + List of configured tool instances ready for the agent, including MCP tools. + + """ + # Build standard tools + tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools) + + # Load MCP tools if requested and dependencies are available + if ( + include_mcp_tools + and "db_session" in dependencies + and "search_space_id" in dependencies + ): + try: + mcp_tools = await load_mcp_tools( + dependencies["db_session"], dependencies["search_space_id"], + ) + tools.extend(mcp_tools) + logging.info( + f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}", + ) + except Exception as e: + # Log error but don't fail - just continue without MCP tools + logging.exception(f"Failed to load MCP tools: {e!s}") + + # Log all tools being returned to agent + logging.info( + f"Total tools for agent: {len(tools)} - {[t.name for t in tools]}", + ) + + return tools diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 00a0c27e6..9b245ba44 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -80,6 +80,7 @@ class SearchSourceConnectorType(str, Enum): WEBCRAWLER_CONNECTOR = "WEBCRAWLER_CONNECTOR" BOOKSTACK_CONNECTOR = "BOOKSTACK_CONNECTOR" CIRCLEBACK_CONNECTOR = "CIRCLEBACK_CONNECTOR" + MCP_CONNECTOR = "MCP_CONNECTOR" # Model Context Protocol - User-defined API tools class LiteLLMProvider(str, Enum): @@ -412,8 +413,17 @@ class NewChatMessage(BaseModel, TimestampMixin): index=True, ) - # Relationship + # Track who sent this message (for shared chats) + author_id = Column( + UUID(as_uuid=True), + ForeignKey("user.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + + # Relationships thread = relationship("NewChatThread", back_populates="messages") + author = relationship("User") class Document(BaseModel, TimestampMixin): @@ -611,7 +621,8 @@ class SearchSourceConnector(BaseModel, TimestampMixin): "search_space_id", "user_id", "connector_type", - name="uq_searchspace_user_connector_type", + "name", + name="uq_searchspace_user_connector_type_name", ), ) @@ -919,6 +930,10 @@ if config.AUTH_TYPE == "GOOGLE": ) pages_used = Column(Integer, nullable=False, default=0, server_default="0") + # User profile from OAuth + display_name = Column(String, nullable=True) + avatar_url = Column(String, nullable=True) + else: class User(SQLAlchemyBaseUserTableUUID, Base): @@ -958,6 +973,10 @@ else: ) pages_used = Column(Integer, nullable=False, default=0, server_default="0") + # User profile (can be set manually for non-OAuth users) + display_name = Column(String, nullable=True) + avatar_url = Column(String, nullable=True) + engine = create_async_engine(DATABASE_URL) async_session_maker = async_sessionmaker(engine, expire_on_commit=False) diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index fb5808307..e4dc5714a 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -411,11 +411,9 @@ async def get_thread_messages( Requires CHATS_READ permission. """ try: - # Get thread with messages + # Get thread first result = await session.execute( - select(NewChatThread) - .options(selectinload(NewChatThread.messages)) - .filter(NewChatThread.id == thread_id) + select(NewChatThread).filter(NewChatThread.id == thread_id) ) thread = result.scalars().first() @@ -434,6 +432,15 @@ async def get_thread_messages( # Check thread-level access based on visibility await check_thread_access(session, thread, user) + # Get messages with their authors loaded + messages_result = await session.execute( + select(NewChatMessage) + .options(selectinload(NewChatMessage.author)) + .filter(NewChatMessage.thread_id == thread_id) + .order_by(NewChatMessage.created_at) + ) + db_messages = messages_result.scalars().all() + # Return messages in the format expected by assistant-ui messages = [ NewChatMessageRead( @@ -442,8 +449,11 @@ async def get_thread_messages( role=msg.role, content=msg.content, created_at=msg.created_at, + author_id=msg.author_id, + author_display_name=msg.author.display_name if msg.author else None, + author_avatar_url=msg.author.avatar_url if msg.author else None, ) - for msg in thread.messages + for msg in db_messages ] return ThreadHistoryLoadResponse(messages=messages) @@ -782,6 +792,7 @@ async def append_message( thread_id=thread_id, role=message_role, content=message.content, + author_id=user.id, ) session.add(db_message) diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index 1aca7c43b..9c6c2fc3f 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -7,6 +7,13 @@ PUT /search-source-connectors/{connector_id} - Update a specific connector DELETE /search-source-connectors/{connector_id} - Delete a specific connector POST /search-source-connectors/{connector_id}/index - Index content from a connector to a search space +MCP (Model Context Protocol) Connector routes: +POST /connectors/mcp - Create a new MCP connector with custom API tools +GET /connectors/mcp - List all MCP connectors for the current user's search space +GET /connectors/mcp/{connector_id} - Get a specific MCP connector with tools config +PUT /connectors/mcp/{connector_id} - Update an MCP connector's tools config +DELETE /connectors/mcp/{connector_id} - Delete an MCP connector + Note: OAuth connectors (Gmail, Drive, Slack, etc.) support multiple accounts per search space. Non-OAuth connectors (BookStack, GitHub, etc.) are limited to one per search space. """ @@ -32,6 +39,9 @@ from app.db import ( ) from app.schemas import ( GoogleDriveIndexRequest, + MCPConnectorCreate, + MCPConnectorRead, + MCPConnectorUpdate, SearchSourceConnectorBase, SearchSourceConnectorCreate, SearchSourceConnectorRead, @@ -128,18 +138,20 @@ async def create_search_source_connector( # Check if a connector with the same type already exists for this search space # (for non-OAuth connectors that don't support multiple accounts) - result = await session.execute( - select(SearchSourceConnector).filter( - SearchSourceConnector.search_space_id == search_space_id, - SearchSourceConnector.connector_type == connector.connector_type, - ) - ) - existing_connector = result.scalars().first() - if existing_connector: - raise HTTPException( - status_code=409, - detail=f"A connector with type {connector.connector_type} already exists in this search space.", + # Exception: MCP_CONNECTOR can have multiple instances with different names + if connector.connector_type != SearchSourceConnectorType.MCP_CONNECTOR: + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.connector_type == connector.connector_type, + ) ) + existing_connector = result.scalars().first() + if existing_connector: + raise HTTPException( + status_code=409, + detail=f"A connector with type {connector.connector_type} already exists in this search space.", + ) # Prepare connector data connector_data = connector.model_dump() @@ -2054,3 +2066,348 @@ async def run_bookstack_indexing( indexing_function=index_bookstack_pages, update_timestamp_func=_update_connector_timestamp_by_id, ) + + +# ============================================================================= +# MCP Connector Routes +# ============================================================================= + + +@router.post("/connectors/mcp", response_model=MCPConnectorRead, status_code=201) +async def create_mcp_connector( + connector_data: MCPConnectorCreate, + search_space_id: int = Query(..., description="Search space ID"), + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Create a new MCP (Model Context Protocol) connector. + + MCP connectors allow users to connect to MCP servers (like in Cursor). + Tools are auto-discovered from the server - no manual configuration needed. + + Args: + connector_data: MCP server configuration (command, args, env) + search_space_id: ID of the search space to attach the connector to + session: Database session + user: Current authenticated user + + Returns: + Created MCP connector with server configuration + + Raises: + HTTPException: If search space not found or permission denied + """ + try: + # Check user has permission to create connectors + await check_permission( + session, + user, + search_space_id, + Permission.CONNECTORS_CREATE.value, + "You don't have permission to create connectors in this search space", + ) + + # Create the connector with server config + db_connector = SearchSourceConnector( + name=connector_data.name, + connector_type=SearchSourceConnectorType.MCP_CONNECTOR, + is_indexable=False, # MCP connectors are not indexable + config={"server_config": connector_data.server_config.model_dump()}, + periodic_indexing_enabled=False, + indexing_frequency_minutes=None, + search_space_id=search_space_id, + user_id=user.id, + ) + + session.add(db_connector) + await session.commit() + await session.refresh(db_connector) + + logger.info( + f"Created MCP connector {db_connector.id} for server '{connector_data.server_config.command}' " + f"for user {user.id} in search space {search_space_id}" + ) + + # Convert to read schema + connector_read = SearchSourceConnectorRead.model_validate(db_connector) + return MCPConnectorRead.from_connector(connector_read) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to create MCP connector: {e!s}", exc_info=True) + await session.rollback() + raise HTTPException( + status_code=500, detail=f"Failed to create MCP connector: {e!s}" + ) from e + + +@router.get("/connectors/mcp", response_model=list[MCPConnectorRead]) +async def list_mcp_connectors( + search_space_id: int = Query(..., description="Search space ID"), + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + List all MCP connectors for a search space. + + Args: + search_space_id: ID of the search space + session: Database session + user: Current authenticated user + + Returns: + List of MCP connectors with their tool configurations + """ + try: + # Check user has permission to read connectors + await check_permission( + session, + user, + search_space_id, + Permission.CONNECTORS_READ.value, + "You don't have permission to view connectors in this search space", + ) + + # Fetch MCP connectors + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.connector_type + == SearchSourceConnectorType.MCP_CONNECTOR, + SearchSourceConnector.search_space_id == search_space_id, + ) + ) + + connectors = result.scalars().all() + return [ + MCPConnectorRead.from_connector(SearchSourceConnectorRead.model_validate(c)) + for c in connectors + ] + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to list MCP connectors: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to list MCP connectors: {e!s}" + ) from e + + +@router.get("/connectors/mcp/{connector_id}", response_model=MCPConnectorRead) +async def get_mcp_connector( + connector_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Get a specific MCP connector by ID. + + Args: + connector_id: ID of the connector + session: Database session + user: Current authenticated user + + Returns: + MCP connector with tool configurations + """ + try: + # Fetch connector + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.MCP_CONNECTOR, + ) + ) + connector = result.scalars().first() + + if not connector: + raise HTTPException(status_code=404, detail="MCP connector not found") + + # Check user has permission to read connectors + await check_permission( + session, + user, + connector.search_space_id, + Permission.CONNECTORS_READ.value, + "You don't have permission to view this connector", + ) + + connector_read = SearchSourceConnectorRead.model_validate(connector) + return MCPConnectorRead.from_connector(connector_read) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to get MCP connector: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to get MCP connector: {e!s}" + ) from e + + +@router.put("/connectors/mcp/{connector_id}", response_model=MCPConnectorRead) +async def update_mcp_connector( + connector_id: int, + connector_update: MCPConnectorUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Update an MCP connector. + + Args: + connector_id: ID of the connector to update + connector_update: Updated connector data + session: Database session + user: Current authenticated user + + Returns: + Updated MCP connector + """ + try: + # Fetch connector + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.MCP_CONNECTOR, + ) + ) + connector = result.scalars().first() + + if not connector: + raise HTTPException(status_code=404, detail="MCP connector not found") + + # Check user has permission to update connectors + await check_permission( + session, + user, + connector.search_space_id, + Permission.CONNECTORS_UPDATE.value, + "You don't have permission to update this connector", + ) + + # Update fields + if connector_update.name is not None: + connector.name = connector_update.name + + if connector_update.server_config is not None: + connector.config = { + "server_config": connector_update.server_config.model_dump() + } + + connector.updated_at = datetime.now(UTC) + + await session.commit() + await session.refresh(connector) + + logger.info(f"Updated MCP connector {connector_id}") + + connector_read = SearchSourceConnectorRead.model_validate(connector) + return MCPConnectorRead.from_connector(connector_read) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to update MCP connector: {e!s}", exc_info=True) + await session.rollback() + raise HTTPException( + status_code=500, detail=f"Failed to update MCP connector: {e!s}" + ) from e + + +@router.delete("/connectors/mcp/{connector_id}", status_code=204) +async def delete_mcp_connector( + connector_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Delete an MCP connector. + + Args: + connector_id: ID of the connector to delete + session: Database session + user: Current authenticated user + """ + try: + # Fetch connector + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.MCP_CONNECTOR, + ) + ) + connector = result.scalars().first() + + if not connector: + raise HTTPException(status_code=404, detail="MCP connector not found") + + # Check user has permission to delete connectors + await check_permission( + session, + user, + connector.search_space_id, + Permission.CONNECTORS_DELETE.value, + "You don't have permission to delete this connector", + ) + + await session.delete(connector) + await session.commit() + + logger.info(f"Deleted MCP connector {connector_id}") + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to delete MCP connector: {e!s}", exc_info=True) + await session.rollback() + raise HTTPException( + status_code=500, detail=f"Failed to delete MCP connector: {e!s}" + ) from e + + +@router.post("/connectors/mcp/test") +async def test_mcp_server_connection( + server_config: dict = Body(...), + user: User = Depends(current_active_user), +): + """ + Test connection to an MCP server and fetch available tools. + + This endpoint allows users to test their MCP server configuration + before saving it, similar to Cursor's flow. + + Args: + server_config: Server configuration with command, args, env + user: Current authenticated user + + Returns: + Connection status and list of available tools + """ + try: + from app.agents.new_chat.tools.mcp_client import test_mcp_connection + + command = server_config.get("command") + args = server_config.get("args", []) + env = server_config.get("env", {}) + + if not command: + raise HTTPException(status_code=400, detail="Server command is required") + + # Test the connection + result = await test_mcp_connection(command, args, env) + + return result + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to test MCP connection: {e!s}", exc_info=True) + return { + "status": "error", + "message": f"Failed to test connection: {e!s}", + "tools": [], + } diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index a8bde7ed9..076ac5915 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -55,6 +55,10 @@ from .rbac_schemas import ( UserSearchSpaceAccess, ) from .search_source_connector import ( + MCPConnectorCreate, + MCPConnectorRead, + MCPConnectorUpdate, + MCPServerConfig, SearchSourceConnectorBase, SearchSourceConnectorCreate, SearchSourceConnectorRead, @@ -108,6 +112,11 @@ __all__ = [ "LogFilter", "LogRead", "LogUpdate", + # Search source connector schemas + "MCPConnectorCreate", + "MCPConnectorRead", + "MCPConnectorUpdate", + "MCPServerConfig", "MembershipRead", "MembershipReadWithUser", "MembershipUpdate", @@ -135,7 +144,6 @@ __all__ = [ "RoleCreate", "RoleRead", "RoleUpdate", - # Search source connector schemas "SearchSourceConnectorBase", "SearchSourceConnectorCreate", "SearchSourceConnectorRead", diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index e6dbcd920..3734b0470 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -38,6 +38,9 @@ class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel): """Schema for reading a message.""" thread_id: int + author_id: UUID | None = None + author_display_name: str | None = None + author_avatar_url: str | None = None model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/schemas/search_source_connector.py b/surfsense_backend/app/schemas/search_source_connector.py index dbe4dce1f..e27cc775c 100644 --- a/surfsense_backend/app/schemas/search_source_connector.py +++ b/surfsense_backend/app/schemas/search_source_connector.py @@ -23,7 +23,7 @@ class SearchSourceConnectorBase(BaseModel): @field_validator("config") @classmethod def validate_config_for_connector_type( - cls, config: dict[str, Any], values: dict[str, Any] + cls, config: dict[str, Any], values: dict[str, Any], ) -> dict[str, Any]: connector_type = values.data.get("connector_type") return validate_connector_config(connector_type, config) @@ -38,15 +38,18 @@ class SearchSourceConnectorBase(BaseModel): """ if self.periodic_indexing_enabled: if not self.is_indexable: + msg = "periodic_indexing_enabled can only be True for indexable connectors" raise ValueError( - "periodic_indexing_enabled can only be True for indexable connectors" + msg, ) if self.indexing_frequency_minutes is None: + msg = "indexing_frequency_minutes is required when periodic_indexing_enabled is True" raise ValueError( - "indexing_frequency_minutes is required when periodic_indexing_enabled is True" + msg, ) if self.indexing_frequency_minutes <= 0: - raise ValueError("indexing_frequency_minutes must be greater than 0") + msg = "indexing_frequency_minutes must be greater than 0" + raise ValueError(msg) return self @@ -70,3 +73,63 @@ class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampMod user_id: uuid.UUID model_config = ConfigDict(from_attributes=True) + + +# ============================================================================= +# MCP-specific schemas +# ============================================================================= + + +class MCPServerConfig(BaseModel): + """Configuration for an MCP server connection (similar to Cursor's config).""" + + command: str # e.g., "uvx", "node", "python" + args: list[str] = [] # e.g., ["mcp-server-git", "--repository", "/path"] + env: dict[str, str] = {} # Environment variables for the server process + transport: str = "stdio" # "stdio" | "sse" | "http" (stdio is most common) + + +class MCPConnectorCreate(BaseModel): + """Schema for creating an MCP connector.""" + + name: str + server_config: MCPServerConfig + + +class MCPConnectorUpdate(BaseModel): + """Schema for updating an MCP connector.""" + + name: str | None = None + server_config: MCPServerConfig | None = None + + +class MCPConnectorRead(BaseModel): + """Schema for reading an MCP connector with server config.""" + + id: int + name: str + connector_type: SearchSourceConnectorType + server_config: MCPServerConfig + search_space_id: int + user_id: uuid.UUID + created_at: datetime + updated_at: datetime + + model_config = ConfigDict(from_attributes=True) + + @classmethod + def from_connector(cls, connector: SearchSourceConnectorRead) -> "MCPConnectorRead": + """Convert from base SearchSourceConnectorRead.""" + config = connector.config or {} + server_config = MCPServerConfig(**config.get("server_config", {})) + + return cls( + id=connector.id, + name=connector.name, + connector_type=connector.connector_type, + server_config=server_config, + search_space_id=connector.search_space_id, + user_id=connector.user_id, + created_at=connector.created_at, + updated_at=connector.updated_at, + ) diff --git a/surfsense_backend/app/schemas/users.py b/surfsense_backend/app/schemas/users.py index a8e0cfac8..88d0a4f37 100644 --- a/surfsense_backend/app/schemas/users.py +++ b/surfsense_backend/app/schemas/users.py @@ -6,6 +6,8 @@ from fastapi_users import schemas class UserRead(schemas.BaseUser[uuid.UUID]): pages_limit: int pages_used: int + display_name: str | None = None + avatar_url: str | None = None class UserCreate(schemas.BaseUserCreate): @@ -13,4 +15,5 @@ class UserCreate(schemas.BaseUserCreate): class UserUpdate(schemas.BaseUserUpdate): - pass + display_name: str | None = None + avatar_url: str | None = None diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index a74f134dc..5f8cd638b 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -237,7 +237,7 @@ async def stream_new_chat( checkpointer = await get_checkpointer() # Create the deep agent with checkpointer and configurable prompts - agent = create_surfsense_deep_agent( + agent = await create_surfsense_deep_agent( llm=llm, search_space_id=search_space_id, db_session=session, diff --git a/surfsense_backend/app/users.py b/surfsense_backend/app/users.py index dd284307f..e86eb752b 100644 --- a/surfsense_backend/app/users.py +++ b/surfsense_backend/app/users.py @@ -1,6 +1,7 @@ import logging import uuid +import httpx from fastapi import Depends, Request, Response from fastapi.responses import JSONResponse, RedirectResponse from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models @@ -46,6 +47,71 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): reset_password_token_secret = SECRET verification_token_secret = SECRET + async def oauth_callback( + self, + oauth_name: str, + access_token: str, + account_id: str, + account_email: str, + expires_at: int | None = None, + refresh_token: str | None = None, + request: Request | None = None, + *, + associate_by_email: bool = False, + is_verified_by_default: bool = False, + ) -> User: + """ + Override OAuth callback to capture Google profile data (name, avatar). + """ + # Call parent implementation to create/get user + user = await super().oauth_callback( + oauth_name, + access_token, + account_id, + account_email, + expires_at, + refresh_token, + request, + associate_by_email=associate_by_email, + is_verified_by_default=is_verified_by_default, + ) + + # Fetch and store Google profile data if not already set + if oauth_name == "google" and (not user.display_name or not user.avatar_url): + try: + async with httpx.AsyncClient() as client: + response = await client.get( + "https://people.googleapis.com/v1/people/me", + params={"personFields": "names,photos"}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + response.raise_for_status() + profile = response.json() + + update_dict = {} + + # Extract name from names array + names = profile.get("names", []) + if not user.display_name and names: + display_name = names[0].get("displayName") + if display_name: + update_dict["display_name"] = display_name + + # Extract photo URL from photos array + photos = profile.get("photos", []) + if not user.avatar_url and photos: + photo_url = photos[0].get("url") + if photo_url: + update_dict["avatar_url"] = photo_url + + if update_dict: + user = await self.user_db.update(user, update_dict) + + except Exception as e: + logger.warning(f"Failed to fetch Google profile: {e}") + + return user + async def on_after_register(self, user: User, request: Request | None = None): """ Called after a user registers. Creates a default search space for the user diff --git a/surfsense_backend/app/utils/connector_naming.py b/surfsense_backend/app/utils/connector_naming.py index 731f419d6..a2b748a3a 100644 --- a/surfsense_backend/app/utils/connector_naming.py +++ b/surfsense_backend/app/utils/connector_naming.py @@ -8,9 +8,9 @@ from typing import Any from urllib.parse import urlparse from uuid import UUID -from sqlalchemy import func +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select +from sqlalchemy.sql import func from app.db import SearchSourceConnector, SearchSourceConnectorType @@ -27,6 +27,7 @@ BASE_NAME_FOR_TYPE = { SearchSourceConnectorType.DISCORD_CONNECTOR: "Discord", SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "Confluence", SearchSourceConnectorType.AIRTABLE_CONNECTOR: "Airtable", + SearchSourceConnectorType.MCP_CONNECTOR: "Model Context Protocol (MCP)", } @@ -75,7 +76,7 @@ def extract_identifier_from_credentials( if ".atlassian.net" in hostname: return hostname.replace(".atlassian.net", "") return hostname - except Exception: + except (ValueError, TypeError, AttributeError): pass return None diff --git a/surfsense_backend/pyproject.toml b/surfsense_backend/pyproject.toml index e3e7583f8..83a00b4e4 100644 --- a/surfsense_backend/pyproject.toml +++ b/surfsense_backend/pyproject.toml @@ -57,6 +57,9 @@ dependencies = [ "chonkie[all]>=1.5.0", "langgraph-checkpoint-postgres>=3.0.2", "psycopg[binary,pool]>=3.3.2", + "mcp>=1.25.0", + "starlette>=0.40.0,<0.51.0", + "sse-starlette>=3.1.1,<3.1.2", ] [dependency-groups] diff --git a/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsTableShell.tsx b/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsTableShell.tsx index 566e103ac..94c0626e6 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsTableShell.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentsTableShell.tsx @@ -79,25 +79,17 @@ export function DocumentsTableShell({ [documents, sortKey, sortDesc] ); - // Filter out SURFSENSE_DOCS for selection purposes - const selectableDocs = React.useMemo( - () => sorted.filter((d) => d.document_type !== "SURFSENSE_DOCS"), - [sorted] - ); - - const allSelectedOnPage = - selectableDocs.length > 0 && selectableDocs.every((d) => selectedIds.has(d.id)); - const someSelectedOnPage = - selectableDocs.some((d) => selectedIds.has(d.id)) && !allSelectedOnPage; + const allSelectedOnPage = sorted.length > 0 && sorted.every((d) => selectedIds.has(d.id)); + const someSelectedOnPage = sorted.some((d) => selectedIds.has(d.id)) && !allSelectedOnPage; const toggleAll = (checked: boolean) => { const next = new Set(selectedIds); if (checked) - selectableDocs.forEach((d) => { + sorted.forEach((d) => { next.add(d.id); }); else - selectableDocs.forEach((d) => { + sorted.forEach((d) => { next.delete(d.id); }); setSelectedIds(next); @@ -238,10 +230,9 @@ export function DocumentsTableShell({ const icon = getDocumentTypeIcon(doc.document_type); const title = doc.title; const truncatedTitle = title.length > 30 ? `${title.slice(0, 30)}...` : title; - const isSurfsenseDoc = doc.document_type === "SURFSENSE_DOCS"; return ( !isSurfsenseDoc && toggleOne(doc.id, !!v)} - disabled={isSurfsenseDoc} + checked={selectedIds.has(doc.id)} + onCheckedChange={(v) => toggleOne(doc.id, !!v)} aria-label="Select row" /> diff --git a/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/page.tsx index 01891f05b..52eb3546c 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/documents/(manage)/page.tsx @@ -18,7 +18,7 @@ import { cacheKeys } from "@/lib/query-client/cache-keys"; import { DocumentsFilters } from "./components/DocumentsFilters"; import { DocumentsTableShell, type SortKey } from "./components/DocumentsTableShell"; import { PaginationControls } from "./components/PaginationControls"; -import type { ColumnVisibility, Document } from "./components/types"; +import type { ColumnVisibility } from "./components/types"; function useDebounced(value: T, delay = 250) { const [debounced, setDebounced] = useState(value); @@ -58,39 +58,30 @@ export default function DocumentsTable() { const { data: rawTypeCounts } = useAtomValue(documentTypeCountsAtom); const { mutateAsync: deleteDocumentMutation } = useAtomValue(deleteDocumentMutationAtom); - // Filter out SURFSENSE_DOCS from active types for regular documents API - const regularDocumentTypes = useMemo( - () => activeTypes.filter((t) => t !== "SURFSENSE_DOCS"), - [activeTypes] - ); - - // Check if only SURFSENSE_DOCS is selected (skip regular docs query) - const onlySurfsenseDocsSelected = activeTypes.length === 1 && activeTypes[0] === "SURFSENSE_DOCS"; - - // Build query parameters for fetching documents (excluding SURFSENSE_DOCS type) + // Build query parameters for fetching documents const queryParams = useMemo( () => ({ search_space_id: searchSpaceId, page: pageIndex, page_size: pageSize, - ...(regularDocumentTypes.length > 0 && { document_types: regularDocumentTypes }), + ...(activeTypes.length > 0 && { document_types: activeTypes }), }), - [searchSpaceId, pageIndex, pageSize, regularDocumentTypes] + [searchSpaceId, pageIndex, pageSize, activeTypes] ); - // Build search query parameters (excluding SURFSENSE_DOCS type) + // Build search query parameters const searchQueryParams = useMemo( () => ({ search_space_id: searchSpaceId, page: pageIndex, page_size: pageSize, title: debouncedSearch.trim(), - ...(regularDocumentTypes.length > 0 && { document_types: regularDocumentTypes }), + ...(activeTypes.length > 0 && { document_types: activeTypes }), }), - [searchSpaceId, pageIndex, pageSize, regularDocumentTypes, debouncedSearch] + [searchSpaceId, pageIndex, pageSize, activeTypes, debouncedSearch] ); - // Use query for fetching documents (disabled when only SURFSENSE_DOCS is selected) + // Use query for fetching documents const { data: documentsResponse, isLoading: isDocumentsLoading, @@ -100,10 +91,10 @@ export default function DocumentsTable() { queryKey: cacheKeys.documents.globalQueryParams(queryParams), queryFn: () => documentsApiService.getDocuments({ queryParams }), staleTime: 3 * 60 * 1000, // 3 minutes - enabled: !!searchSpaceId && !debouncedSearch.trim() && !onlySurfsenseDocsSelected, + enabled: !!searchSpaceId && !debouncedSearch.trim(), }); - // Use query for searching documents (disabled when only SURFSENSE_DOCS is selected) + // Use query for searching documents const { data: searchResponse, isLoading: isSearchLoading, @@ -113,7 +104,7 @@ export default function DocumentsTable() { queryKey: cacheKeys.documents.globalQueryParams(searchQueryParams), queryFn: () => documentsApiService.searchDocuments({ queryParams: searchQueryParams }), staleTime: 3 * 60 * 1000, // 3 minutes - enabled: !!searchSpaceId && !!debouncedSearch.trim() && !onlySurfsenseDocsSelected, + enabled: !!searchSpaceId && !!debouncedSearch.trim(), }); // Determine if we should show SurfSense docs (when no type filter or SURFSENSE_DOCS is selected) @@ -163,64 +154,16 @@ export default function DocumentsTable() { }, [rawTypeCounts, surfsenseDocsResponse?.total]); // Extract documents and total based on search state - const regularDocuments = debouncedSearch.trim() + const documents = debouncedSearch.trim() ? searchResponse?.items || [] : documentsResponse?.items || []; - const regularTotal = debouncedSearch.trim() - ? searchResponse?.total || 0 - : documentsResponse?.total || 0; + const total = debouncedSearch.trim() ? searchResponse?.total || 0 : documentsResponse?.total || 0; - // Merge regular documents with SurfSense docs - const documents = useMemo(() => { - // If filtering by type and not including SURFSENSE_DOCS, only show regular docs - if (activeTypes.length > 0 && !activeTypes.includes("SURFSENSE_DOCS" as DocumentTypeEnum)) { - return regularDocuments; - } - // If filtering only by SURFSENSE_DOCS, only show surfsense docs - if (activeTypes.length === 1 && activeTypes[0] === "SURFSENSE_DOCS") { - return surfsenseDocsAsDocuments; - } - // Otherwise, merge both (surfsense docs first) - return [...surfsenseDocsAsDocuments, ...regularDocuments]; - }, [regularDocuments, surfsenseDocsAsDocuments, activeTypes]); + const loading = debouncedSearch.trim() ? isSearchLoading : isDocumentsLoading; + const error = debouncedSearch.trim() ? searchError : documentsError; - const total = useMemo(() => { - if (activeTypes.length > 0 && !activeTypes.includes("SURFSENSE_DOCS" as DocumentTypeEnum)) { - return regularTotal; - } - if (activeTypes.length === 1 && activeTypes[0] === "SURFSENSE_DOCS") { - return surfsenseDocsResponse?.total || 0; - } - return regularTotal + (surfsenseDocsResponse?.total || 0); - }, [regularTotal, surfsenseDocsResponse?.total, activeTypes]); - - const loading = useMemo(() => { - // If only SURFSENSE_DOCS selected, only check surfsense loading - if (onlySurfsenseDocsSelected) { - return isSurfsenseDocsLoading; - } - // Otherwise check both regular docs and surfsense docs loading - const regularLoading = debouncedSearch.trim() ? isSearchLoading : isDocumentsLoading; - return regularLoading || (showSurfsenseDocs && isSurfsenseDocsLoading); - }, [ - onlySurfsenseDocsSelected, - isSurfsenseDocsLoading, - debouncedSearch, - isSearchLoading, - isDocumentsLoading, - showSurfsenseDocs, - ]); - - const error = useMemo(() => { - // If only SURFSENSE_DOCS selected, no regular docs errors - if (onlySurfsenseDocsSelected) { - return null; - } - return debouncedSearch.trim() ? searchError : documentsError; - }, [onlySurfsenseDocsSelected, debouncedSearch, searchError, documentsError]); - - // Display server-filtered results directly - const displayDocs = documents || []; + // Display results directly + const displayDocs = documents; const displayTotal = total; const pageStart = pageIndex * pageSize; const pageEnd = Math.min(pageStart + pageSize, displayTotal); @@ -240,33 +183,16 @@ export default function DocumentsTable() { if (isRefreshing) return; setIsRefreshing(true); try { - const refetchPromises: Promise[] = []; - // Only refetch regular documents if not in "only surfsense docs" mode - if (!onlySurfsenseDocsSelected) { - if (debouncedSearch.trim()) { - refetchPromises.push(refetchSearch()); - } else { - refetchPromises.push(refetchDocuments()); - } + if (debouncedSearch.trim()) { + await refetchSearch(); + } else { + await refetchDocuments(); } - if (showSurfsenseDocs) { - refetchPromises.push(refetchSurfsenseDocs()); - } - await Promise.all(refetchPromises); toast.success(t("refresh_success") || "Documents refreshed"); } finally { setIsRefreshing(false); } - }, [ - debouncedSearch, - refetchSearch, - refetchDocuments, - refetchSurfsenseDocs, - showSurfsenseDocs, - onlySurfsenseDocsSelected, - t, - isRefreshing, - ]); + }, [debouncedSearch, refetchSearch, refetchDocuments, t, isRefreshing]); // Create a delete function for single document deletion const deleteDocument = useCallback( @@ -357,7 +283,7 @@ export default function DocumentsTable() { createAttachmentAdapter(), []); @@ -306,12 +323,6 @@ export default function NewChatPage() { if (steps.length > 0) { restoredThinkingSteps.set(`msg-${msg.id}`, steps); } - // Hydrate write_todos plan state from persisted tool calls - // Disabled for now - // const writeTodosCalls = extractWriteTodosFromContent(msg.content); - // for (const todoData of writeTodosCalls) { - // hydratePlanState(todoData); - // } } if (msg.role === "user") { const docs = extractMentionedDocuments(msg.content); @@ -448,13 +459,27 @@ export default function NewChatPage() { // Add user message to state const userMsgId = `msg-user-${Date.now()}`; + + // Include author metadata for shared chats + const authorMetadata = + currentThread?.visibility === "SEARCH_SPACE" && currentUser + ? { + custom: { + author: { + displayName: currentUser.display_name ?? null, + avatarUrl: currentUser.avatar_url ?? null, + }, + }, + } + : undefined; + const userMessage: ThreadMessageLike = { id: userMsgId, role: "user", content: message.content, createdAt: new Date(), - // Include attachments so they can be displayed attachments: message.attachments || [], + metadata: authorMetadata, }; setMessages((prev) => [...prev, userMessage]); @@ -884,6 +909,8 @@ export default function NewChatPage() { setMentionedDocuments, setMessageDocumentsMap, queryClient, + currentThread, + currentUser, ] ); diff --git a/surfsense_web/app/dashboard/user/settings/components/ApiKeyContent.tsx b/surfsense_web/app/dashboard/user/settings/components/ApiKeyContent.tsx new file mode 100644 index 000000000..40e7b1d34 --- /dev/null +++ b/surfsense_web/app/dashboard/user/settings/components/ApiKeyContent.tsx @@ -0,0 +1,123 @@ +"use client"; + +import { Check, Copy, Key, Menu, Shield } from "lucide-react"; +import { AnimatePresence, motion } from "motion/react"; +import { useTranslations } from "next-intl"; +import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; +import { Button } from "@/components/ui/button"; +import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; +import { useApiKey } from "@/hooks/use-api-key"; + +interface ApiKeyContentProps { + onMenuClick: () => void; +} + +export function ApiKeyContent({ onMenuClick }: ApiKeyContentProps) { + const t = useTranslations("userSettings"); + const { apiKey, isLoading, copied, copyToClipboard } = useApiKey(); + + return ( + +
+
+ + +
+ + + + +
+

+ {t("api_key_title")} +

+

{t("api_key_description")}

+
+
+
+
+ + + + + + {t("api_key_warning_title")} + {t("api_key_warning_description")} + + +
+

{t("your_api_key")}

+ {isLoading ? ( +
+ ) : apiKey ? ( +
+
+ {apiKey} +
+ + + + + + {copied ? t("copied") : t("copy")} + + +
+ ) : ( +

{t("no_api_key")}

+ )} +
+ +
+

{t("usage_title")}

+

{t("usage_description")}

+
+									Authorization: Bearer {apiKey || "YOUR_API_KEY"}
+								
+
+ + +
+
+ + ); +} + diff --git a/surfsense_web/app/dashboard/user/settings/components/ProfileContent.tsx b/surfsense_web/app/dashboard/user/settings/components/ProfileContent.tsx new file mode 100644 index 000000000..fab978b49 --- /dev/null +++ b/surfsense_web/app/dashboard/user/settings/components/ProfileContent.tsx @@ -0,0 +1,181 @@ +"use client"; + +import { useAtomValue } from "jotai"; +import { Loader2, Menu, User } from "lucide-react"; +import { AnimatePresence, motion } from "motion/react"; +import { useTranslations } from "next-intl"; +import { useEffect, useState } from "react"; +import { toast } from "sonner"; +import { currentUserAtom } from "@/atoms/user/user-query.atoms"; +import { updateUserMutationAtom } from "@/atoms/user/user-mutation.atoms"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; + +interface ProfileContentProps { + onMenuClick: () => void; +} + +function AvatarDisplay({ url, fallback }: { url?: string; fallback: string }) { + const [hasError, setHasError] = useState(false); + + useEffect(() => { + setHasError(false); + }, [url]); + + if (url && !hasError) { + return ( + Avatar setHasError(true)} + /> + ); + } + + return ( +
+ {fallback} +
+ ); +} + +export function ProfileContent({ onMenuClick }: ProfileContentProps) { + const t = useTranslations("userSettings"); + const { data: user, isLoading: isUserLoading } = useAtomValue(currentUserAtom); + const { mutateAsync: updateUser, isPending } = useAtomValue(updateUserMutationAtom); + + const [displayName, setDisplayName] = useState(""); + + useEffect(() => { + if (user) { + setDisplayName(user.display_name || ""); + } + }, [user]); + + const getInitials = (email: string) => { + const name = email.split("@")[0]; + return name.slice(0, 2).toUpperCase(); + }; + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + + try { + await updateUser({ + display_name: displayName || null, + }); + toast.success(t("profile_saved")); + } catch { + toast.error(t("profile_save_error")); + } + }; + + const hasChanges = displayName !== (user?.display_name || ""); + + return ( + +
+
+ + +
+ + + + +
+

+ {t("profile_title")} +

+

{t("profile_description")}

+
+
+
+
+ + + + {isUserLoading ? ( +
+ +
+ ) : ( +
+
+
+
+ + +
+ +
+ + setDisplayName(e.target.value)} + /> +

+ {t("profile_display_name_hint")} +

+
+ +
+ + +
+
+
+ +
+ +
+
+ )} +
+
+
+
+
+ ); +} diff --git a/surfsense_web/app/dashboard/user/settings/components/UserSettingsSidebar.tsx b/surfsense_web/app/dashboard/user/settings/components/UserSettingsSidebar.tsx new file mode 100644 index 000000000..e25d318f3 --- /dev/null +++ b/surfsense_web/app/dashboard/user/settings/components/UserSettingsSidebar.tsx @@ -0,0 +1,155 @@ +"use client"; + +import { ArrowLeft, ChevronRight, X } from "lucide-react"; +import type { LucideIcon } from "lucide-react"; +import { AnimatePresence, motion } from "motion/react"; +import { useTranslations } from "next-intl"; +import { Button } from "@/components/ui/button"; +import { cn } from "@/lib/utils"; + +export interface SettingsNavItem { + id: string; + label: string; + description: string; + icon: LucideIcon; +} + +interface UserSettingsSidebarProps { + activeSection: string; + onSectionChange: (section: string) => void; + onBackToApp: () => void; + isOpen: boolean; + onClose: () => void; + navItems: SettingsNavItem[]; +} + +export function UserSettingsSidebar({ + activeSection, + onSectionChange, + onBackToApp, + isOpen, + onClose, + navItems, +}: UserSettingsSidebarProps) { + const t = useTranslations("userSettings"); + + const handleNavClick = (sectionId: string) => { + onSectionChange(sectionId); + onClose(); + }; + + return ( + <> + + {isOpen && ( + + )} + + + + + ); +} + diff --git a/surfsense_web/app/dashboard/user/settings/page.tsx b/surfsense_web/app/dashboard/user/settings/page.tsx index bf88e65e5..973b39076 100644 --- a/surfsense_web/app/dashboard/user/settings/page.tsx +++ b/surfsense_web/app/dashboard/user/settings/page.tsx @@ -1,286 +1,27 @@ "use client"; -import { - ArrowLeft, - Check, - ChevronRight, - Copy, - Key, - type LucideIcon, - Menu, - Shield, - X, -} from "lucide-react"; -import { AnimatePresence, motion } from "motion/react"; +import { Key, User } from "lucide-react"; +import { motion } from "motion/react"; import { useRouter } from "next/navigation"; import { useTranslations } from "next-intl"; import { useCallback, useState } from "react"; -import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; -import { Button } from "@/components/ui/button"; -import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; -import { useApiKey } from "@/hooks/use-api-key"; -import { cn } from "@/lib/utils"; - -interface SettingsNavItem { - id: string; - label: string; - description: string; - icon: LucideIcon; -} - -function UserSettingsSidebar({ - activeSection, - onSectionChange, - onBackToApp, - isOpen, - onClose, - navItems, -}: { - activeSection: string; - onSectionChange: (section: string) => void; - onBackToApp: () => void; - isOpen: boolean; - onClose: () => void; - navItems: SettingsNavItem[]; -}) { - const t = useTranslations("userSettings"); - - const handleNavClick = (sectionId: string) => { - onSectionChange(sectionId); - onClose(); - }; - - return ( - <> - - {isOpen && ( - - )} - - - - - ); -} - -function ApiKeyContent({ onMenuClick }: { onMenuClick: () => void }) { - const t = useTranslations("userSettings"); - const { apiKey, isLoading, copied, copyToClipboard } = useApiKey(); - - return ( - -
-
- - -
- - - - -
-

- {t("api_key_title")} -

-

{t("api_key_description")}

-
-
-
-
- - - - - - {t("api_key_warning_title")} - {t("api_key_warning_description")} - - -
-

{t("your_api_key")}

- {isLoading ? ( -
- ) : apiKey ? ( -
-
- {apiKey} -
- - - - - - {copied ? t("copied") : t("copy")} - - -
- ) : ( -

{t("no_api_key")}

- )} -
- -
-

{t("usage_title")}

-

{t("usage_description")}

-
-									Authorization: Bearer {apiKey || "YOUR_API_KEY"}
-								
-
- - -
-
- - ); -} +import { ApiKeyContent } from "./components/ApiKeyContent"; +import { ProfileContent } from "./components/ProfileContent"; +import { UserSettingsSidebar, type SettingsNavItem } from "./components/UserSettingsSidebar"; export default function UserSettingsPage() { const t = useTranslations("userSettings"); const router = useRouter(); - const [activeSection, setActiveSection] = useState("api-key"); + const [activeSection, setActiveSection] = useState("profile"); const [isSidebarOpen, setIsSidebarOpen] = useState(false); const navItems: SettingsNavItem[] = [ + { + id: "profile", + label: t("profile_nav_label"), + description: t("profile_nav_description"), + icon: User, + }, { id: "api-key", label: t("api_key_nav_label"), @@ -310,6 +51,9 @@ export default function UserSettingsPage() { onClose={() => setIsSidebarOpen(false)} navItems={navItems} /> + {activeSection === "profile" && ( + setIsSidebarOpen(true)} /> + )} {activeSection === "api-key" && ( setIsSidebarOpen(true)} /> )} diff --git a/surfsense_web/atoms/user/user-mutation.atoms.ts b/surfsense_web/atoms/user/user-mutation.atoms.ts new file mode 100644 index 000000000..02a9f2146 --- /dev/null +++ b/surfsense_web/atoms/user/user-mutation.atoms.ts @@ -0,0 +1,19 @@ +import { atomWithMutation, queryClientAtom } from "jotai-tanstack-query"; +import type { UpdateUserRequest } from "@/contracts/types/user.types"; +import { userApiService } from "@/lib/apis/user-api.service"; +import { cacheKeys } from "@/lib/query-client/cache-keys"; + +export const updateUserMutationAtom = atomWithMutation((get) => { + const queryClient = get(queryClientAtom); + + return { + mutationKey: cacheKeys.user.current(), + mutationFn: async (request: UpdateUserRequest) => { + return userApiService.updateMe(request); + }, + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: cacheKeys.user.current() }); + }, + }; +}); + diff --git a/surfsense_web/components/assistant-ui/thread.tsx b/surfsense_web/components/assistant-ui/thread.tsx index bf46e3d97..2507fb8a9 100644 --- a/surfsense_web/components/assistant-ui/thread.tsx +++ b/surfsense_web/components/assistant-ui/thread.tsx @@ -19,9 +19,7 @@ import { ChevronRightIcon, CopyIcon, DownloadIcon, - FileText, Loader2, - PencilIcon, RefreshCwIcon, SquareIcon, } from "lucide-react"; @@ -31,7 +29,6 @@ import { createPortal } from "react-dom"; import { mentionedDocumentIdsAtom, mentionedDocumentsAtom, - messageDocumentsMapAtom, } from "@/atoms/chat/mentioned-documents.atom"; import { globalNewLLMConfigsAtom, @@ -42,8 +39,8 @@ import { currentUserAtom } from "@/atoms/user/user-query.atoms"; import { ComposerAddAttachment, ComposerAttachments, - UserMessageAttachments, } from "@/components/assistant-ui/attachment"; +import { UserMessage } from "@/components/assistant-ui/user-message"; import { ConnectorIndicator } from "@/components/assistant-ui/connector-popup"; import { InlineMentionEditor, @@ -639,69 +636,6 @@ const AssistantActionBar: FC = () => { ); }; -const UserMessage: FC = () => { - const messageId = useAssistantState(({ message }) => message?.id); - const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom); - const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined; - const hasAttachments = useAssistantState( - ({ message }) => message?.attachments && message.attachments.length > 0 - ); - - return ( - -
- {/* Display attachments and mentioned documents */} - {(hasAttachments || (mentionedDocs && mentionedDocs.length > 0)) && ( -
- {/* Attachments (images show as thumbnails, documents as chips) */} - - {/* Mentioned documents as chips */} - {mentionedDocs?.map((doc) => ( - - - {doc.title} - - ))} -
- )} - {/* Message bubble with action bar positioned relative to it */} -
-
- -
-
- -
-
-
- - -
- ); -}; - -const UserActionBar: FC = () => { - return ( - - - - - - - - ); -}; const EditComposer: FC = () => { return ( diff --git a/surfsense_web/components/assistant-ui/user-message.tsx b/surfsense_web/components/assistant-ui/user-message.tsx index 745542304..15b5461b6 100644 --- a/surfsense_web/components/assistant-ui/user-message.tsx +++ b/surfsense_web/components/assistant-ui/user-message.tsx @@ -1,16 +1,54 @@ import { ActionBarPrimitive, MessagePrimitive, useAssistantState } from "@assistant-ui/react"; import { useAtomValue } from "jotai"; import { FileText, PencilIcon } from "lucide-react"; -import type { FC } from "react"; +import { type FC, useState } from "react"; import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom"; import { UserMessageAttachments } from "@/components/assistant-ui/attachment"; import { BranchPicker } from "@/components/assistant-ui/branch-picker"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; +interface AuthorMetadata { + displayName: string | null; + avatarUrl: string | null; +} + +const UserAvatar: FC = ({ displayName, avatarUrl }) => { + const [hasError, setHasError] = useState(false); + + const initials = displayName + ? displayName + .split(" ") + .map((n) => n[0]) + .join("") + .toUpperCase() + .slice(0, 2) + : "U"; + + if (avatarUrl && !hasError) { + return ( + {displayName setHasError(true)} + /> + ); + } + + return ( +
+ {initials} +
+ ); +}; + export const UserMessage: FC = () => { const messageId = useAssistantState(({ message }) => message?.id); const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom); const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined; + const metadata = useAssistantState(({ message }) => message?.metadata); + const author = metadata?.custom?.author as AuthorMetadata | undefined; const hasAttachments = useAssistantState( ({ message }) => message?.attachments && message.attachments.length > 0 ); @@ -20,34 +58,42 @@ export const UserMessage: FC = () => { className="aui-user-message-root fade-in slide-in-from-bottom-1 mx-auto grid w-full max-w-(--thread-max-width) animate-in auto-rows-auto grid-cols-[minmax(72px,1fr)_auto] content-start gap-y-2 px-2 py-3 duration-150 [&:where(>*)]:col-start-2" data-role="user" > -
- {/* Display attachments and mentioned documents */} - {(hasAttachments || (mentionedDocs && mentionedDocs.length > 0)) && ( -
- {/* Attachments (images show as thumbnails, documents as chips) */} - - {/* Mentioned documents as chips */} - {mentionedDocs?.map((doc) => ( - - - {doc.title} - - ))} -
- )} - {/* Message bubble with action bar positioned relative to it */} -
-
- -
-
- +
+
+ {/* Display attachments and mentioned documents */} + {(hasAttachments || (mentionedDocs && mentionedDocs.length > 0)) && ( +
+ {/* Attachments (images show as thumbnails, documents as chips) */} + + {/* Mentioned documents as chips */} + {mentionedDocs?.map((doc) => ( + + + {doc.title} + + ))} +
+ )} + {/* Message bubble with action bar positioned relative to it */} +
+
+ +
+
+ +
+ {/* User avatar - only shown in shared chats */} + {author && ( +
+ +
+ )}
diff --git a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx index 3d4e5630d..95ff5d782 100644 --- a/surfsense_web/components/layout/providers/LayoutDataProvider.tsx +++ b/surfsense_web/components/layout/providers/LayoutDataProvider.tsx @@ -354,7 +354,11 @@ export function LayoutDataProvider({ onChatDelete={handleChatDelete} onViewAllSharedChats={handleViewAllSharedChats} onViewAllPrivateChats={handleViewAllPrivateChats} - user={{ email: user?.email || "", name: user?.email?.split("@")[0] }} + user={{ + email: user?.email || "", + name: user?.display_name || user?.email?.split("@")[0], + avatarUrl: user?.avatar_url || undefined, + }} onSettings={handleSettings} onManageMembers={handleManageMembers} onUserSettings={handleUserSettings} diff --git a/surfsense_web/components/layout/types/layout.types.ts b/surfsense_web/components/layout/types/layout.types.ts index 73ac98fa5..3eac64e60 100644 --- a/surfsense_web/components/layout/types/layout.types.ts +++ b/surfsense_web/components/layout/types/layout.types.ts @@ -12,6 +12,7 @@ export interface SearchSpace { export interface User { email: string; name?: string; + avatarUrl?: string; } export interface NavItem { diff --git a/surfsense_web/components/layout/ui/sidebar/SidebarUserProfile.tsx b/surfsense_web/components/layout/ui/sidebar/SidebarUserProfile.tsx index d3e97c8eb..f67dbf7c6 100644 --- a/surfsense_web/components/layout/ui/sidebar/SidebarUserProfile.tsx +++ b/surfsense_web/components/layout/ui/sidebar/SidebarUserProfile.tsx @@ -61,6 +61,39 @@ function getInitials(email: string): string { return name.slice(0, 2).toUpperCase(); } +/** + * User avatar component - shows image if available, otherwise falls back to initials + */ +function UserAvatar({ + avatarUrl, + initials, + bgColor, +}: { + avatarUrl?: string; + initials: string; + bgColor: string; +}) { + if (avatarUrl) { + return ( + User avatar + ); + } + + return ( +
+ {initials} +
+ ); +} + export function SidebarUserProfile({ user, onUserSettings, @@ -88,12 +121,7 @@ export function SidebarUserProfile({ "focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring" )} > -
- {initials} -
+ {displayName} @@ -104,12 +132,7 @@ export function SidebarUserProfile({
-
- {initials} -
+

{displayName}

{user.email}

@@ -149,13 +172,7 @@ export function SidebarUserProfile({ "focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring" )} > - {/* Avatar */} -
- {initials} -
+ {/* Name and email */}
@@ -171,12 +188,7 @@ export function SidebarUserProfile({
-
- {initials} -
+

{displayName}

{user.email}

diff --git a/surfsense_web/components/new-chat/document-mention-picker.tsx b/surfsense_web/components/new-chat/document-mention-picker.tsx index e89885b1d..ba9e4ea95 100644 --- a/surfsense_web/components/new-chat/document-mention-picker.tsx +++ b/surfsense_web/components/new-chat/document-mention-picker.tsx @@ -215,6 +215,16 @@ export const DocumentMentionPicker = forwardRef< isSurfsenseDocsLoading) && currentPage === 0; + // Split documents into SurfSense docs and user docs for grouped rendering + const surfsenseDocsList = useMemo( + () => actualDocuments.filter((doc) => doc.document_type === "SURFSENSE_DOCS"), + [actualDocuments] + ); + const userDocsList = useMemo( + () => actualDocuments.filter((doc) => doc.document_type !== "SURFSENSE_DOCS"), + [actualDocuments] + ); + // Track already selected documents using unique key (document_type:id) to avoid ID collisions const selectedKeys = useMemo( () => new Set(initialSelectedDocuments.map((d) => `${d.document_type}:${d.id}`)), @@ -324,47 +334,102 @@ export const DocumentMentionPicker = forwardRef<
) : (
- {actualDocuments.map((doc) => { - const docKey = `${doc.document_type}:${doc.id}`; - const isAlreadySelected = selectedKeys.has(docKey); - const selectableIndex = selectableDocuments.findIndex( - (d) => d.document_type === doc.document_type && d.id === doc.id - ); - const isHighlighted = !isAlreadySelected && selectableIndex === highlightedIndex; + {/* SurfSense Documentation Section */} + {surfsenseDocsList.length > 0 && ( + <> +
+ SurfSense Docs +
+ {surfsenseDocsList.map((doc) => { + const docKey = `${doc.document_type}:${doc.id}`; + const isAlreadySelected = selectedKeys.has(docKey); + const selectableIndex = selectableDocuments.findIndex( + (d) => d.document_type === doc.document_type && d.id === doc.id + ); + const isHighlighted = !isAlreadySelected && selectableIndex === highlightedIndex; + + return ( + + ); + })} + + )} + + {/* User Documents Section */} + {userDocsList.length > 0 && ( + <> +
+ Your Documents +
+ {userDocsList.map((doc) => { + const docKey = `${doc.document_type}:${doc.id}`; + const isAlreadySelected = selectedKeys.has(docKey); + const selectableIndex = selectableDocuments.findIndex( + (d) => d.document_type === doc.document_type && d.id === doc.id + ); + const isHighlighted = !isAlreadySelected && selectableIndex === highlightedIndex; + + return ( + + ); + })} + + )} - return ( - - ); - })} {/* Loading indicator for additional pages */} {isLoadingMore && (
diff --git a/surfsense_web/contracts/types/user.types.ts b/surfsense_web/contracts/types/user.types.ts index f2d1f0ffc..85fee49a8 100644 --- a/surfsense_web/contracts/types/user.types.ts +++ b/surfsense_web/contracts/types/user.types.ts @@ -8,6 +8,8 @@ export const user = z.object({ is_verified: z.boolean(), pages_limit: z.number(), pages_used: z.number(), + display_name: z.string().nullish(), + avatar_url: z.string().nullish(), }); /** @@ -15,5 +17,20 @@ export const user = z.object({ */ export const getMeResponse = user; +/** + * Update current user request + */ +export const updateUserRequest = z.object({ + display_name: z.string().nullish(), + avatar_url: z.string().nullish(), +}); + +/** + * Update current user response + */ +export const updateUserResponse = user; + export type User = z.infer; export type GetMeResponse = z.infer; +export type UpdateUserRequest = z.infer; +export type UpdateUserResponse = z.infer; diff --git a/surfsense_web/lib/apis/user-api.service.ts b/surfsense_web/lib/apis/user-api.service.ts index ea46ac116..94914ebaa 100644 --- a/surfsense_web/lib/apis/user-api.service.ts +++ b/surfsense_web/lib/apis/user-api.service.ts @@ -1,4 +1,8 @@ -import { getMeResponse } from "@/contracts/types/user.types"; +import { + getMeResponse, + updateUserResponse, + type UpdateUserRequest, +} from "@/contracts/types/user.types"; import { baseApiService } from "./base-api.service"; class UserApiService { @@ -8,6 +12,15 @@ class UserApiService { getMe = async () => { return baseApiService.get(`/users/me`, getMeResponse); }; + + /** + * Update current authenticated user + */ + updateMe = async (request: UpdateUserRequest) => { + return baseApiService.patch(`/users/me`, updateUserResponse, { + body: request, + }); + }; } export const userApiService = new UserApiService(); diff --git a/surfsense_web/lib/chat/thread-persistence.ts b/surfsense_web/lib/chat/thread-persistence.ts index 5c65ad47e..23dd35800 100644 --- a/surfsense_web/lib/chat/thread-persistence.ts +++ b/surfsense_web/lib/chat/thread-persistence.ts @@ -31,6 +31,9 @@ export interface MessageRecord { role: "user" | "assistant" | "system"; content: unknown; created_at: string; + author_id?: string | null; + author_display_name?: string | null; + author_avatar_url?: string | null; } export interface ThreadListResponse { diff --git a/surfsense_web/lib/env-config.ts b/surfsense_web/lib/env-config.ts index 5e35b160c..6201a0425 100644 --- a/surfsense_web/lib/env-config.ts +++ b/surfsense_web/lib/env-config.ts @@ -1,10 +1,10 @@ /** * Environment configuration for the frontend. - * + * * This file centralizes access to NEXT_PUBLIC_* environment variables. * For Docker deployments, these placeholders are replaced at container startup * via sed in the entrypoint script. - * + * * IMPORTANT: Do not use template literals or complex expressions with these values * as it may prevent the sed replacement from working correctly. */ @@ -24,5 +24,5 @@ export const ETL_SERVICE = process.env.NEXT_PUBLIC_ETL_SERVICE || "DOCLING"; // Helper to check if local auth is enabled export const isLocalAuth = () => AUTH_TYPE === "LOCAL"; -// Helper to check if Google auth is enabled +// Helper to check if Google auth is enabled export const isGoogleAuth = () => AUTH_TYPE === "GOOGLE"; diff --git a/surfsense_web/messages/en.json b/surfsense_web/messages/en.json index 84bacbfd2..8fe5dc524 100644 --- a/surfsense_web/messages/en.json +++ b/surfsense_web/messages/en.json @@ -109,6 +109,17 @@ "title": "User Settings", "description": "Manage your account settings and API access", "back_to_app": "Back to app", + "profile_nav_label": "Profile", + "profile_nav_description": "Manage your display name and avatar", + "profile_title": "Profile", + "profile_description": "Update your personal information", + "profile_avatar": "Profile Picture", + "profile_display_name": "Display Name", + "profile_display_name_hint": "This is how your name appears across the app", + "profile_email": "Email", + "profile_save": "Save Changes", + "profile_saved": "Profile updated successfully", + "profile_save_error": "Failed to update profile", "api_key_nav_label": "API Key", "api_key_nav_description": "Manage your API access token", "api_key_title": "API Key",