Merge pull request #699 from MODSetter/dev

feat: fixed file uploads timeouts, migrations and various ux updates
This commit is contained in:
Rohan Verma 2026-01-15 00:11:26 -08:00 committed by GitHub
commit 20a25ef02a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
52 changed files with 2746 additions and 820 deletions

View file

@ -0,0 +1,54 @@
"""Initial schema setup
Revision ID: 0
Revises: None
Creates all tables from SQLAlchemy models. Idempotent - safe to run on existing databases.
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "0"
down_revision: str | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
from app.db import Base
connection = op.get_bind()
# Create tables
op.execute(sa.text("CREATE EXTENSION IF NOT EXISTS vector"))
Base.metadata.create_all(bind=connection)
# Set up indexes
op.execute(
sa.text(
"CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)"
)
)
op.execute(
sa.text(
"CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector('english', content))"
)
)
op.execute(
sa.text(
"CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)"
)
)
op.execute(
sa.text(
"CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector('english', content))"
)
)
def downgrade() -> None:
pass

View file

@ -6,6 +6,8 @@ Revises: 9
from collections.abc import Sequence from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
@ -18,9 +20,35 @@ depends_on: str | Sequence[str] | None = None
CHAT_TYPE_ENUM = "chattype" CHAT_TYPE_ENUM = "chattype"
def enum_exists(enum_name: str) -> bool:
"""Check if an enum type exists in the database."""
conn = op.get_bind()
result = conn.execute(
sa.text("SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)"),
{"enum_name": enum_name},
)
return result.scalar()
def table_exists(table_name: str) -> bool:
"""Check if a table exists in the database."""
conn = op.get_bind()
result = conn.execute(
sa.text(
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)"
),
{"table_name": table_name},
)
return result.scalar()
def upgrade() -> None: def upgrade() -> None:
"""Upgrade schema - replace ChatType enum values with new QNA/REPORT structure.""" """Upgrade schema - replace ChatType enum values with new QNA/REPORT structure."""
# Skip if chats table or chattype enum doesn't exist (fresh database)
if not table_exists("chats") or not enum_exists(CHAT_TYPE_ENUM):
return
# Old enum name for temporary storage # Old enum name for temporary storage
old_enum_name = f"{CHAT_TYPE_ENUM}_old" old_enum_name = f"{CHAT_TYPE_ENUM}_old"
@ -72,6 +100,10 @@ def upgrade() -> None:
def downgrade() -> None: def downgrade() -> None:
"""Downgrade schema - revert ChatType enum to old GENERAL/DEEP/DEEPER/DEEPEST structure.""" """Downgrade schema - revert ChatType enum to old GENERAL/DEEP/DEEPER/DEEPEST structure."""
# Skip if chats table or chattype enum doesn't exist
if not table_exists("chats") or not enum_exists(CHAT_TYPE_ENUM):
return
# Old enum name for temporary storage # Old enum name for temporary storage
old_enum_name = f"{CHAT_TYPE_ENUM}_old" old_enum_name = f"{CHAT_TYPE_ENUM}_old"

View file

@ -7,22 +7,34 @@ Revises:
from collections.abc import Sequence from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op from alembic import op
# Import pgvector if needed for other types, though not for this ENUM change
# import pgvector
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "1" revision: str = "1"
down_revision: str | None = None down_revision: str | None = "0"
branch_labels: str | Sequence[str] | None = None branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None
def enum_exists(enum_name: str) -> bool:
"""Check if an enum type exists in the database."""
conn = op.get_bind()
result = conn.execute(
sa.text("SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)"),
{"enum_name": enum_name},
)
return result.scalar()
def upgrade() -> None: def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
# Skip if the enum doesn't exist (fresh DB after downgrade - create_db_and_tables will handle it)
if not enum_exists("searchsourceconnectortype"):
return
# Manually add the command to add the enum value # Manually add the command to add the enum value
# Note: It's generally better to let autogenerate handle this, but we're bypassing it # Note: It's generally better to let autogenerate handle this, but we're bypassing it
op.execute( op.execute(
@ -51,6 +63,10 @@ END$$;
def downgrade() -> None: def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
# Skip if the enum doesn't exist
if not enum_exists("searchsourceconnectortype"):
return
# Downgrading removal of an enum value is complex and potentially dangerous # Downgrading removal of an enum value is complex and potentially dangerous
# if the value is in use. Often omitted or requires manual SQL based on context. # if the value is in use. Often omitted or requires manual SQL based on context.
# For now, we'll just pass. If you needed to reverse this, you'd likely # For now, we'll just pass. If you needed to reverse this, you'd likely

View file

@ -7,6 +7,8 @@ Revises: 23
from collections.abc import Sequence from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
@ -16,11 +18,27 @@ branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None
def table_exists(table_name: str) -> bool:
"""Check if a table exists in the database."""
conn = op.get_bind()
result = conn.execute(
sa.text(
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)"
),
{"table_name": table_name},
)
return result.scalar()
def upgrade() -> None: def upgrade() -> None:
""" """
Fix any chats with NULL type values by setting them to QNA. Fix any chats with NULL type values by setting them to QNA.
This handles edge cases from previous migrations where type values were not properly migrated. This handles edge cases from previous migrations where type values were not properly migrated.
""" """
# Skip if chats table doesn't exist (fresh database)
if not table_exists("chats"):
return
# Update any NULL type values to QNA (the default chat type) # Update any NULL type values to QNA (the default chat type)
op.execute( op.execute(
""" """

View file

@ -10,6 +10,8 @@ Revises: 33
from collections.abc import Sequence from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op from alembic import op
# revision identifiers # revision identifiers
@ -19,42 +21,59 @@ branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None
def table_exists(table_name: str) -> bool:
"""Check if a table exists in the database."""
conn = op.get_bind()
result = conn.execute(
sa.text(
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)"
),
{"table_name": table_name},
)
return result.scalar()
def upgrade() -> None: def upgrade() -> None:
"""Add columns only if they don't already exist (safe for re-runs).""" """Add columns only if they don't already exist (safe for re-runs)."""
# Add 'state_version' column to chats table (default 1) # Add 'state_version' column to chats table (default 1)
op.execute(""" # Skip if chats table doesn't exist (fresh database)
ALTER TABLE chats if table_exists("chats"):
ADD COLUMN IF NOT EXISTS state_version BIGINT DEFAULT 1 NOT NULL op.execute("""
""") ALTER TABLE chats
ADD COLUMN IF NOT EXISTS state_version BIGINT DEFAULT 1 NOT NULL
""")
# Add 'chat_state_version' column to podcasts table # Add 'chat_state_version' column to podcasts table
op.execute(""" if table_exists("podcasts"):
ALTER TABLE podcasts op.execute("""
ADD COLUMN IF NOT EXISTS chat_state_version BIGINT ALTER TABLE podcasts
""") ADD COLUMN IF NOT EXISTS chat_state_version BIGINT
""")
# Add 'chat_id' column to podcasts table # Add 'chat_id' column to podcasts table
op.execute(""" op.execute("""
ALTER TABLE podcasts ALTER TABLE podcasts
ADD COLUMN IF NOT EXISTS chat_id INTEGER ADD COLUMN IF NOT EXISTS chat_id INTEGER
""") """)
def downgrade() -> None: def downgrade() -> None:
"""Remove columns only if they exist.""" """Remove columns only if they exist."""
op.execute(""" if table_exists("podcasts"):
ALTER TABLE podcasts op.execute("""
DROP COLUMN IF EXISTS chat_state_version ALTER TABLE podcasts
""") DROP COLUMN IF EXISTS chat_state_version
""")
op.execute(""" op.execute("""
ALTER TABLE podcasts ALTER TABLE podcasts
DROP COLUMN IF EXISTS chat_id DROP COLUMN IF EXISTS chat_id
""") """)
op.execute(""" if table_exists("chats"):
ALTER TABLE chats op.execute("""
DROP COLUMN IF EXISTS state_version ALTER TABLE chats
""") DROP COLUMN IF EXISTS state_version
""")

View file

@ -62,8 +62,25 @@ def parse_timestamp(ts, fallback):
return fallback return fallback
def table_exists(table_name: str) -> bool:
"""Check if a table exists in the database."""
conn = op.get_bind()
result = conn.execute(
sa.text(
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)"
),
{"table_name": table_name},
)
return result.scalar()
def upgrade() -> None: def upgrade() -> None:
"""Migrate old chats to new_chat_threads and remove old tables.""" """Migrate old chats to new_chat_threads and remove old tables."""
# Skip if chats table doesn't exist (fresh database)
if not table_exists("chats"):
print("[Migration 49] Chats table does not exist, skipping migration")
return
connection = op.get_bind() connection = op.get_bind()
# Get all old chats # Get all old chats
@ -176,36 +193,47 @@ def upgrade() -> None:
print("[Migration 49] Migration complete!") print("[Migration 49] Migration complete!")
def enum_exists(enum_name: str) -> bool:
"""Check if an enum type exists in the database."""
conn = op.get_bind()
result = conn.execute(
sa.text("SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)"),
{"enum_name": enum_name},
)
return result.scalar()
def downgrade() -> None: def downgrade() -> None:
"""Recreate old chats table (data cannot be restored).""" """Recreate old chats table (data cannot be restored)."""
# Recreate chattype enum # Skip if chats table already exists
if table_exists("chats"):
print("[Migration 49 Downgrade] Chats table already exists, skipping")
return
# Recreate chattype enum if it doesn't exist
if not enum_exists("chattype"):
op.execute(
sa.text("""
CREATE TYPE chattype AS ENUM ('QNA')
""")
)
# Recreate chats table using raw SQL to avoid SQLAlchemy trying to create the enum
op.execute( op.execute(
sa.text(""" sa.text("""
CREATE TYPE chattype AS ENUM ('QNA') CREATE TABLE chats (
id SERIAL PRIMARY KEY,
type chattype NOT NULL,
title VARCHAR NOT NULL,
initial_connectors VARCHAR[],
messages JSON NOT NULL,
state_version BIGINT NOT NULL DEFAULT 1,
search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
)
""") """)
) )
op.execute(sa.text("CREATE INDEX ix_chats_id ON chats (id)"))
# Recreate chats table op.execute(sa.text("CREATE INDEX ix_chats_title ON chats (title)"))
op.create_table(
"chats",
sa.Column("id", sa.Integer(), primary_key=True, index=True),
sa.Column("type", sa.Enum("QNA", name="chattype"), nullable=False),
sa.Column("title", sa.String(), nullable=False, index=True),
sa.Column("initial_connectors", sa.ARRAY(sa.String()), nullable=True),
sa.Column("messages", sa.JSON(), nullable=False),
sa.Column("state_version", sa.BigInteger(), nullable=False, default=1),
sa.Column(
"search_space_id",
sa.Integer(),
sa.ForeignKey("searchspaces.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
print("[Migration 49 Downgrade] Chats table recreated (data not restored)") print("[Migration 49 Downgrade] Chats table recreated (data not restored)")

View file

@ -39,7 +39,7 @@ def upgrade():
""" """
) )
# Rename columns (only if they exist with old names) # Rename columns (only if source exists and target doesn't already exist)
op.execute( op.execute(
""" """
DO $$ DO $$
@ -47,6 +47,9 @@ def upgrade():
IF EXISTS ( IF EXISTS (
SELECT 1 FROM information_schema.columns SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'fast_llm_id' WHERE table_name = 'searchspaces' AND column_name = 'fast_llm_id'
) AND NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'agent_llm_id'
) THEN ) THEN
ALTER TABLE searchspaces RENAME COLUMN fast_llm_id TO agent_llm_id; ALTER TABLE searchspaces RENAME COLUMN fast_llm_id TO agent_llm_id;
END IF; END IF;
@ -61,6 +64,9 @@ def upgrade():
IF EXISTS ( IF EXISTS (
SELECT 1 FROM information_schema.columns SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'long_context_llm_id' WHERE table_name = 'searchspaces' AND column_name = 'long_context_llm_id'
) AND NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'document_summary_llm_id'
) THEN ) THEN
ALTER TABLE searchspaces RENAME COLUMN long_context_llm_id TO document_summary_llm_id; ALTER TABLE searchspaces RENAME COLUMN long_context_llm_id TO document_summary_llm_id;
END IF; END IF;
@ -100,7 +106,7 @@ def downgrade():
""" """
) )
# Rename columns back # Rename columns back (only if source exists and target doesn't already exist)
op.execute( op.execute(
""" """
DO $$ DO $$
@ -108,6 +114,9 @@ def downgrade():
IF EXISTS ( IF EXISTS (
SELECT 1 FROM information_schema.columns SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'agent_llm_id' WHERE table_name = 'searchspaces' AND column_name = 'agent_llm_id'
) AND NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'fast_llm_id'
) THEN ) THEN
ALTER TABLE searchspaces RENAME COLUMN agent_llm_id TO fast_llm_id; ALTER TABLE searchspaces RENAME COLUMN agent_llm_id TO fast_llm_id;
END IF; END IF;
@ -122,6 +131,9 @@ def downgrade():
IF EXISTS ( IF EXISTS (
SELECT 1 FROM information_schema.columns SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'document_summary_llm_id' WHERE table_name = 'searchspaces' AND column_name = 'document_summary_llm_id'
) AND NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'long_context_llm_id'
) THEN ) THEN
ALTER TABLE searchspaces RENAME COLUMN document_summary_llm_id TO long_context_llm_id; ALTER TABLE searchspaces RENAME COLUMN document_summary_llm_id TO long_context_llm_id;
END IF; END IF;

View file

@ -60,14 +60,28 @@ def downgrade() -> None:
connection = op.get_bind() connection = op.get_bind()
connection.execute( # Only update if the target enum value exists (it won't on fresh databases)
result = connection.execute(
text( text(
""" """
UPDATE documents SELECT EXISTS (
SET document_type = 'GOOGLE_DRIVE_CONNECTOR' SELECT 1 FROM pg_type t
WHERE document_type = 'GOOGLE_DRIVE_FILE'; JOIN pg_enum e ON t.oid = e.enumtypid
WHERE t.typname = 'documenttype' AND e.enumlabel = 'GOOGLE_DRIVE_CONNECTOR'
);
""" """
) )
) )
enum_exists = result.scalar()
connection.commit() if enum_exists:
connection.execute(
text(
"""
UPDATE documents
SET document_type = 'GOOGLE_DRIVE_CONNECTOR'
WHERE document_type = 'GOOGLE_DRIVE_FILE';
"""
)
)
connection.commit()

View file

@ -18,59 +18,77 @@ branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None
def upgrade() -> None: def table_exists(table_name: str) -> bool:
# Alter Chat table """Check if a table exists in the database."""
op.alter_column( conn = op.get_bind()
"chats", result = conn.execute(
"title", sa.text(
existing_type=sa.String(200), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)"
type_=sa.String(), ),
existing_nullable=False, {"table_name": table_name},
) )
return result.scalar()
def upgrade() -> None:
# Alter Chat table (may not exist on fresh databases, removed in migration 49)
if table_exists("chats"):
op.alter_column(
"chats",
"title",
existing_type=sa.String(200),
type_=sa.String(),
existing_nullable=False,
)
# Alter Document table # Alter Document table
op.alter_column( if table_exists("documents"):
"documents", op.alter_column(
"title", "documents",
existing_type=sa.String(200), "title",
type_=sa.String(), existing_type=sa.String(200),
existing_nullable=False, type_=sa.String(),
) existing_nullable=False,
)
# Alter Podcast table # Alter Podcast table
op.alter_column( if table_exists("podcasts"):
"podcasts", op.alter_column(
"title", "podcasts",
existing_type=sa.String(200), "title",
type_=sa.String(), existing_type=sa.String(200),
existing_nullable=False, type_=sa.String(),
) existing_nullable=False,
)
def downgrade() -> None: def downgrade() -> None:
# Revert Chat table # Revert Chat table
op.alter_column( if table_exists("chats"):
"chats", op.alter_column(
"title", "chats",
existing_type=sa.String(), "title",
type_=sa.String(200), existing_type=sa.String(),
existing_nullable=False, type_=sa.String(200),
) existing_nullable=False,
)
# Revert Document table # Revert Document table
op.alter_column( if table_exists("documents"):
"documents", op.alter_column(
"title", "documents",
existing_type=sa.String(), "title",
type_=sa.String(200), existing_type=sa.String(),
existing_nullable=False, type_=sa.String(200),
) existing_nullable=False,
)
# Revert Podcast table # Revert Podcast table
op.alter_column( if table_exists("podcasts"):
"podcasts", op.alter_column(
"title", "podcasts",
existing_type=sa.String(), "title",
type_=sa.String(200), existing_type=sa.String(),
existing_nullable=False, type_=sa.String(200),
) existing_nullable=False,
)

View file

@ -0,0 +1,38 @@
"""Add MCP connector type
Revision ID: 62
Revises: 61
Create Date: 2026-01-09 15:19:51.827647
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "62"
down_revision: str | None = "61"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add MCP_CONNECTOR to SearchSourceConnectorType enum."""
# Add new enum value using raw SQL
op.execute(
"""
ALTER TYPE searchsourceconnectortype ADD VALUE IF NOT EXISTS 'MCP_CONNECTOR';
"""
)
def downgrade() -> None:
"""Remove MCP_CONNECTOR from SearchSourceConnectorType enum."""
# Note: PostgreSQL does not support removing enum values directly.
# To downgrade, you would need to:
# 1. Create a new enum without MCP_CONNECTOR
# 2. Alter the column to use the new enum
# 3. Drop the old enum
# This is left as a manual operation if needed.
pass

View file

@ -0,0 +1,97 @@
"""allow_multiple_connectors_with_unique_names
Revision ID: 63
Revises: 62
Create Date: 2026-01-13 12:23:31.481643
"""
from collections.abc import Sequence
from sqlalchemy import text
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "63"
down_revision: str | None = "62"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
connection = op.get_bind()
# Check if old constraint exists before trying to drop it
old_constraint_exists = connection.execute(
text("""
SELECT 1 FROM information_schema.table_constraints
WHERE table_name='search_source_connectors'
AND constraint_type='UNIQUE'
AND constraint_name='uq_searchspace_user_connector_type'
""")
).scalar()
if old_constraint_exists:
op.drop_constraint(
"uq_searchspace_user_connector_type",
"search_source_connectors",
type_="unique",
)
# Check if new constraint already exists before creating it
new_constraint_exists = connection.execute(
text("""
SELECT 1 FROM information_schema.table_constraints
WHERE table_name='search_source_connectors'
AND constraint_type='UNIQUE'
AND constraint_name='uq_searchspace_user_connector_type_name'
""")
).scalar()
if not new_constraint_exists:
op.create_unique_constraint(
"uq_searchspace_user_connector_type_name",
"search_source_connectors",
["search_space_id", "user_id", "connector_type", "name"],
)
def downgrade() -> None:
"""Downgrade schema."""
connection = op.get_bind()
# Check if new constraint exists before trying to drop it
new_constraint_exists = connection.execute(
text("""
SELECT 1 FROM information_schema.table_constraints
WHERE table_name='search_source_connectors'
AND constraint_type='UNIQUE'
AND constraint_name='uq_searchspace_user_connector_type_name'
""")
).scalar()
if new_constraint_exists:
op.drop_constraint(
"uq_searchspace_user_connector_type_name",
"search_source_connectors",
type_="unique",
)
# Check if old constraint already exists before creating it
old_constraint_exists = connection.execute(
text("""
SELECT 1 FROM information_schema.table_constraints
WHERE table_name='search_source_connectors'
AND constraint_type='UNIQUE'
AND constraint_name='uq_searchspace_user_connector_type'
""")
).scalar()
if not old_constraint_exists:
op.create_unique_constraint(
"uq_searchspace_user_connector_type",
"search_source_connectors",
["search_space_id", "user_id", "connector_type"],
)

View file

@ -0,0 +1,72 @@
"""Add display_name and avatar_url columns to user table
This migration adds:
- display_name column for user's full name from OAuth
- avatar_url column for user's profile picture URL from OAuth
Revision ID: 64
Revises: 63
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "64"
down_revision: str | None = "63"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add display_name and avatar_url columns to user table."""
# Add display_name column (nullable for existing users)
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'user' AND column_name = 'display_name'
) THEN
ALTER TABLE "user"
ADD COLUMN display_name VARCHAR;
END IF;
END$$;
"""
)
# Add avatar_url column (nullable for existing users)
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'user' AND column_name = 'avatar_url'
) THEN
ALTER TABLE "user"
ADD COLUMN avatar_url VARCHAR;
END IF;
END$$;
"""
)
def downgrade() -> None:
"""Remove display_name and avatar_url columns from user table."""
op.execute(
"""
ALTER TABLE "user"
DROP COLUMN IF EXISTS avatar_url;
"""
)
op.execute(
"""
ALTER TABLE "user"
DROP COLUMN IF EXISTS display_name;
"""
)

View file

@ -0,0 +1,46 @@
"""Add author_id column to new_chat_messages table
Revision ID: 65
Revises: 64
"""
from collections.abc import Sequence
from alembic import op
revision: str = "65"
down_revision: str | None = "64"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add author_id column to new_chat_messages table."""
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'new_chat_messages' AND column_name = 'author_id'
) THEN
ALTER TABLE new_chat_messages
ADD COLUMN author_id UUID REFERENCES "user"(id) ON DELETE SET NULL;
CREATE INDEX IF NOT EXISTS ix_new_chat_messages_author_id
ON new_chat_messages(author_id);
END IF;
END$$;
"""
)
def downgrade() -> None:
"""Remove author_id column from new_chat_messages table."""
op.execute(
"""
DROP INDEX IF EXISTS ix_new_chat_messages_author_id;
ALTER TABLE new_chat_messages
DROP COLUMN IF EXISTS author_id;
"""
)

View file

@ -20,7 +20,7 @@ from app.agents.new_chat.system_prompt import (
build_configurable_system_prompt, build_configurable_system_prompt,
build_surfsense_system_prompt, build_surfsense_system_prompt,
) )
from app.agents.new_chat.tools import build_tools from app.agents.new_chat.tools.registry import build_tools_async
from app.services.connector_service import ConnectorService from app.services.connector_service import ConnectorService
# ============================================================================= # =============================================================================
@ -28,7 +28,7 @@ from app.services.connector_service import ConnectorService
# ============================================================================= # =============================================================================
def create_surfsense_deep_agent( async def create_surfsense_deep_agent(
llm: ChatLiteLLM, llm: ChatLiteLLM,
search_space_id: int, search_space_id: int,
db_session: AsyncSession, db_session: AsyncSession,
@ -120,8 +120,8 @@ def create_surfsense_deep_agent(
"firecrawl_api_key": firecrawl_api_key, "firecrawl_api_key": firecrawl_api_key,
} }
# Build tools using the registry # Build tools using the async registry (includes MCP tools)
tools = build_tools( tools = await build_tools_async(
dependencies=dependencies, dependencies=dependencies,
enabled_tools=enabled_tools, enabled_tools=enabled_tools,
disabled_tools=disabled_tools, disabled_tools=disabled_tools,

View file

@ -0,0 +1,203 @@
"""MCP Client Wrapper.
This module provides a client for communicating with MCP servers via stdio transport.
It handles server lifecycle management, tool discovery, and tool execution.
"""
import logging
import os
from contextlib import asynccontextmanager
from typing import Any
from mcp import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client
logger = logging.getLogger(__name__)
class MCPClient:
"""Client for communicating with an MCP server."""
def __init__(
self, command: str, args: list[str], env: dict[str, str] | None = None
):
"""Initialize MCP client.
Args:
command: Command to spawn the MCP server (e.g., "uvx", "node")
args: Arguments for the command (e.g., ["mcp-server-git"])
env: Optional environment variables for the server process
"""
self.command = command
self.args = args
self.env = env or {}
self.session: ClientSession | None = None
@asynccontextmanager
async def connect(self):
"""Connect to the MCP server and manage its lifecycle.
Yields:
ClientSession: Active MCP session for making requests
"""
try:
# Merge env vars with current environment
server_env = os.environ.copy()
server_env.update(self.env)
# Create server parameters with env
server_params = StdioServerParameters(
command=self.command, args=self.args, env=server_env
)
# Spawn server process and create session
# Note: Cannot combine these context managers because ClientSession
# needs the read/write streams from stdio_client
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
self.session = session
logger.info(
"Connected to MCP server: %s %s",
self.command,
" ".join(self.args),
)
yield session
except Exception as e:
logger.error("Failed to connect to MCP server: %s", e, exc_info=True)
raise
finally:
self.session = None
logger.info("Disconnected from MCP server: %s", self.command)
async def list_tools(self) -> list[dict[str, Any]]:
"""List all tools available from the MCP server.
Returns:
List of tool definitions with name, description, and input schema
Raises:
RuntimeError: If not connected to server
"""
if not self.session:
raise RuntimeError(
"Not connected to MCP server. Use 'async with client.connect():'"
)
try:
# Call tools/list RPC method
response = await self.session.list_tools()
tools = []
for tool in response.tools:
tools.append(
{
"name": tool.name,
"description": tool.description or "",
"input_schema": tool.inputSchema
if hasattr(tool, "inputSchema")
else {},
}
)
logger.info("Listed %d tools from MCP server", len(tools))
return tools
except Exception as e:
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
raise
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""Call a tool on the MCP server.
Args:
tool_name: Name of the tool to call
arguments: Arguments to pass to the tool
Returns:
Tool execution result
Raises:
RuntimeError: If not connected to server
"""
if not self.session:
raise RuntimeError(
"Not connected to MCP server. Use 'async with client.connect():'"
)
try:
logger.info(
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
)
# Call tools/call RPC method
response = await self.session.call_tool(tool_name, arguments=arguments)
# Extract content from response
result = []
for content in response.content:
if hasattr(content, "text"):
result.append(content.text)
elif hasattr(content, "data"):
result.append(str(content.data))
else:
result.append(str(content))
result_str = "\n".join(result) if result else ""
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
return result_str
except RuntimeError as e:
# Handle validation errors from MCP server responses
# Some MCP servers (like server-memory) return extra fields not in their schema
if "Invalid structured content" in str(e):
logger.warning(
"MCP server returned data not matching its schema, but continuing: %s",
e,
)
# Try to extract result from error message or return a success message
return "Operation completed (server returned unexpected format)"
raise
except (ValueError, TypeError, AttributeError, KeyError) as e:
logger.error(
"Failed to call MCP tool '%s': %s", tool_name, e, exc_info=True
)
return f"Error calling tool: {e!s}"
async def test_mcp_connection(
command: str, args: list[str], env: dict[str, str] | None = None
) -> dict[str, Any]:
"""Test connection to an MCP server and fetch available tools.
Args:
command: Command to spawn the MCP server
args: Arguments for the command
env: Optional environment variables
Returns:
Dict with connection status and available tools
"""
client = MCPClient(command, args, env)
try:
async with client.connect():
tools = await client.list_tools()
return {
"status": "success",
"message": f"Connected successfully. Found {len(tools)} tools.",
"tools": tools,
}
except (RuntimeError, ConnectionError, TimeoutError, OSError) as e:
return {
"status": "error",
"message": f"Failed to connect: {e!s}",
"tools": [],
}

View file

@ -0,0 +1,198 @@
"""MCP Tool Factory.
This module creates LangChain tools from MCP servers using the Model Context Protocol.
Tools are dynamically discovered from MCP servers - no manual configuration needed.
This implements real MCP protocol support similar to Cursor's implementation.
"""
import logging
from typing import Any
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, create_model
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.mcp_client import MCPClient
from app.db import SearchSourceConnector, SearchSourceConnectorType
logger = logging.getLogger(__name__)
def _create_dynamic_input_model_from_schema(
tool_name: str,
input_schema: dict[str, Any],
) -> type[BaseModel]:
"""Create a Pydantic model from MCP tool's JSON schema.
Args:
tool_name: Name of the tool (used for model class name)
input_schema: JSON schema from MCP server
Returns:
Pydantic model class for tool input validation
"""
properties = input_schema.get("properties", {})
required_fields = input_schema.get("required", [])
# Build Pydantic field definitions
field_definitions = {}
for param_name, param_schema in properties.items():
param_description = param_schema.get("description", "")
is_required = param_name in required_fields
# Use Any type for complex schemas to preserve structure
# This allows the MCP server to do its own validation
from typing import Any as AnyType
from pydantic import Field
if is_required:
field_definitions[param_name] = (
AnyType,
Field(..., description=param_description),
)
else:
field_definitions[param_name] = (
AnyType | None,
Field(None, description=param_description),
)
# Create dynamic model
model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input"
return create_model(model_name, **field_definitions)
async def _create_mcp_tool_from_definition(
tool_def: dict[str, Any],
mcp_client: MCPClient,
) -> StructuredTool:
"""Create a LangChain tool from an MCP tool definition.
Args:
tool_def: Tool definition from MCP server with name, description, input_schema
mcp_client: MCP client instance for calling the tool
Returns:
LangChain StructuredTool instance
"""
tool_name = tool_def.get("name", "unnamed_tool")
tool_description = tool_def.get("description", "No description provided")
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
# Log the actual schema for debugging
logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}")
# Create dynamic input model from schema
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
async def mcp_tool_call(**kwargs) -> str:
"""Execute the MCP tool call via the client."""
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
try:
# Connect to server and call tool
async with mcp_client.connect():
result = await mcp_client.call_tool(tool_name, kwargs)
return str(result)
except Exception as e:
error_msg = f"MCP tool '{tool_name}' failed: {e!s}"
logger.exception(error_msg)
return f"Error: {error_msg}"
# Create StructuredTool with response_format to preserve exact schema
tool = StructuredTool(
name=tool_name,
description=tool_description,
coroutine=mcp_tool_call,
args_schema=input_model,
# Store the original MCP schema as metadata so we can access it later
metadata={"mcp_input_schema": input_schema},
)
logger.info(f"Created MCP tool: '{tool_name}'")
return tool
async def load_mcp_tools(
session: AsyncSession,
search_space_id: int,
) -> list[StructuredTool]:
"""Load all MCP tools from user's active MCP server connectors.
This discovers tools dynamically from MCP servers using the protocol.
Args:
session: Database session
search_space_id: User's search space ID
Returns:
List of LangChain StructuredTool instances
"""
try:
# Fetch all MCP connectors for this search space
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.connector_type
== SearchSourceConnectorType.MCP_CONNECTOR,
SearchSourceConnector.search_space_id == search_space_id,
),
)
tools: list[StructuredTool] = []
for connector in result.scalars():
try:
# Extract server config
config = connector.config or {}
server_config = config.get("server_config", {})
command = server_config.get("command")
args = server_config.get("args", [])
env = server_config.get("env", {})
if not command:
logger.warning(
f"MCP connector {connector.id} missing command, skipping"
)
continue
# Create MCP client
mcp_client = MCPClient(command, args, env)
# Connect and discover tools
async with mcp_client.connect():
tool_definitions = await mcp_client.list_tools()
logger.info(
f"Discovered {len(tool_definitions)} tools from MCP server "
f"'{command}' (connector {connector.id})"
)
# Create LangChain tools from definitions
for tool_def in tool_definitions:
try:
tool = await _create_mcp_tool_from_definition(
tool_def, mcp_client
)
tools.append(tool)
except Exception as e:
logger.exception(
f"Failed to create tool '{tool_def.get('name')}' "
f"from connector {connector.id}: {e!s}",
)
except Exception as e:
logger.exception(
f"Failed to load tools from MCP connector {connector.id}: {e!s}",
)
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
return tools
except Exception as e:
logger.exception(f"Failed to load MCP tools: {e!s}")
return []

View file

@ -1,5 +1,4 @@
""" """Tools registry for SurfSense deep agent.
Tools registry for SurfSense deep agent.
This module provides a registry pattern for managing tools in the SurfSense agent. This module provides a registry pattern for managing tools in the SurfSense agent.
It makes it easy for OSS contributors to add new tools by: It makes it easy for OSS contributors to add new tools by:
@ -37,6 +36,7 @@ Example of adding a new tool:
), ),
""" """
import logging
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
@ -46,6 +46,7 @@ from langchain_core.tools import BaseTool
from .display_image import create_display_image_tool from .display_image import create_display_image_tool
from .knowledge_base import create_search_knowledge_base_tool from .knowledge_base import create_search_knowledge_base_tool
from .link_preview import create_link_preview_tool from .link_preview import create_link_preview_tool
from .mcp_tool import load_mcp_tools
from .podcast import create_generate_podcast_tool from .podcast import create_generate_podcast_tool
from .scrape_webpage import create_scrape_webpage_tool from .scrape_webpage import create_scrape_webpage_tool
from .search_surfsense_docs import create_search_surfsense_docs_tool from .search_surfsense_docs import create_search_surfsense_docs_tool
@ -57,8 +58,7 @@ from .search_surfsense_docs import create_search_surfsense_docs_tool
@dataclass @dataclass
class ToolDefinition: class ToolDefinition:
""" """Definition of a tool that can be added to the agent.
Definition of a tool that can be added to the agent.
Attributes: Attributes:
name: Unique identifier for the tool name: Unique identifier for the tool
@ -66,6 +66,7 @@ class ToolDefinition:
factory: Callable that creates the tool. Receives a dict of dependencies. factory: Callable that creates the tool. Receives a dict of dependencies.
requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session") requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session")
enabled_by_default: Whether the tool is enabled when no explicit config is provided enabled_by_default: Whether the tool is enabled when no explicit config is provided
""" """
name: str name: str
@ -178,8 +179,7 @@ def build_tools(
disabled_tools: list[str] | None = None, disabled_tools: list[str] | None = None,
additional_tools: list[BaseTool] | None = None, additional_tools: list[BaseTool] | None = None,
) -> list[BaseTool]: ) -> list[BaseTool]:
""" """Build the list of tools for the agent.
Build the list of tools for the agent.
Args: Args:
dependencies: Dict containing all possible dependencies: dependencies: Dict containing all possible dependencies:
@ -206,6 +206,7 @@ def build_tools(
# Add custom tools # Add custom tools
tools = build_tools(deps, additional_tools=[my_custom_tool]) tools = build_tools(deps, additional_tools=[my_custom_tool])
""" """
# Determine which tools to enable # Determine which tools to enable
if enabled_tools is not None: if enabled_tools is not None:
@ -226,8 +227,9 @@ def build_tools(
# Check that all required dependencies are provided # Check that all required dependencies are provided
missing_deps = [dep for dep in tool_def.requires if dep not in dependencies] missing_deps = [dep for dep in tool_def.requires if dep not in dependencies]
if missing_deps: if missing_deps:
msg = f"Tool '{tool_def.name}' requires dependencies: {missing_deps}"
raise ValueError( raise ValueError(
f"Tool '{tool_def.name}' requires dependencies: {missing_deps}" msg,
) )
# Create the tool # Create the tool
@ -239,3 +241,62 @@ def build_tools(
tools.extend(additional_tools) tools.extend(additional_tools)
return tools return tools
async def build_tools_async(
dependencies: dict[str, Any],
enabled_tools: list[str] | None = None,
disabled_tools: list[str] | None = None,
additional_tools: list[BaseTool] | None = None,
include_mcp_tools: bool = True,
) -> list[BaseTool]:
"""Async version of build_tools that also loads MCP tools from database.
Design Note:
This function exists because MCP tools require database queries to load user configs,
while built-in tools are created synchronously from static code.
Alternative: We could make build_tools() itself async and always query the database,
but that would force async everywhere even when only using built-in tools. The current
design keeps the simple case (static tools only) synchronous while supporting dynamic
database-loaded tools through this async wrapper.
Args:
dependencies: Dict containing all possible dependencies
enabled_tools: Explicit list of tool names to enable. If None, uses defaults.
disabled_tools: List of tool names to disable (applied after enabled_tools).
additional_tools: Extra tools to add (e.g., custom tools not in registry).
include_mcp_tools: Whether to load user's MCP tools from database.
Returns:
List of configured tool instances ready for the agent, including MCP tools.
"""
# Build standard tools
tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools)
# Load MCP tools if requested and dependencies are available
if (
include_mcp_tools
and "db_session" in dependencies
and "search_space_id" in dependencies
):
try:
mcp_tools = await load_mcp_tools(
dependencies["db_session"],
dependencies["search_space_id"],
)
tools.extend(mcp_tools)
logging.info(
f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}",
)
except Exception as e:
# Log error but don't fail - just continue without MCP tools
logging.exception(f"Failed to load MCP tools: {e!s}")
# Log all tools being returned to agent
logging.info(
f"Total tools for agent: {len(tools)} - {[t.name for t in tools]}",
)
return tools

View file

@ -80,6 +80,7 @@ class SearchSourceConnectorType(str, Enum):
WEBCRAWLER_CONNECTOR = "WEBCRAWLER_CONNECTOR" WEBCRAWLER_CONNECTOR = "WEBCRAWLER_CONNECTOR"
BOOKSTACK_CONNECTOR = "BOOKSTACK_CONNECTOR" BOOKSTACK_CONNECTOR = "BOOKSTACK_CONNECTOR"
CIRCLEBACK_CONNECTOR = "CIRCLEBACK_CONNECTOR" CIRCLEBACK_CONNECTOR = "CIRCLEBACK_CONNECTOR"
MCP_CONNECTOR = "MCP_CONNECTOR" # Model Context Protocol - User-defined API tools
class LiteLLMProvider(str, Enum): class LiteLLMProvider(str, Enum):
@ -412,8 +413,17 @@ class NewChatMessage(BaseModel, TimestampMixin):
index=True, index=True,
) )
# Relationship # Track who sent this message (for shared chats)
author_id = Column(
UUID(as_uuid=True),
ForeignKey("user.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
# Relationships
thread = relationship("NewChatThread", back_populates="messages") thread = relationship("NewChatThread", back_populates="messages")
author = relationship("User")
class Document(BaseModel, TimestampMixin): class Document(BaseModel, TimestampMixin):
@ -605,7 +615,8 @@ class SearchSourceConnector(BaseModel, TimestampMixin):
"search_space_id", "search_space_id",
"user_id", "user_id",
"connector_type", "connector_type",
name="uq_searchspace_user_connector_type", "name",
name="uq_searchspace_user_connector_type_name",
), ),
) )
@ -874,6 +885,10 @@ if config.AUTH_TYPE == "GOOGLE":
) )
pages_used = Column(Integer, nullable=False, default=0, server_default="0") pages_used = Column(Integer, nullable=False, default=0, server_default="0")
# User profile from OAuth
display_name = Column(String, nullable=True)
avatar_url = Column(String, nullable=True)
else: else:
class User(SQLAlchemyBaseUserTableUUID, Base): class User(SQLAlchemyBaseUserTableUUID, Base):
@ -907,6 +922,10 @@ else:
) )
pages_used = Column(Integer, nullable=False, default=0, server_default="0") pages_used = Column(Integer, nullable=False, default=0, server_default="0")
# User profile (can be set manually for non-OAuth users)
display_name = Column(String, nullable=True)
avatar_url = Column(String, nullable=True)
engine = create_async_engine(DATABASE_URL) engine = create_async_engine(DATABASE_URL)
async_session_maker = async_sessionmaker(engine, expire_on_commit=False) async_session_maker = async_sessionmaker(engine, expire_on_commit=False)

View file

@ -411,11 +411,9 @@ async def get_thread_messages(
Requires CHATS_READ permission. Requires CHATS_READ permission.
""" """
try: try:
# Get thread with messages # Get thread first
result = await session.execute( result = await session.execute(
select(NewChatThread) select(NewChatThread).filter(NewChatThread.id == thread_id)
.options(selectinload(NewChatThread.messages))
.filter(NewChatThread.id == thread_id)
) )
thread = result.scalars().first() thread = result.scalars().first()
@ -434,6 +432,15 @@ async def get_thread_messages(
# Check thread-level access based on visibility # Check thread-level access based on visibility
await check_thread_access(session, thread, user) await check_thread_access(session, thread, user)
# Get messages with their authors loaded
messages_result = await session.execute(
select(NewChatMessage)
.options(selectinload(NewChatMessage.author))
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at)
)
db_messages = messages_result.scalars().all()
# Return messages in the format expected by assistant-ui # Return messages in the format expected by assistant-ui
messages = [ messages = [
NewChatMessageRead( NewChatMessageRead(
@ -442,8 +449,11 @@ async def get_thread_messages(
role=msg.role, role=msg.role,
content=msg.content, content=msg.content,
created_at=msg.created_at, created_at=msg.created_at,
author_id=msg.author_id,
author_display_name=msg.author.display_name if msg.author else None,
author_avatar_url=msg.author.avatar_url if msg.author else None,
) )
for msg in thread.messages for msg in db_messages
] ]
return ThreadHistoryLoadResponse(messages=messages) return ThreadHistoryLoadResponse(messages=messages)
@ -782,6 +792,7 @@ async def append_message(
thread_id=thread_id, thread_id=thread_id,
role=message_role, role=message_role,
content=message.content, content=message.content,
author_id=user.id,
) )
session.add(db_message) session.add(db_message)

View file

@ -7,6 +7,13 @@ PUT /search-source-connectors/{connector_id} - Update a specific connector
DELETE /search-source-connectors/{connector_id} - Delete a specific connector DELETE /search-source-connectors/{connector_id} - Delete a specific connector
POST /search-source-connectors/{connector_id}/index - Index content from a connector to a search space POST /search-source-connectors/{connector_id}/index - Index content from a connector to a search space
MCP (Model Context Protocol) Connector routes:
POST /connectors/mcp - Create a new MCP connector with custom API tools
GET /connectors/mcp - List all MCP connectors for the current user's search space
GET /connectors/mcp/{connector_id} - Get a specific MCP connector with tools config
PUT /connectors/mcp/{connector_id} - Update an MCP connector's tools config
DELETE /connectors/mcp/{connector_id} - Delete an MCP connector
Note: OAuth connectors (Gmail, Drive, Slack, etc.) support multiple accounts per search space. Note: OAuth connectors (Gmail, Drive, Slack, etc.) support multiple accounts per search space.
Non-OAuth connectors (BookStack, GitHub, etc.) are limited to one per search space. Non-OAuth connectors (BookStack, GitHub, etc.) are limited to one per search space.
""" """
@ -32,6 +39,9 @@ from app.db import (
) )
from app.schemas import ( from app.schemas import (
GoogleDriveIndexRequest, GoogleDriveIndexRequest,
MCPConnectorCreate,
MCPConnectorRead,
MCPConnectorUpdate,
SearchSourceConnectorBase, SearchSourceConnectorBase,
SearchSourceConnectorCreate, SearchSourceConnectorCreate,
SearchSourceConnectorRead, SearchSourceConnectorRead,
@ -127,18 +137,20 @@ async def create_search_source_connector(
# Check if a connector with the same type already exists for this search space # Check if a connector with the same type already exists for this search space
# (for non-OAuth connectors that don't support multiple accounts) # (for non-OAuth connectors that don't support multiple accounts)
result = await session.execute( # Exception: MCP_CONNECTOR can have multiple instances with different names
select(SearchSourceConnector).filter( if connector.connector_type != SearchSourceConnectorType.MCP_CONNECTOR:
SearchSourceConnector.search_space_id == search_space_id, result = await session.execute(
SearchSourceConnector.connector_type == connector.connector_type, select(SearchSourceConnector).filter(
) SearchSourceConnector.search_space_id == search_space_id,
) SearchSourceConnector.connector_type == connector.connector_type,
existing_connector = result.scalars().first() )
if existing_connector:
raise HTTPException(
status_code=409,
detail=f"A connector with type {connector.connector_type} already exists in this search space.",
) )
existing_connector = result.scalars().first()
if existing_connector:
raise HTTPException(
status_code=409,
detail=f"A connector with type {connector.connector_type} already exists in this search space.",
)
# Prepare connector data # Prepare connector data
connector_data = connector.model_dump() connector_data = connector.model_dump()
@ -1964,3 +1976,348 @@ async def run_bookstack_indexing(
f"Critical error in run_bookstack_indexing for connector {connector_id}: {e}", f"Critical error in run_bookstack_indexing for connector {connector_id}: {e}",
exc_info=True, exc_info=True,
) )
# =============================================================================
# MCP Connector Routes
# =============================================================================
@router.post("/connectors/mcp", response_model=MCPConnectorRead, status_code=201)
async def create_mcp_connector(
connector_data: MCPConnectorCreate,
search_space_id: int = Query(..., description="Search space ID"),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Create a new MCP (Model Context Protocol) connector.
MCP connectors allow users to connect to MCP servers (like in Cursor).
Tools are auto-discovered from the server - no manual configuration needed.
Args:
connector_data: MCP server configuration (command, args, env)
search_space_id: ID of the search space to attach the connector to
session: Database session
user: Current authenticated user
Returns:
Created MCP connector with server configuration
Raises:
HTTPException: If search space not found or permission denied
"""
try:
# Check user has permission to create connectors
await check_permission(
session,
user,
search_space_id,
Permission.CONNECTORS_CREATE.value,
"You don't have permission to create connectors in this search space",
)
# Create the connector with server config
db_connector = SearchSourceConnector(
name=connector_data.name,
connector_type=SearchSourceConnectorType.MCP_CONNECTOR,
is_indexable=False, # MCP connectors are not indexable
config={"server_config": connector_data.server_config.model_dump()},
periodic_indexing_enabled=False,
indexing_frequency_minutes=None,
search_space_id=search_space_id,
user_id=user.id,
)
session.add(db_connector)
await session.commit()
await session.refresh(db_connector)
logger.info(
f"Created MCP connector {db_connector.id} for server '{connector_data.server_config.command}' "
f"for user {user.id} in search space {search_space_id}"
)
# Convert to read schema
connector_read = SearchSourceConnectorRead.model_validate(db_connector)
return MCPConnectorRead.from_connector(connector_read)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to create MCP connector: {e!s}", exc_info=True)
await session.rollback()
raise HTTPException(
status_code=500, detail=f"Failed to create MCP connector: {e!s}"
) from e
@router.get("/connectors/mcp", response_model=list[MCPConnectorRead])
async def list_mcp_connectors(
search_space_id: int = Query(..., description="Search space ID"),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
List all MCP connectors for a search space.
Args:
search_space_id: ID of the search space
session: Database session
user: Current authenticated user
Returns:
List of MCP connectors with their tool configurations
"""
try:
# Check user has permission to read connectors
await check_permission(
session,
user,
search_space_id,
Permission.CONNECTORS_READ.value,
"You don't have permission to view connectors in this search space",
)
# Fetch MCP connectors
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.connector_type
== SearchSourceConnectorType.MCP_CONNECTOR,
SearchSourceConnector.search_space_id == search_space_id,
)
)
connectors = result.scalars().all()
return [
MCPConnectorRead.from_connector(SearchSourceConnectorRead.model_validate(c))
for c in connectors
]
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to list MCP connectors: {e!s}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"Failed to list MCP connectors: {e!s}"
) from e
@router.get("/connectors/mcp/{connector_id}", response_model=MCPConnectorRead)
async def get_mcp_connector(
connector_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get a specific MCP connector by ID.
Args:
connector_id: ID of the connector
session: Database session
user: Current authenticated user
Returns:
MCP connector with tool configurations
"""
try:
# Fetch connector
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.MCP_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
raise HTTPException(status_code=404, detail="MCP connector not found")
# Check user has permission to read connectors
await check_permission(
session,
user,
connector.search_space_id,
Permission.CONNECTORS_READ.value,
"You don't have permission to view this connector",
)
connector_read = SearchSourceConnectorRead.model_validate(connector)
return MCPConnectorRead.from_connector(connector_read)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get MCP connector: {e!s}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"Failed to get MCP connector: {e!s}"
) from e
@router.put("/connectors/mcp/{connector_id}", response_model=MCPConnectorRead)
async def update_mcp_connector(
connector_id: int,
connector_update: MCPConnectorUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Update an MCP connector.
Args:
connector_id: ID of the connector to update
connector_update: Updated connector data
session: Database session
user: Current authenticated user
Returns:
Updated MCP connector
"""
try:
# Fetch connector
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.MCP_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
raise HTTPException(status_code=404, detail="MCP connector not found")
# Check user has permission to update connectors
await check_permission(
session,
user,
connector.search_space_id,
Permission.CONNECTORS_UPDATE.value,
"You don't have permission to update this connector",
)
# Update fields
if connector_update.name is not None:
connector.name = connector_update.name
if connector_update.server_config is not None:
connector.config = {
"server_config": connector_update.server_config.model_dump()
}
connector.updated_at = datetime.now(UTC)
await session.commit()
await session.refresh(connector)
logger.info(f"Updated MCP connector {connector_id}")
connector_read = SearchSourceConnectorRead.model_validate(connector)
return MCPConnectorRead.from_connector(connector_read)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to update MCP connector: {e!s}", exc_info=True)
await session.rollback()
raise HTTPException(
status_code=500, detail=f"Failed to update MCP connector: {e!s}"
) from e
@router.delete("/connectors/mcp/{connector_id}", status_code=204)
async def delete_mcp_connector(
connector_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Delete an MCP connector.
Args:
connector_id: ID of the connector to delete
session: Database session
user: Current authenticated user
"""
try:
# Fetch connector
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.MCP_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
raise HTTPException(status_code=404, detail="MCP connector not found")
# Check user has permission to delete connectors
await check_permission(
session,
user,
connector.search_space_id,
Permission.CONNECTORS_DELETE.value,
"You don't have permission to delete this connector",
)
await session.delete(connector)
await session.commit()
logger.info(f"Deleted MCP connector {connector_id}")
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to delete MCP connector: {e!s}", exc_info=True)
await session.rollback()
raise HTTPException(
status_code=500, detail=f"Failed to delete MCP connector: {e!s}"
) from e
@router.post("/connectors/mcp/test")
async def test_mcp_server_connection(
server_config: dict = Body(...),
user: User = Depends(current_active_user),
):
"""
Test connection to an MCP server and fetch available tools.
This endpoint allows users to test their MCP server configuration
before saving it, similar to Cursor's flow.
Args:
server_config: Server configuration with command, args, env
user: Current authenticated user
Returns:
Connection status and list of available tools
"""
try:
from app.agents.new_chat.tools.mcp_client import test_mcp_connection
command = server_config.get("command")
args = server_config.get("args", [])
env = server_config.get("env", {})
if not command:
raise HTTPException(status_code=400, detail="Server command is required")
# Test the connection
result = await test_mcp_connection(command, args, env)
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to test MCP connection: {e!s}", exc_info=True)
return {
"status": "error",
"message": f"Failed to test connection: {e!s}",
"tools": [],
}

View file

@ -55,6 +55,10 @@ from .rbac_schemas import (
UserSearchSpaceAccess, UserSearchSpaceAccess,
) )
from .search_source_connector import ( from .search_source_connector import (
MCPConnectorCreate,
MCPConnectorRead,
MCPConnectorUpdate,
MCPServerConfig,
SearchSourceConnectorBase, SearchSourceConnectorBase,
SearchSourceConnectorCreate, SearchSourceConnectorCreate,
SearchSourceConnectorRead, SearchSourceConnectorRead,
@ -108,6 +112,11 @@ __all__ = [
"LogFilter", "LogFilter",
"LogRead", "LogRead",
"LogUpdate", "LogUpdate",
# Search source connector schemas
"MCPConnectorCreate",
"MCPConnectorRead",
"MCPConnectorUpdate",
"MCPServerConfig",
"MembershipRead", "MembershipRead",
"MembershipReadWithUser", "MembershipReadWithUser",
"MembershipUpdate", "MembershipUpdate",
@ -135,7 +144,6 @@ __all__ = [
"RoleCreate", "RoleCreate",
"RoleRead", "RoleRead",
"RoleUpdate", "RoleUpdate",
# Search source connector schemas
"SearchSourceConnectorBase", "SearchSourceConnectorBase",
"SearchSourceConnectorCreate", "SearchSourceConnectorCreate",
"SearchSourceConnectorRead", "SearchSourceConnectorRead",

View file

@ -38,6 +38,9 @@ class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
"""Schema for reading a message.""" """Schema for reading a message."""
thread_id: int thread_id: int
author_id: UUID | None = None
author_display_name: str | None = None
author_avatar_url: str | None = None
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)

View file

@ -23,7 +23,9 @@ class SearchSourceConnectorBase(BaseModel):
@field_validator("config") @field_validator("config")
@classmethod @classmethod
def validate_config_for_connector_type( def validate_config_for_connector_type(
cls, config: dict[str, Any], values: dict[str, Any] cls,
config: dict[str, Any],
values: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
connector_type = values.data.get("connector_type") connector_type = values.data.get("connector_type")
return validate_connector_config(connector_type, config) return validate_connector_config(connector_type, config)
@ -38,15 +40,18 @@ class SearchSourceConnectorBase(BaseModel):
""" """
if self.periodic_indexing_enabled: if self.periodic_indexing_enabled:
if not self.is_indexable: if not self.is_indexable:
msg = "periodic_indexing_enabled can only be True for indexable connectors"
raise ValueError( raise ValueError(
"periodic_indexing_enabled can only be True for indexable connectors" msg,
) )
if self.indexing_frequency_minutes is None: if self.indexing_frequency_minutes is None:
msg = "indexing_frequency_minutes is required when periodic_indexing_enabled is True"
raise ValueError( raise ValueError(
"indexing_frequency_minutes is required when periodic_indexing_enabled is True" msg,
) )
if self.indexing_frequency_minutes <= 0: if self.indexing_frequency_minutes <= 0:
raise ValueError("indexing_frequency_minutes must be greater than 0") msg = "indexing_frequency_minutes must be greater than 0"
raise ValueError(msg)
return self return self
@ -70,3 +75,63 @@ class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampMod
user_id: uuid.UUID user_id: uuid.UUID
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
# =============================================================================
# MCP-specific schemas
# =============================================================================
class MCPServerConfig(BaseModel):
"""Configuration for an MCP server connection (similar to Cursor's config)."""
command: str # e.g., "uvx", "node", "python"
args: list[str] = [] # e.g., ["mcp-server-git", "--repository", "/path"]
env: dict[str, str] = {} # Environment variables for the server process
transport: str = "stdio" # "stdio" | "sse" | "http" (stdio is most common)
class MCPConnectorCreate(BaseModel):
"""Schema for creating an MCP connector."""
name: str
server_config: MCPServerConfig
class MCPConnectorUpdate(BaseModel):
"""Schema for updating an MCP connector."""
name: str | None = None
server_config: MCPServerConfig | None = None
class MCPConnectorRead(BaseModel):
"""Schema for reading an MCP connector with server config."""
id: int
name: str
connector_type: SearchSourceConnectorType
server_config: MCPServerConfig
search_space_id: int
user_id: uuid.UUID
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
@classmethod
def from_connector(cls, connector: SearchSourceConnectorRead) -> "MCPConnectorRead":
"""Convert from base SearchSourceConnectorRead."""
config = connector.config or {}
server_config = MCPServerConfig(**config.get("server_config", {}))
return cls(
id=connector.id,
name=connector.name,
connector_type=connector.connector_type,
server_config=server_config,
search_space_id=connector.search_space_id,
user_id=connector.user_id,
created_at=connector.created_at,
updated_at=connector.updated_at,
)

View file

@ -6,6 +6,8 @@ from fastapi_users import schemas
class UserRead(schemas.BaseUser[uuid.UUID]): class UserRead(schemas.BaseUser[uuid.UUID]):
pages_limit: int pages_limit: int
pages_used: int pages_used: int
display_name: str | None = None
avatar_url: str | None = None
class UserCreate(schemas.BaseUserCreate): class UserCreate(schemas.BaseUserCreate):
@ -13,4 +15,5 @@ class UserCreate(schemas.BaseUserCreate):
class UserUpdate(schemas.BaseUserUpdate): class UserUpdate(schemas.BaseUserUpdate):
pass display_name: str | None = None
avatar_url: str | None = None

View file

@ -237,7 +237,7 @@ async def stream_new_chat(
checkpointer = await get_checkpointer() checkpointer = await get_checkpointer()
# Create the deep agent with checkpointer and configurable prompts # Create the deep agent with checkpointer and configurable prompts
agent = create_surfsense_deep_agent( agent = await create_surfsense_deep_agent(
llm=llm, llm=llm,
search_space_id=search_space_id, search_space_id=search_space_id,
db_session=session, db_session=session,

View file

@ -2,11 +2,14 @@
File document processors for different ETL services (Unstructured, LlamaCloud, Docling). File document processors for different ETL services (Unstructured, LlamaCloud, Docling).
""" """
import asyncio
import contextlib import contextlib
import logging import logging
import ssl
import warnings import warnings
from logging import ERROR, getLogger from logging import ERROR, getLogger
import httpx
from fastapi import HTTPException from fastapi import HTTPException
from langchain_core.documents import Document as LangChainDocument from langchain_core.documents import Document as LangChainDocument
from litellm import atranscription from litellm import atranscription
@ -31,6 +34,122 @@ from .base import (
) )
from .markdown_processor import add_received_markdown_file_document from .markdown_processor import add_received_markdown_file_document
# Constants for LlamaCloud retry configuration
LLAMACLOUD_MAX_RETRIES = 3
LLAMACLOUD_BASE_DELAY = 5 # Base delay in seconds for exponential backoff
LLAMACLOUD_RETRYABLE_EXCEPTIONS = (
ssl.SSLError,
httpx.ConnectError,
httpx.ConnectTimeout,
httpx.ReadTimeout,
httpx.WriteTimeout,
ConnectionError,
TimeoutError,
)
async def parse_with_llamacloud_retry(
file_path: str,
estimated_pages: int,
task_logger: TaskLoggingService | None = None,
log_entry: Log | None = None,
):
"""
Parse a file with LlamaCloud with retry logic for transient SSL/connection errors.
Args:
file_path: Path to the file to parse
estimated_pages: Estimated number of pages for timeout calculation
task_logger: Optional task logger for progress updates
log_entry: Optional log entry for progress updates
Returns:
LlamaParse result object
Raises:
Exception: If all retries fail
"""
from llama_cloud_services import LlamaParse
from llama_cloud_services.parse.utils import ResultType
# Calculate timeouts based on estimated pages
# Base timeout of 300 seconds + 30 seconds per page for large documents
base_timeout = 300
per_page_timeout = 30
job_timeout = base_timeout + (estimated_pages * per_page_timeout)
# Create custom httpx client with larger timeouts for file uploads
# The SSL error often occurs during large file uploads, so we need generous timeouts
custom_timeout = httpx.Timeout(
connect=60.0, # 60 seconds to establish connection
read=300.0, # 5 minutes to read response
write=300.0, # 5 minutes to write/upload (important for large files)
pool=60.0, # 60 seconds to acquire connection from pool
)
last_exception = None
for attempt in range(1, LLAMACLOUD_MAX_RETRIES + 1):
try:
# Create a fresh httpx client for each attempt
async with httpx.AsyncClient(timeout=custom_timeout) as custom_client:
# Create LlamaParse parser instance with optimized settings
parser = LlamaParse(
api_key=app_config.LLAMA_CLOUD_API_KEY,
num_workers=1, # Use single worker for file processing
verbose=True,
language="en",
result_type=ResultType.MD,
# Timeout settings for large files
max_timeout=max(2000, job_timeout), # Overall max timeout
job_timeout_in_seconds=job_timeout,
job_timeout_extra_time_per_page_in_seconds=per_page_timeout,
# Use our custom client with larger timeouts
custom_client=custom_client,
)
# Parse the file asynchronously
result = await parser.aparse(file_path)
return result
except LLAMACLOUD_RETRYABLE_EXCEPTIONS as e:
last_exception = e
error_type = type(e).__name__
if attempt < LLAMACLOUD_MAX_RETRIES:
# Calculate exponential backoff delay
delay = LLAMACLOUD_BASE_DELAY * (2 ** (attempt - 1))
if task_logger and log_entry:
await task_logger.log_task_progress(
log_entry,
f"LlamaCloud upload failed (attempt {attempt}/{LLAMACLOUD_MAX_RETRIES}), retrying in {delay}s",
{
"error_type": error_type,
"error_message": str(e)[:200],
"attempt": attempt,
"retry_delay": delay,
},
)
else:
logging.warning(
f"LlamaCloud upload failed (attempt {attempt}/{LLAMACLOUD_MAX_RETRIES}): {error_type}. "
f"Retrying in {delay}s..."
)
await asyncio.sleep(delay)
else:
logging.error(
f"LlamaCloud upload failed after {LLAMACLOUD_MAX_RETRIES} attempts: {error_type} - {e}"
)
except Exception:
# Non-retryable exception, raise immediately
raise
# All retries exhausted
raise last_exception or RuntimeError("LlamaCloud parsing failed after all retries")
async def add_received_file_document_using_unstructured( async def add_received_file_document_using_unstructured(
session: AsyncSession, session: AsyncSession,
@ -819,24 +938,18 @@ async def process_file_in_background(
"file_type": "document", "file_type": "document",
"etl_service": "LLAMACLOUD", "etl_service": "LLAMACLOUD",
"processing_stage": "parsing", "processing_stage": "parsing",
"estimated_pages": estimated_pages_before,
}, },
) )
from llama_cloud_services import LlamaParse # Parse file with retry logic for SSL/connection errors (common with large files)
from llama_cloud_services.parse.utils import ResultType result = await parse_with_llamacloud_retry(
file_path=file_path,
# Create LlamaParse parser instance estimated_pages=estimated_pages_before,
parser = LlamaParse( task_logger=task_logger,
api_key=app_config.LLAMA_CLOUD_API_KEY, log_entry=log_entry,
num_workers=1, # Use single worker for file processing
verbose=True,
language="en",
result_type=ResultType.MD,
) )
# Parse the file asynchronously
result = await parser.aparse(file_path)
# Clean up the temp file # Clean up the temp file
import os import os

View file

@ -1,6 +1,7 @@
import logging import logging
import uuid import uuid
import httpx
from fastapi import Depends, Request, Response from fastapi import Depends, Request, Response
from fastapi.responses import JSONResponse, RedirectResponse from fastapi.responses import JSONResponse, RedirectResponse
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
@ -46,6 +47,71 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = SECRET reset_password_token_secret = SECRET
verification_token_secret = SECRET verification_token_secret = SECRET
async def oauth_callback(
self,
oauth_name: str,
access_token: str,
account_id: str,
account_email: str,
expires_at: int | None = None,
refresh_token: str | None = None,
request: Request | None = None,
*,
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> User:
"""
Override OAuth callback to capture Google profile data (name, avatar).
"""
# Call parent implementation to create/get user
user = await super().oauth_callback(
oauth_name,
access_token,
account_id,
account_email,
expires_at,
refresh_token,
request,
associate_by_email=associate_by_email,
is_verified_by_default=is_verified_by_default,
)
# Fetch and store Google profile data if not already set
if oauth_name == "google" and (not user.display_name or not user.avatar_url):
try:
async with httpx.AsyncClient() as client:
response = await client.get(
"https://people.googleapis.com/v1/people/me",
params={"personFields": "names,photos"},
headers={"Authorization": f"Bearer {access_token}"},
)
response.raise_for_status()
profile = response.json()
update_dict = {}
# Extract name from names array
names = profile.get("names", [])
if not user.display_name and names:
display_name = names[0].get("displayName")
if display_name:
update_dict["display_name"] = display_name
# Extract photo URL from photos array
photos = profile.get("photos", [])
if not user.avatar_url and photos:
photo_url = photos[0].get("url")
if photo_url:
update_dict["avatar_url"] = photo_url
if update_dict:
user = await self.user_db.update(user, update_dict)
except Exception as e:
logger.warning(f"Failed to fetch Google profile: {e}")
return user
async def on_after_register(self, user: User, request: Request | None = None): async def on_after_register(self, user: User, request: Request | None = None):
""" """
Called after a user registers. Creates a default search space for the user Called after a user registers. Creates a default search space for the user

View file

@ -8,9 +8,9 @@ from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import UUID from uuid import UUID
from sqlalchemy import func from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.sql import func
from app.db import SearchSourceConnector, SearchSourceConnectorType from app.db import SearchSourceConnector, SearchSourceConnectorType
@ -27,6 +27,7 @@ BASE_NAME_FOR_TYPE = {
SearchSourceConnectorType.DISCORD_CONNECTOR: "Discord", SearchSourceConnectorType.DISCORD_CONNECTOR: "Discord",
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "Confluence", SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "Confluence",
SearchSourceConnectorType.AIRTABLE_CONNECTOR: "Airtable", SearchSourceConnectorType.AIRTABLE_CONNECTOR: "Airtable",
SearchSourceConnectorType.MCP_CONNECTOR: "Model Context Protocol (MCP)",
} }
@ -75,7 +76,7 @@ def extract_identifier_from_credentials(
if ".atlassian.net" in hostname: if ".atlassian.net" in hostname:
return hostname.replace(".atlassian.net", "") return hostname.replace(".atlassian.net", "")
return hostname return hostname
except Exception: except (ValueError, TypeError, AttributeError):
pass pass
return None return None

View file

@ -57,6 +57,9 @@ dependencies = [
"chonkie[all]>=1.5.0", "chonkie[all]>=1.5.0",
"langgraph-checkpoint-postgres>=3.0.2", "langgraph-checkpoint-postgres>=3.0.2",
"psycopg[binary,pool]>=3.3.2", "psycopg[binary,pool]>=3.3.2",
"mcp>=1.25.0",
"starlette>=0.40.0,<0.51.0",
"sse-starlette>=3.1.1,<3.1.2",
] ]
[dependency-groups] [dependency-groups]

View file

@ -175,6 +175,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/26/99/fc813cd978842c26c82534010ea849eee9ab3a13ea2b74e95cb9c99e747b/amqp-5.3.1-py3-none-any.whl", hash = "sha256:43b3319e1b4e7d1251833a93d672b4af1e40f3d632d479b98661a95f117880a2", size = 50944 }, { url = "https://files.pythonhosted.org/packages/26/99/fc813cd978842c26c82534010ea849eee9ab3a13ea2b74e95cb9c99e747b/amqp-5.3.1-py3-none-any.whl", hash = "sha256:43b3319e1b4e7d1251833a93d672b4af1e40f3d632d479b98661a95f117880a2", size = 50944 },
] ]
[[package]]
name = "annotated-doc"
version = "0.0.4"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303 },
]
[[package]] [[package]]
name = "annotated-types" name = "annotated-types"
version = "0.7.0" version = "0.7.0"
@ -1568,16 +1577,17 @@ wheels = [
[[package]] [[package]]
name = "fastapi" name = "fastapi"
version = "0.115.9" version = "0.128.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "annotated-doc" },
{ name = "pydantic" }, { name = "pydantic" },
{ name = "starlette" }, { name = "starlette" },
{ name = "typing-extensions" }, { name = "typing-extensions" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/ab/dd/d854f85e70f7341b29e3fda754f2833aec197bd355f805238758e3bcd8ed/fastapi-0.115.9.tar.gz", hash = "sha256:9d7da3b196c5eed049bc769f9475cd55509a112fbe031c0ef2f53768ae68d13f", size = 293774 } sdist = { url = "https://files.pythonhosted.org/packages/52/08/8c8508db6c7b9aae8f7175046af41baad690771c9bcde676419965e338c7/fastapi-0.128.0.tar.gz", hash = "sha256:1cc179e1cef10a6be60ffe429f79b829dce99d8de32d7acb7e6c8dfdf7f2645a", size = 365682 }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/32/b6/7517af5234378518f27ad35a7b24af9591bc500b8c1780929c1295999eb6/fastapi-0.115.9-py3-none-any.whl", hash = "sha256:4a439d7923e4de796bcc88b64e9754340fcd1574673cbd865ba8a99fe0d28c56", size = 94919 }, { url = "https://files.pythonhosted.org/packages/5c/05/5cbb59154b093548acd0f4c7c474a118eda06da25aa75c616b72d8fcd92a/fastapi-0.128.0-py3-none-any.whl", hash = "sha256:aebd93f9716ee3b4f4fcfe13ffb7cf308d99c9f3ab5622d8877441072561582d", size = 103094 },
] ]
[[package]] [[package]]
@ -3482,6 +3492,31 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/1b/92/9a45c91089c3cf690b5badd4be81e392ff086ccca8a1d4e3a08463d8a966/matplotlib-3.10.3-cp313-cp313t-win_amd64.whl", hash = "sha256:4f23ffe95c5667ef8a2b56eea9b53db7f43910fa4a2d5472ae0f72b64deab4d5", size = 8139044 }, { url = "https://files.pythonhosted.org/packages/1b/92/9a45c91089c3cf690b5badd4be81e392ff086ccca8a1d4e3a08463d8a966/matplotlib-3.10.3-cp313-cp313t-win_amd64.whl", hash = "sha256:4f23ffe95c5667ef8a2b56eea9b53db7f43910fa4a2d5472ae0f72b64deab4d5", size = 8139044 },
] ]
[[package]]
name = "mcp"
version = "1.25.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "httpx" },
{ name = "httpx-sse" },
{ name = "jsonschema" },
{ name = "pydantic" },
{ name = "pydantic-settings" },
{ name = "pyjwt", extra = ["crypto"] },
{ name = "python-multipart" },
{ name = "pywin32", marker = "sys_platform == 'win32'" },
{ name = "sse-starlette" },
{ name = "starlette" },
{ name = "typing-extensions" },
{ name = "typing-inspection" },
{ name = "uvicorn", marker = "sys_platform != 'emscripten'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d5/2d/649d80a0ecf6a1f82632ca44bec21c0461a9d9fc8934d38cb5b319f2db5e/mcp-1.25.0.tar.gz", hash = "sha256:56310361ebf0364e2d438e5b45f7668cbb124e158bb358333cd06e49e83a6802", size = 605387 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e2/fc/6dc7659c2ae5ddf280477011f4213a74f806862856b796ef08f028e664bf/mcp-1.25.0-py3-none-any.whl", hash = "sha256:b37c38144a666add0862614cc79ec276e97d72aa8ca26d622818d4e278b9721a", size = 233076 },
]
[[package]] [[package]]
name = "mdurl" name = "mdurl"
version = "0.1.2" version = "0.1.2"
@ -6382,15 +6417,29 @@ wheels = [
] ]
[[package]] [[package]]
name = "starlette" name = "sse-starlette"
version = "0.45.3" version = "3.1.1"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "anyio" }, { name = "anyio" },
{ name = "starlette" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/ff/fb/2984a686808b89a6781526129a4b51266f678b2d2b97ab2d325e56116df8/starlette-0.45.3.tar.gz", hash = "sha256:2cbcba2a75806f8a41c722141486f37c28e30a0921c5f6fe4346cb0dcee1302f", size = 2574076 } sdist = { url = "https://files.pythonhosted.org/packages/62/08/8f554b0e5bad3e4e880521a1686d96c05198471eed860b0eb89b57ea3636/sse_starlette-3.1.1.tar.gz", hash = "sha256:bffa531420c1793ab224f63648c059bcadc412bf9fdb1301ac8de1cf9a67b7fb", size = 24306 }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/d9/61/f2b52e107b1fc8944b33ef56bf6ac4ebbe16d91b94d2b87ce013bf63fb84/starlette-0.45.3-py3-none-any.whl", hash = "sha256:dfb6d332576f136ec740296c7e8bb8c8a7125044e7c6da30744718880cdd059d", size = 71507 }, { url = "https://files.pythonhosted.org/packages/e3/31/4c281581a0f8de137b710a07f65518b34bcf333b201cfa06cfda9af05f8a/sse_starlette-3.1.1-py3-none-any.whl", hash = "sha256:bb38f71ae74cfd86b529907a9fda5632195dfa6ae120f214ea4c890c7ee9d436", size = 12442 },
]
[[package]]
name = "starlette"
version = "0.50.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ba/b8/73a0e6a6e079a9d9cfa64113d771e421640b6f679a52eeb9b32f72d871a1/starlette-0.50.0.tar.gz", hash = "sha256:a2a17b22203254bcbc2e1f926d2d55f3f9497f769416b3190768befe598fa3ca", size = 2646985 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d9/52/1064f510b141bd54025f9b55105e26d1fa970b9be67ad766380a3c9b74b0/starlette-0.50.0-py3-none-any.whl", hash = "sha256:9e5391843ec9b6e472eed1365a78c8098cfceb7a74bfd4d6b1c0c0095efb3bca", size = 74033 },
] ]
[[package]] [[package]]
@ -6443,6 +6492,7 @@ dependencies = [
{ name = "litellm" }, { name = "litellm" },
{ name = "llama-cloud-services" }, { name = "llama-cloud-services" },
{ name = "markdownify" }, { name = "markdownify" },
{ name = "mcp" },
{ name = "notion-client" }, { name = "notion-client" },
{ name = "numpy" }, { name = "numpy" },
{ name = "pgvector" }, { name = "pgvector" },
@ -6457,6 +6507,8 @@ dependencies = [
{ name = "slack-sdk" }, { name = "slack-sdk" },
{ name = "soundfile" }, { name = "soundfile" },
{ name = "spacy" }, { name = "spacy" },
{ name = "sse-starlette" },
{ name = "starlette" },
{ name = "static-ffmpeg" }, { name = "static-ffmpeg" },
{ name = "tavily-python" }, { name = "tavily-python" },
{ name = "trafilatura" }, { name = "trafilatura" },
@ -6505,6 +6557,7 @@ requires-dist = [
{ name = "litellm", specifier = ">=1.80.10" }, { name = "litellm", specifier = ">=1.80.10" },
{ name = "llama-cloud-services", specifier = ">=0.6.25" }, { name = "llama-cloud-services", specifier = ">=0.6.25" },
{ name = "markdownify", specifier = ">=0.14.1" }, { name = "markdownify", specifier = ">=0.14.1" },
{ name = "mcp", specifier = ">=1.25.0" },
{ name = "notion-client", specifier = ">=2.3.0" }, { name = "notion-client", specifier = ">=2.3.0" },
{ name = "numpy", specifier = ">=1.24.0" }, { name = "numpy", specifier = ">=1.24.0" },
{ name = "pgvector", specifier = ">=0.3.6" }, { name = "pgvector", specifier = ">=0.3.6" },
@ -6519,6 +6572,8 @@ requires-dist = [
{ name = "slack-sdk", specifier = ">=3.34.0" }, { name = "slack-sdk", specifier = ">=3.34.0" },
{ name = "soundfile", specifier = ">=0.13.1" }, { name = "soundfile", specifier = ">=0.13.1" },
{ name = "spacy", specifier = ">=3.8.7" }, { name = "spacy", specifier = ">=3.8.7" },
{ name = "sse-starlette", specifier = ">=3.1.1,<3.1.2" },
{ name = "starlette", specifier = ">=0.40.0,<0.51.0" },
{ name = "static-ffmpeg", specifier = ">=2.13" }, { name = "static-ffmpeg", specifier = ">=2.13" },
{ name = "tavily-python", specifier = ">=0.3.2" }, { name = "tavily-python", specifier = ">=0.3.2" },
{ name = "trafilatura", specifier = ">=2.0.0" }, { name = "trafilatura", specifier = ">=2.0.0" },

View file

@ -79,25 +79,17 @@ export function DocumentsTableShell({
[documents, sortKey, sortDesc] [documents, sortKey, sortDesc]
); );
// Filter out SURFSENSE_DOCS for selection purposes const allSelectedOnPage = sorted.length > 0 && sorted.every((d) => selectedIds.has(d.id));
const selectableDocs = React.useMemo( const someSelectedOnPage = sorted.some((d) => selectedIds.has(d.id)) && !allSelectedOnPage;
() => sorted.filter((d) => d.document_type !== "SURFSENSE_DOCS"),
[sorted]
);
const allSelectedOnPage =
selectableDocs.length > 0 && selectableDocs.every((d) => selectedIds.has(d.id));
const someSelectedOnPage =
selectableDocs.some((d) => selectedIds.has(d.id)) && !allSelectedOnPage;
const toggleAll = (checked: boolean) => { const toggleAll = (checked: boolean) => {
const next = new Set(selectedIds); const next = new Set(selectedIds);
if (checked) if (checked)
selectableDocs.forEach((d) => { sorted.forEach((d) => {
next.add(d.id); next.add(d.id);
}); });
else else
selectableDocs.forEach((d) => { sorted.forEach((d) => {
next.delete(d.id); next.delete(d.id);
}); });
setSelectedIds(next); setSelectedIds(next);
@ -238,10 +230,9 @@ export function DocumentsTableShell({
const icon = getDocumentTypeIcon(doc.document_type); const icon = getDocumentTypeIcon(doc.document_type);
const title = doc.title; const title = doc.title;
const truncatedTitle = title.length > 30 ? `${title.slice(0, 30)}...` : title; const truncatedTitle = title.length > 30 ? `${title.slice(0, 30)}...` : title;
const isSurfsenseDoc = doc.document_type === "SURFSENSE_DOCS";
return ( return (
<motion.tr <motion.tr
key={`${doc.document_type}-${doc.id}`} key={doc.id}
initial={{ opacity: 0, y: 10 }} initial={{ opacity: 0, y: 10 }}
animate={{ animate={{
opacity: 1, opacity: 1,
@ -258,9 +249,8 @@ export function DocumentsTableShell({
> >
<TableCell className="px-4 py-3"> <TableCell className="px-4 py-3">
<Checkbox <Checkbox
checked={selectedIds.has(doc.id) && !isSurfsenseDoc} checked={selectedIds.has(doc.id)}
onCheckedChange={(v) => !isSurfsenseDoc && toggleOne(doc.id, !!v)} onCheckedChange={(v) => toggleOne(doc.id, !!v)}
disabled={isSurfsenseDoc}
aria-label="Select row" aria-label="Select row"
/> />
</TableCell> </TableCell>

View file

@ -20,7 +20,7 @@ import { DocumentsFilters } from "./components/DocumentsFilters";
import { DocumentsTableShell, type SortKey } from "./components/DocumentsTableShell"; import { DocumentsTableShell, type SortKey } from "./components/DocumentsTableShell";
import { PaginationControls } from "./components/PaginationControls"; import { PaginationControls } from "./components/PaginationControls";
import { ProcessingIndicator } from "./components/ProcessingIndicator"; import { ProcessingIndicator } from "./components/ProcessingIndicator";
import type { ColumnVisibility, Document } from "./components/types"; import type { ColumnVisibility } from "./components/types";
function useDebounced<T>(value: T, delay = 250) { function useDebounced<T>(value: T, delay = 250) {
const [debounced, setDebounced] = useState(value); const [debounced, setDebounced] = useState(value);
@ -60,39 +60,30 @@ export default function DocumentsTable() {
const { data: rawTypeCounts } = useAtomValue(documentTypeCountsAtom); const { data: rawTypeCounts } = useAtomValue(documentTypeCountsAtom);
const { mutateAsync: deleteDocumentMutation } = useAtomValue(deleteDocumentMutationAtom); const { mutateAsync: deleteDocumentMutation } = useAtomValue(deleteDocumentMutationAtom);
// Filter out SURFSENSE_DOCS from active types for regular documents API // Build query parameters for fetching documents
const regularDocumentTypes = useMemo(
() => activeTypes.filter((t) => t !== "SURFSENSE_DOCS"),
[activeTypes]
);
// Check if only SURFSENSE_DOCS is selected (skip regular docs query)
const onlySurfsenseDocsSelected = activeTypes.length === 1 && activeTypes[0] === "SURFSENSE_DOCS";
// Build query parameters for fetching documents (excluding SURFSENSE_DOCS type)
const queryParams = useMemo( const queryParams = useMemo(
() => ({ () => ({
search_space_id: searchSpaceId, search_space_id: searchSpaceId,
page: pageIndex, page: pageIndex,
page_size: pageSize, page_size: pageSize,
...(regularDocumentTypes.length > 0 && { document_types: regularDocumentTypes }), ...(activeTypes.length > 0 && { document_types: activeTypes }),
}), }),
[searchSpaceId, pageIndex, pageSize, regularDocumentTypes] [searchSpaceId, pageIndex, pageSize, activeTypes]
); );
// Build search query parameters (excluding SURFSENSE_DOCS type) // Build search query parameters
const searchQueryParams = useMemo( const searchQueryParams = useMemo(
() => ({ () => ({
search_space_id: searchSpaceId, search_space_id: searchSpaceId,
page: pageIndex, page: pageIndex,
page_size: pageSize, page_size: pageSize,
title: debouncedSearch.trim(), title: debouncedSearch.trim(),
...(regularDocumentTypes.length > 0 && { document_types: regularDocumentTypes }), ...(activeTypes.length > 0 && { document_types: activeTypes }),
}), }),
[searchSpaceId, pageIndex, pageSize, regularDocumentTypes, debouncedSearch] [searchSpaceId, pageIndex, pageSize, activeTypes, debouncedSearch]
); );
// Use query for fetching documents (disabled when only SURFSENSE_DOCS is selected) // Use query for fetching documents
const { const {
data: documentsResponse, data: documentsResponse,
isLoading: isDocumentsLoading, isLoading: isDocumentsLoading,
@ -102,10 +93,10 @@ export default function DocumentsTable() {
queryKey: cacheKeys.documents.globalQueryParams(queryParams), queryKey: cacheKeys.documents.globalQueryParams(queryParams),
queryFn: () => documentsApiService.getDocuments({ queryParams }), queryFn: () => documentsApiService.getDocuments({ queryParams }),
staleTime: 3 * 60 * 1000, // 3 minutes staleTime: 3 * 60 * 1000, // 3 minutes
enabled: !!searchSpaceId && !debouncedSearch.trim() && !onlySurfsenseDocsSelected, enabled: !!searchSpaceId && !debouncedSearch.trim(),
}); });
// Use query for searching documents (disabled when only SURFSENSE_DOCS is selected) // Use query for searching documents
const { const {
data: searchResponse, data: searchResponse,
isLoading: isSearchLoading, isLoading: isSearchLoading,
@ -115,114 +106,20 @@ export default function DocumentsTable() {
queryKey: cacheKeys.documents.globalQueryParams(searchQueryParams), queryKey: cacheKeys.documents.globalQueryParams(searchQueryParams),
queryFn: () => documentsApiService.searchDocuments({ queryParams: searchQueryParams }), queryFn: () => documentsApiService.searchDocuments({ queryParams: searchQueryParams }),
staleTime: 3 * 60 * 1000, // 3 minutes staleTime: 3 * 60 * 1000, // 3 minutes
enabled: !!searchSpaceId && !!debouncedSearch.trim() && !onlySurfsenseDocsSelected, enabled: !!searchSpaceId && !!debouncedSearch.trim(),
}); });
// Determine if we should show SurfSense docs (when no type filter or SURFSENSE_DOCS is selected)
const showSurfsenseDocs =
activeTypes.length === 0 || activeTypes.includes("SURFSENSE_DOCS" as DocumentTypeEnum);
// Use query for fetching SurfSense docs
const {
data: surfsenseDocsResponse,
isLoading: isSurfsenseDocsLoading,
refetch: refetchSurfsenseDocs,
} = useQuery({
queryKey: ["surfsense-docs", debouncedSearch, pageIndex, pageSize],
queryFn: () =>
documentsApiService.getSurfsenseDocs({
queryParams: {
page: pageIndex,
page_size: pageSize,
title: debouncedSearch.trim() || undefined,
},
}),
staleTime: 3 * 60 * 1000, // 3 minutes
enabled: showSurfsenseDocs,
});
// Transform SurfSense docs to match the Document type
const surfsenseDocsAsDocuments: Document[] = useMemo(() => {
if (!surfsenseDocsResponse?.items) return [];
return surfsenseDocsResponse.items.map((doc) => ({
id: doc.id,
title: doc.title,
document_type: "SURFSENSE_DOCS",
document_metadata: { source: doc.source },
content: doc.content,
created_at: doc.created_at || doc.updated_at || new Date().toISOString(),
search_space_id: -1, // Special value for global docs
}));
}, [surfsenseDocsResponse]);
// Merge type counts with SURFSENSE_DOCS count
const typeCounts = useMemo(() => {
const counts = { ...(rawTypeCounts || {}) };
if (surfsenseDocsResponse?.total) {
counts.SURFSENSE_DOCS = surfsenseDocsResponse.total;
}
return counts;
}, [rawTypeCounts, surfsenseDocsResponse?.total]);
// Extract documents and total based on search state // Extract documents and total based on search state
const regularDocuments = debouncedSearch.trim() const documents = debouncedSearch.trim()
? searchResponse?.items || [] ? searchResponse?.items || []
: documentsResponse?.items || []; : documentsResponse?.items || [];
const regularTotal = debouncedSearch.trim() const total = debouncedSearch.trim() ? searchResponse?.total || 0 : documentsResponse?.total || 0;
? searchResponse?.total || 0
: documentsResponse?.total || 0;
// Merge regular documents with SurfSense docs const loading = debouncedSearch.trim() ? isSearchLoading : isDocumentsLoading;
const documents = useMemo(() => { const error = debouncedSearch.trim() ? searchError : documentsError;
// If filtering by type and not including SURFSENSE_DOCS, only show regular docs
if (activeTypes.length > 0 && !activeTypes.includes("SURFSENSE_DOCS" as DocumentTypeEnum)) {
return regularDocuments;
}
// If filtering only by SURFSENSE_DOCS, only show surfsense docs
if (activeTypes.length === 1 && activeTypes[0] === "SURFSENSE_DOCS") {
return surfsenseDocsAsDocuments;
}
// Otherwise, merge both (surfsense docs first)
return [...surfsenseDocsAsDocuments, ...regularDocuments];
}, [regularDocuments, surfsenseDocsAsDocuments, activeTypes]);
const total = useMemo(() => { // Display results directly
if (activeTypes.length > 0 && !activeTypes.includes("SURFSENSE_DOCS" as DocumentTypeEnum)) { const displayDocs = documents;
return regularTotal;
}
if (activeTypes.length === 1 && activeTypes[0] === "SURFSENSE_DOCS") {
return surfsenseDocsResponse?.total || 0;
}
return regularTotal + (surfsenseDocsResponse?.total || 0);
}, [regularTotal, surfsenseDocsResponse?.total, activeTypes]);
const loading = useMemo(() => {
// If only SURFSENSE_DOCS selected, only check surfsense loading
if (onlySurfsenseDocsSelected) {
return isSurfsenseDocsLoading;
}
// Otherwise check both regular docs and surfsense docs loading
const regularLoading = debouncedSearch.trim() ? isSearchLoading : isDocumentsLoading;
return regularLoading || (showSurfsenseDocs && isSurfsenseDocsLoading);
}, [
onlySurfsenseDocsSelected,
isSurfsenseDocsLoading,
debouncedSearch,
isSearchLoading,
isDocumentsLoading,
showSurfsenseDocs,
]);
const error = useMemo(() => {
// If only SURFSENSE_DOCS selected, no regular docs errors
if (onlySurfsenseDocsSelected) {
return null;
}
return debouncedSearch.trim() ? searchError : documentsError;
}, [onlySurfsenseDocsSelected, debouncedSearch, searchError, documentsError]);
// Display server-filtered results directly
const displayDocs = documents || [];
const displayTotal = total; const displayTotal = total;
const pageStart = pageIndex * pageSize; const pageStart = pageIndex * pageSize;
const pageEnd = Math.min(pageStart + pageSize, displayTotal); const pageEnd = Math.min(pageStart + pageSize, displayTotal);
@ -242,33 +139,16 @@ export default function DocumentsTable() {
if (isRefreshing) return; if (isRefreshing) return;
setIsRefreshing(true); setIsRefreshing(true);
try { try {
const refetchPromises: Promise<unknown>[] = []; if (debouncedSearch.trim()) {
// Only refetch regular documents if not in "only surfsense docs" mode await refetchSearch();
if (!onlySurfsenseDocsSelected) { } else {
if (debouncedSearch.trim()) { await refetchDocuments();
refetchPromises.push(refetchSearch());
} else {
refetchPromises.push(refetchDocuments());
}
} }
if (showSurfsenseDocs) {
refetchPromises.push(refetchSurfsenseDocs());
}
await Promise.all(refetchPromises);
toast.success(t("refresh_success") || "Documents refreshed"); toast.success(t("refresh_success") || "Documents refreshed");
} finally { } finally {
setIsRefreshing(false); setIsRefreshing(false);
} }
}, [ }, [debouncedSearch, refetchSearch, refetchDocuments, t, isRefreshing]);
debouncedSearch,
refetchSearch,
refetchDocuments,
refetchSurfsenseDocs,
showSurfsenseDocs,
onlySurfsenseDocsSelected,
t,
isRefreshing,
]);
// Set up smart polling for active tasks - only polls when tasks are in progress // Set up smart polling for active tasks - only polls when tasks are in progress
const { summary } = useLogsSummary(searchSpaceId, 24, { const { summary } = useLogsSummary(searchSpaceId, 24, {
@ -385,7 +265,7 @@ export default function DocumentsTable() {
<ProcessingIndicator documentProcessorTasksCount={documentProcessorTasksCount} /> <ProcessingIndicator documentProcessorTasksCount={documentProcessorTasksCount} />
<DocumentsFilters <DocumentsFilters
typeCounts={typeCounts ?? {}} typeCounts={rawTypeCounts ?? {}}
selectedIds={selectedIds} selectedIds={selectedIds}
onSearch={setSearch} onSearch={setSearch}
searchValue={search} searchValue={search}

View file

@ -23,6 +23,7 @@ import {
// extractWriteTodosFromContent, // extractWriteTodosFromContent,
hydratePlanStateAtom, hydratePlanStateAtom,
} from "@/atoms/chat/plan-state.atom"; } from "@/atoms/chat/plan-state.atom";
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
import { Thread } from "@/components/assistant-ui/thread"; import { Thread } from "@/components/assistant-ui/thread";
import { ChatHeader } from "@/components/new-chat/chat-header"; import { ChatHeader } from "@/components/new-chat/chat-header";
import type { ThinkingStep } from "@/components/tool-ui/deepagent-thinking"; import type { ThinkingStep } from "@/components/tool-ui/deepagent-thinking";
@ -185,12 +186,25 @@ function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike {
} }
} }
// Build metadata.custom for author display in shared chats
const metadata = msg.author_id
? {
custom: {
author: {
displayName: msg.author_display_name ?? null,
avatarUrl: msg.author_avatar_url ?? null,
},
},
}
: undefined;
return { return {
id: `msg-${msg.id}`, id: `msg-${msg.id}`,
role: msg.role, role: msg.role,
content, content,
createdAt: new Date(msg.created_at), createdAt: new Date(msg.created_at),
attachments, attachments,
metadata,
}; };
} }
@ -238,6 +252,9 @@ export default function NewChatPage() {
const setMessageDocumentsMap = useSetAtom(messageDocumentsMapAtom); const setMessageDocumentsMap = useSetAtom(messageDocumentsMapAtom);
const hydratePlanState = useSetAtom(hydratePlanStateAtom); const hydratePlanState = useSetAtom(hydratePlanStateAtom);
// Get current user for author info in shared chats
const { data: currentUser } = useAtomValue(currentUserAtom);
// Create the attachment adapter for file processing // Create the attachment adapter for file processing
const attachmentAdapter = useMemo(() => createAttachmentAdapter(), []); const attachmentAdapter = useMemo(() => createAttachmentAdapter(), []);
@ -306,12 +323,6 @@ export default function NewChatPage() {
if (steps.length > 0) { if (steps.length > 0) {
restoredThinkingSteps.set(`msg-${msg.id}`, steps); restoredThinkingSteps.set(`msg-${msg.id}`, steps);
} }
// Hydrate write_todos plan state from persisted tool calls
// Disabled for now
// const writeTodosCalls = extractWriteTodosFromContent(msg.content);
// for (const todoData of writeTodosCalls) {
// hydratePlanState(todoData);
// }
} }
if (msg.role === "user") { if (msg.role === "user") {
const docs = extractMentionedDocuments(msg.content); const docs = extractMentionedDocuments(msg.content);
@ -448,13 +459,27 @@ export default function NewChatPage() {
// Add user message to state // Add user message to state
const userMsgId = `msg-user-${Date.now()}`; const userMsgId = `msg-user-${Date.now()}`;
// Include author metadata for shared chats
const authorMetadata =
currentThread?.visibility === "SEARCH_SPACE" && currentUser
? {
custom: {
author: {
displayName: currentUser.display_name ?? null,
avatarUrl: currentUser.avatar_url ?? null,
},
},
}
: undefined;
const userMessage: ThreadMessageLike = { const userMessage: ThreadMessageLike = {
id: userMsgId, id: userMsgId,
role: "user", role: "user",
content: message.content, content: message.content,
createdAt: new Date(), createdAt: new Date(),
// Include attachments so they can be displayed
attachments: message.attachments || [], attachments: message.attachments || [],
metadata: authorMetadata,
}; };
setMessages((prev) => [...prev, userMessage]); setMessages((prev) => [...prev, userMessage]);
@ -884,6 +909,8 @@ export default function NewChatPage() {
setMentionedDocuments, setMentionedDocuments,
setMessageDocumentsMap, setMessageDocumentsMap,
queryClient, queryClient,
currentThread,
currentUser,
] ]
); );

View file

@ -0,0 +1,122 @@
"use client";
import { Check, Copy, Key, Menu, Shield } from "lucide-react";
import { AnimatePresence, motion } from "motion/react";
import { useTranslations } from "next-intl";
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
import { Button } from "@/components/ui/button";
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip";
import { useApiKey } from "@/hooks/use-api-key";
interface ApiKeyContentProps {
onMenuClick: () => void;
}
export function ApiKeyContent({ onMenuClick }: ApiKeyContentProps) {
const t = useTranslations("userSettings");
const { apiKey, isLoading, copied, copyToClipboard } = useApiKey();
return (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ delay: 0.2, duration: 0.4 }}
className="h-full min-w-0 flex-1 overflow-hidden bg-background"
>
<div className="h-full overflow-y-auto">
<div className="mx-auto max-w-4xl p-4 md:p-6 lg:p-10">
<AnimatePresence mode="wait">
<motion.div
key="api-key-header"
initial={{ opacity: 0, y: 10 }}
animate={{ opacity: 1, y: 0 }}
exit={{ opacity: 0, y: -10 }}
transition={{ duration: 0.3 }}
className="mb-6 md:mb-8"
>
<div className="flex items-center gap-3 md:gap-4">
<Button
variant="outline"
size="icon"
onClick={onMenuClick}
className="h-10 w-10 shrink-0 md:hidden"
>
<Menu className="h-5 w-5" />
</Button>
<motion.div
initial={{ scale: 0.8, opacity: 0 }}
animate={{ scale: 1, opacity: 1 }}
transition={{ delay: 0.1, duration: 0.3 }}
className="flex h-10 w-10 shrink-0 items-center justify-center rounded-lg border border-primary/10 bg-gradient-to-br from-primary/20 to-primary/5 shadow-sm md:h-14 md:w-14 md:rounded-2xl"
>
<Key className="h-5 w-5 text-primary md:h-7 md:w-7" />
</motion.div>
<div className="min-w-0">
<h1 className="truncate text-lg font-bold tracking-tight md:text-2xl">
{t("api_key_title")}
</h1>
<p className="text-sm text-muted-foreground">{t("api_key_description")}</p>
</div>
</div>
</motion.div>
</AnimatePresence>
<AnimatePresence mode="wait">
<motion.div
key="api-key-content"
initial={{ opacity: 0, y: 20 }}
animate={{ opacity: 1, y: 0 }}
exit={{ opacity: 0, y: -20 }}
transition={{ duration: 0.35, ease: [0.4, 0, 0.2, 1] }}
className="space-y-6"
>
<Alert>
<Shield className="h-4 w-4" />
<AlertTitle>{t("api_key_warning_title")}</AlertTitle>
<AlertDescription>{t("api_key_warning_description")}</AlertDescription>
</Alert>
<div className="rounded-lg border bg-card p-6">
<h3 className="mb-4 font-medium">{t("your_api_key")}</h3>
{isLoading ? (
<div className="h-12 w-full animate-pulse rounded-md bg-muted" />
) : apiKey ? (
<div className="flex items-center gap-2">
<div className="flex-1 overflow-x-auto rounded-md bg-muted p-3 font-mono text-sm">
{apiKey}
</div>
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="outline"
size="icon"
onClick={copyToClipboard}
className="shrink-0"
>
{copied ? <Check className="h-4 w-4" /> : <Copy className="h-4 w-4" />}
</Button>
</TooltipTrigger>
<TooltipContent>{copied ? t("copied") : t("copy")}</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
) : (
<p className="text-center text-muted-foreground">{t("no_api_key")}</p>
)}
</div>
<div className="rounded-lg border bg-card p-6">
<h3 className="mb-2 font-medium">{t("usage_title")}</h3>
<p className="mb-4 text-sm text-muted-foreground">{t("usage_description")}</p>
<pre className="overflow-x-auto rounded-md bg-muted p-3 text-sm">
<code>Authorization: Bearer {apiKey || "YOUR_API_KEY"}</code>
</pre>
</div>
</motion.div>
</AnimatePresence>
</div>
</div>
</motion.div>
);
}

View file

@ -0,0 +1,181 @@
"use client";
import { useAtomValue } from "jotai";
import { Loader2, Menu, User } from "lucide-react";
import { AnimatePresence, motion } from "motion/react";
import { useTranslations } from "next-intl";
import { useEffect, useState } from "react";
import { toast } from "sonner";
import { updateUserMutationAtom } from "@/atoms/user/user-mutation.atoms";
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
interface ProfileContentProps {
onMenuClick: () => void;
}
function AvatarDisplay({ url, fallback }: { url?: string; fallback: string }) {
const [hasError, setHasError] = useState(false);
useEffect(() => {
setHasError(false);
}, [url]);
if (url && !hasError) {
return (
<img
src={url}
alt="Avatar"
className="h-16 w-16 rounded-xl object-cover"
onError={() => setHasError(true)}
/>
);
}
return (
<div className="flex h-16 w-16 items-center justify-center rounded-xl bg-muted text-xl font-semibold text-muted-foreground">
{fallback}
</div>
);
}
export function ProfileContent({ onMenuClick }: ProfileContentProps) {
const t = useTranslations("userSettings");
const { data: user, isLoading: isUserLoading } = useAtomValue(currentUserAtom);
const { mutateAsync: updateUser, isPending } = useAtomValue(updateUserMutationAtom);
const [displayName, setDisplayName] = useState("");
useEffect(() => {
if (user) {
setDisplayName(user.display_name || "");
}
}, [user]);
const getInitials = (email: string) => {
const name = email.split("@")[0];
return name.slice(0, 2).toUpperCase();
};
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
try {
await updateUser({
display_name: displayName || null,
});
toast.success(t("profile_saved"));
} catch {
toast.error(t("profile_save_error"));
}
};
const hasChanges = displayName !== (user?.display_name || "");
return (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ delay: 0.2, duration: 0.4 }}
className="h-full min-w-0 flex-1 overflow-hidden bg-background"
>
<div className="h-full overflow-y-auto">
<div className="mx-auto max-w-4xl p-4 md:p-6 lg:p-10">
<AnimatePresence mode="wait">
<motion.div
key="profile-header"
initial={{ opacity: 0, y: 10 }}
animate={{ opacity: 1, y: 0 }}
exit={{ opacity: 0, y: -10 }}
transition={{ duration: 0.3 }}
className="mb-6 md:mb-8"
>
<div className="flex items-center gap-3 md:gap-4">
<Button
variant="outline"
size="icon"
onClick={onMenuClick}
className="h-10 w-10 shrink-0 md:hidden"
>
<Menu className="h-5 w-5" />
</Button>
<motion.div
initial={{ scale: 0.8, opacity: 0 }}
animate={{ scale: 1, opacity: 1 }}
transition={{ delay: 0.1, duration: 0.3 }}
className="flex h-10 w-10 shrink-0 items-center justify-center rounded-lg border border-primary/10 bg-gradient-to-br from-primary/20 to-primary/5 shadow-sm md:h-14 md:w-14 md:rounded-2xl"
>
<User className="h-5 w-5 text-primary md:h-7 md:w-7" />
</motion.div>
<div className="min-w-0">
<h1 className="truncate text-lg font-bold tracking-tight md:text-2xl">
{t("profile_title")}
</h1>
<p className="text-sm text-muted-foreground">{t("profile_description")}</p>
</div>
</div>
</motion.div>
</AnimatePresence>
<AnimatePresence mode="wait">
<motion.div
key="profile-content"
initial={{ opacity: 0, y: 20 }}
animate={{ opacity: 1, y: 0 }}
exit={{ opacity: 0, y: -20 }}
transition={{ duration: 0.35, ease: [0.4, 0, 0.2, 1] }}
>
{isUserLoading ? (
<div className="flex items-center justify-center py-12">
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
</div>
) : (
<form onSubmit={handleSubmit} className="space-y-6">
<div className="rounded-lg border bg-card p-6">
<div className="flex flex-col gap-6">
<div className="space-y-2">
<Label>{t("profile_avatar")}</Label>
<AvatarDisplay
url={user?.avatar_url || undefined}
fallback={getInitials(user?.email || "")}
/>
</div>
<div className="space-y-2">
<Label htmlFor="display-name">{t("profile_display_name")}</Label>
<Input
id="display-name"
type="text"
placeholder={user?.email?.split("@")[0]}
value={displayName}
onChange={(e) => setDisplayName(e.target.value)}
/>
<p className="text-xs text-muted-foreground">
{t("profile_display_name_hint")}
</p>
</div>
<div className="space-y-2">
<Label>{t("profile_email")}</Label>
<Input type="email" value={user?.email || ""} disabled />
</div>
</div>
</div>
<div className="flex justify-end">
<Button type="submit" disabled={isPending || !hasChanges}>
{isPending && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
{t("profile_save")}
</Button>
</div>
</form>
)}
</motion.div>
</AnimatePresence>
</div>
</div>
</motion.div>
);
}

View file

@ -0,0 +1,154 @@
"use client";
import type { LucideIcon } from "lucide-react";
import { ArrowLeft, ChevronRight, X } from "lucide-react";
import { AnimatePresence, motion } from "motion/react";
import { useTranslations } from "next-intl";
import { Button } from "@/components/ui/button";
import { cn } from "@/lib/utils";
export interface SettingsNavItem {
id: string;
label: string;
description: string;
icon: LucideIcon;
}
interface UserSettingsSidebarProps {
activeSection: string;
onSectionChange: (section: string) => void;
onBackToApp: () => void;
isOpen: boolean;
onClose: () => void;
navItems: SettingsNavItem[];
}
export function UserSettingsSidebar({
activeSection,
onSectionChange,
onBackToApp,
isOpen,
onClose,
navItems,
}: UserSettingsSidebarProps) {
const t = useTranslations("userSettings");
const handleNavClick = (sectionId: string) => {
onSectionChange(sectionId);
onClose();
};
return (
<>
<AnimatePresence>
{isOpen && (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
transition={{ duration: 0.2 }}
className="fixed inset-0 z-40 bg-background/80 backdrop-blur-sm md:hidden"
onClick={onClose}
/>
)}
</AnimatePresence>
<aside
className={cn(
"fixed left-0 top-0 z-50 md:relative md:z-auto",
"flex h-full w-72 shrink-0 flex-col bg-background md:bg-muted/30",
"md:border-r",
"transition-transform duration-300 ease-out",
"md:translate-x-0",
isOpen ? "translate-x-0" : "-translate-x-full md:translate-x-0"
)}
>
{/* Header with title */}
<div className="space-y-3 p-4">
<div className="flex items-center justify-between">
<Button
variant="ghost"
onClick={onBackToApp}
className="group h-11 justify-start gap-3 px-3 hover:bg-muted"
>
<div className="flex h-8 w-8 items-center justify-center rounded-lg bg-primary/10 transition-colors group-hover:bg-primary/20">
<ArrowLeft className="h-4 w-4 text-primary" />
</div>
<span className="font-medium">{t("back_to_app")}</span>
</Button>
<Button variant="ghost" size="icon" onClick={onClose} className="h-9 w-9 md:hidden">
<X className="h-5 w-5" />
</Button>
</div>
{/* Settings Title */}
<div className="px-3">
<h2 className="text-lg font-semibold text-foreground">{t("title")}</h2>
</div>
</div>
<nav className="flex-1 space-y-1 overflow-y-auto px-3 py-2">
{navItems.map((item, index) => {
const isActive = activeSection === item.id;
const Icon = item.icon;
return (
<motion.button
key={item.id}
initial={{ opacity: 0, x: -10 }}
animate={{ opacity: 1, x: 0 }}
transition={{ delay: 0.1 + index * 0.05, duration: 0.3 }}
onClick={() => handleNavClick(item.id)}
whileHover={{ scale: 1.01 }}
whileTap={{ scale: 0.99 }}
className={cn(
"relative flex w-full items-center gap-3 rounded-xl px-3 py-3 text-left transition-all duration-200",
isActive ? "border border-border bg-muted shadow-sm" : "hover:bg-muted/60"
)}
>
{isActive && (
<motion.div
layoutId="userSettingsActiveIndicator"
className="absolute left-0 top-1/2 h-8 w-1 -translate-y-1/2 rounded-r-full bg-primary"
initial={false}
transition={{
type: "spring",
stiffness: 500,
damping: 35,
}}
/>
)}
<div
className={cn(
"flex h-9 w-9 items-center justify-center rounded-lg transition-colors",
isActive ? "bg-primary/10 text-primary" : "bg-muted text-muted-foreground"
)}
>
<Icon className="h-4 w-4" />
</div>
<div className="min-w-0 flex-1">
<p
className={cn(
"truncate text-sm font-medium transition-colors",
isActive ? "text-foreground" : "text-muted-foreground"
)}
>
{item.label}
</p>
<p className="truncate text-xs text-muted-foreground/70">{item.description}</p>
</div>
<ChevronRight
className={cn(
"h-4 w-4 shrink-0 transition-all",
isActive
? "translate-x-0 text-primary opacity-100"
: "-translate-x-1 text-muted-foreground/40 opacity-0"
)}
/>
</motion.button>
);
})}
</nav>
</aside>
</>
);
}

View file

@ -1,286 +1,27 @@
"use client"; "use client";
import { import { Key, User } from "lucide-react";
ArrowLeft, import { motion } from "motion/react";
Check,
ChevronRight,
Copy,
Key,
type LucideIcon,
Menu,
Shield,
X,
} from "lucide-react";
import { AnimatePresence, motion } from "motion/react";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { useTranslations } from "next-intl"; import { useTranslations } from "next-intl";
import { useCallback, useState } from "react"; import { useCallback, useState } from "react";
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; import { ApiKeyContent } from "./components/ApiKeyContent";
import { Button } from "@/components/ui/button"; import { ProfileContent } from "./components/ProfileContent";
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip"; import { type SettingsNavItem, UserSettingsSidebar } from "./components/UserSettingsSidebar";
import { useApiKey } from "@/hooks/use-api-key";
import { cn } from "@/lib/utils";
interface SettingsNavItem {
id: string;
label: string;
description: string;
icon: LucideIcon;
}
function UserSettingsSidebar({
activeSection,
onSectionChange,
onBackToApp,
isOpen,
onClose,
navItems,
}: {
activeSection: string;
onSectionChange: (section: string) => void;
onBackToApp: () => void;
isOpen: boolean;
onClose: () => void;
navItems: SettingsNavItem[];
}) {
const t = useTranslations("userSettings");
const handleNavClick = (sectionId: string) => {
onSectionChange(sectionId);
onClose();
};
return (
<>
<AnimatePresence>
{isOpen && (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
transition={{ duration: 0.2 }}
className="fixed inset-0 z-40 bg-background/80 backdrop-blur-sm md:hidden"
onClick={onClose}
/>
)}
</AnimatePresence>
<aside
className={cn(
"fixed left-0 top-0 z-50 md:relative md:z-auto",
"flex h-full w-72 shrink-0 flex-col bg-background md:bg-muted/30",
"md:border-r",
"transition-transform duration-300 ease-out",
"md:translate-x-0",
isOpen ? "translate-x-0" : "-translate-x-full md:translate-x-0"
)}
>
{/* Header with title */}
<div className="space-y-3 p-4">
<div className="flex items-center justify-between">
<Button
variant="ghost"
onClick={onBackToApp}
className="group h-11 justify-start gap-3 px-3 hover:bg-muted"
>
<div className="flex h-8 w-8 items-center justify-center rounded-lg bg-primary/10 transition-colors group-hover:bg-primary/20">
<ArrowLeft className="h-4 w-4 text-primary" />
</div>
<span className="font-medium">{t("back_to_app")}</span>
</Button>
<Button variant="ghost" size="icon" onClick={onClose} className="h-9 w-9 md:hidden">
<X className="h-5 w-5" />
</Button>
</div>
{/* Settings Title */}
<div className="px-3">
<h2 className="text-lg font-semibold text-foreground">{t("title")}</h2>
</div>
</div>
<nav className="flex-1 space-y-1 overflow-y-auto px-3 py-2">
{navItems.map((item, index) => {
const isActive = activeSection === item.id;
const Icon = item.icon;
return (
<motion.button
key={item.id}
initial={{ opacity: 0, x: -10 }}
animate={{ opacity: 1, x: 0 }}
transition={{ delay: 0.1 + index * 0.05, duration: 0.3 }}
onClick={() => handleNavClick(item.id)}
whileHover={{ scale: 1.01 }}
whileTap={{ scale: 0.99 }}
className={cn(
"relative flex w-full items-center gap-3 rounded-xl px-3 py-3 text-left transition-all duration-200",
isActive ? "border border-border bg-muted shadow-sm" : "hover:bg-muted/60"
)}
>
{isActive && (
<motion.div
layoutId="userSettingsActiveIndicator"
className="absolute left-0 top-1/2 h-8 w-1 -translate-y-1/2 rounded-r-full bg-primary"
initial={false}
transition={{
type: "spring",
stiffness: 500,
damping: 35,
}}
/>
)}
<div
className={cn(
"flex h-9 w-9 items-center justify-center rounded-lg transition-colors",
isActive ? "bg-primary/10 text-primary" : "bg-muted text-muted-foreground"
)}
>
<Icon className="h-4 w-4" />
</div>
<div className="min-w-0 flex-1">
<p
className={cn(
"truncate text-sm font-medium transition-colors",
isActive ? "text-foreground" : "text-muted-foreground"
)}
>
{item.label}
</p>
<p className="truncate text-xs text-muted-foreground/70">{item.description}</p>
</div>
<ChevronRight
className={cn(
"h-4 w-4 shrink-0 transition-all",
isActive
? "translate-x-0 text-primary opacity-100"
: "-translate-x-1 text-muted-foreground/40 opacity-0"
)}
/>
</motion.button>
);
})}
</nav>
</aside>
</>
);
}
function ApiKeyContent({ onMenuClick }: { onMenuClick: () => void }) {
const t = useTranslations("userSettings");
const { apiKey, isLoading, copied, copyToClipboard } = useApiKey();
return (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ delay: 0.2, duration: 0.4 }}
className="h-full min-w-0 flex-1 overflow-hidden bg-background"
>
<div className="h-full overflow-y-auto">
<div className="mx-auto max-w-4xl p-4 md:p-6 lg:p-10">
<AnimatePresence mode="wait">
<motion.div
key="api-key-header"
initial={{ opacity: 0, y: 10 }}
animate={{ opacity: 1, y: 0 }}
exit={{ opacity: 0, y: -10 }}
transition={{ duration: 0.3 }}
className="mb-6 md:mb-8"
>
<div className="flex items-center gap-3 md:gap-4">
<Button
variant="outline"
size="icon"
onClick={onMenuClick}
className="h-10 w-10 shrink-0 md:hidden"
>
<Menu className="h-5 w-5" />
</Button>
<motion.div
initial={{ scale: 0.8, opacity: 0 }}
animate={{ scale: 1, opacity: 1 }}
transition={{ delay: 0.1, duration: 0.3 }}
className="flex h-10 w-10 shrink-0 items-center justify-center rounded-lg border border-primary/10 bg-gradient-to-br from-primary/20 to-primary/5 shadow-sm md:h-14 md:w-14 md:rounded-2xl"
>
<Key className="h-5 w-5 text-primary md:h-7 md:w-7" />
</motion.div>
<div className="min-w-0">
<h1 className="truncate text-lg font-bold tracking-tight md:text-2xl">
{t("api_key_title")}
</h1>
<p className="text-sm text-muted-foreground">{t("api_key_description")}</p>
</div>
</div>
</motion.div>
</AnimatePresence>
<AnimatePresence mode="wait">
<motion.div
key="api-key-content"
initial={{ opacity: 0, y: 20 }}
animate={{ opacity: 1, y: 0 }}
exit={{ opacity: 0, y: -20 }}
transition={{ duration: 0.35, ease: [0.4, 0, 0.2, 1] }}
className="space-y-6"
>
<Alert>
<Shield className="h-4 w-4" />
<AlertTitle>{t("api_key_warning_title")}</AlertTitle>
<AlertDescription>{t("api_key_warning_description")}</AlertDescription>
</Alert>
<div className="rounded-lg border bg-card p-6">
<h3 className="mb-4 font-medium">{t("your_api_key")}</h3>
{isLoading ? (
<div className="h-12 w-full animate-pulse rounded-md bg-muted" />
) : apiKey ? (
<div className="flex items-center gap-2">
<div className="flex-1 overflow-x-auto rounded-md bg-muted p-3 font-mono text-sm">
{apiKey}
</div>
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="outline"
size="icon"
onClick={copyToClipboard}
className="shrink-0"
>
{copied ? <Check className="h-4 w-4" /> : <Copy className="h-4 w-4" />}
</Button>
</TooltipTrigger>
<TooltipContent>{copied ? t("copied") : t("copy")}</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
) : (
<p className="text-center text-muted-foreground">{t("no_api_key")}</p>
)}
</div>
<div className="rounded-lg border bg-card p-6">
<h3 className="mb-2 font-medium">{t("usage_title")}</h3>
<p className="mb-4 text-sm text-muted-foreground">{t("usage_description")}</p>
<pre className="overflow-x-auto rounded-md bg-muted p-3 text-sm">
<code>Authorization: Bearer {apiKey || "YOUR_API_KEY"}</code>
</pre>
</div>
</motion.div>
</AnimatePresence>
</div>
</div>
</motion.div>
);
}
export default function UserSettingsPage() { export default function UserSettingsPage() {
const t = useTranslations("userSettings"); const t = useTranslations("userSettings");
const router = useRouter(); const router = useRouter();
const [activeSection, setActiveSection] = useState("api-key"); const [activeSection, setActiveSection] = useState("profile");
const [isSidebarOpen, setIsSidebarOpen] = useState(false); const [isSidebarOpen, setIsSidebarOpen] = useState(false);
const navItems: SettingsNavItem[] = [ const navItems: SettingsNavItem[] = [
{
id: "profile",
label: t("profile_nav_label"),
description: t("profile_nav_description"),
icon: User,
},
{ {
id: "api-key", id: "api-key",
label: t("api_key_nav_label"), label: t("api_key_nav_label"),
@ -310,6 +51,9 @@ export default function UserSettingsPage() {
onClose={() => setIsSidebarOpen(false)} onClose={() => setIsSidebarOpen(false)}
navItems={navItems} navItems={navItems}
/> />
{activeSection === "profile" && (
<ProfileContent onMenuClick={() => setIsSidebarOpen(true)} />
)}
{activeSection === "api-key" && ( {activeSection === "api-key" && (
<ApiKeyContent onMenuClick={() => setIsSidebarOpen(true)} /> <ApiKeyContent onMenuClick={() => setIsSidebarOpen(true)} />
)} )}

View file

@ -0,0 +1,18 @@
import { atomWithMutation, queryClientAtom } from "jotai-tanstack-query";
import type { UpdateUserRequest } from "@/contracts/types/user.types";
import { userApiService } from "@/lib/apis/user-api.service";
import { cacheKeys } from "@/lib/query-client/cache-keys";
export const updateUserMutationAtom = atomWithMutation((get) => {
const queryClient = get(queryClientAtom);
return {
mutationKey: cacheKeys.user.current(),
mutationFn: async (request: UpdateUserRequest) => {
return userApiService.updateMe(request);
},
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: cacheKeys.user.current() });
},
};
});

View file

@ -96,38 +96,37 @@ const DocumentUploadPopupContent: FC<{
return ( return (
<Dialog open={isOpen} onOpenChange={onOpenChange}> <Dialog open={isOpen} onOpenChange={onOpenChange}>
<DialogContent className="max-w-4xl w-[95vw] sm:w-full max-h-[calc(100vh-2rem)] sm:h-[85vh] flex flex-col p-0 gap-0 overflow-hidden border border-border bg-muted text-foreground [&>button]:right-3 sm:[&>button]:right-12 [&>button]:top-4 sm:[&>button]:top-10 [&>button]:opacity-80 hover:[&>button]:opacity-100 [&>button]:z-[100] [&>button_svg]:size-4 sm:[&>button_svg]:size-5"> <DialogContent className="max-w-4xl w-[95vw] sm:w-full h-[calc(100dvh-2rem)] sm:h-[85vh] flex flex-col p-0 gap-0 overflow-hidden border border-border bg-muted text-foreground [&>button]:right-3 sm:[&>button]:right-12 [&>button]:top-3 sm:[&>button]:top-10 [&>button]:opacity-80 hover:[&>button]:opacity-100 [&>button]:z-[100] [&>button_svg]:size-4 sm:[&>button_svg]:size-5">
<DialogTitle className="sr-only">Upload Document</DialogTitle> <DialogTitle className="sr-only">Upload Document</DialogTitle>
{/* Fixed Header */} {/* Scrollable container for mobile */}
<div className="flex-shrink-0 px-4 sm:px-12 pt-6 sm:pt-10 transition-shadow duration-200 relative z-10"> <div className="flex-1 min-h-0 overflow-y-auto overscroll-contain">
{/* Upload header */} {/* Header - scrolls with content on mobile */}
<div className="flex items-center gap-2 sm:gap-4 mb-2 sm:mb-6"> <div className="sticky top-0 z-20 bg-muted px-4 sm:px-12 pt-4 sm:pt-10 pb-2 sm:pb-0">
<div className="flex h-10 w-10 sm:h-14 sm:w-14 items-center justify-center rounded-lg sm:rounded-xl bg-primary/10 border border-primary/20 flex-shrink-0"> {/* Upload header */}
<Upload className="size-5 sm:size-7 text-primary" /> <div className="flex items-center gap-2 sm:gap-4 mb-2 sm:mb-6">
</div> <div className="flex h-9 w-9 sm:h-14 sm:w-14 items-center justify-center rounded-lg sm:rounded-xl bg-primary/10 border border-primary/20 flex-shrink-0">
<div className="flex-1 min-w-0"> <Upload className="size-4 sm:size-7 text-primary" />
<h2 className="text-lg sm:text-2xl font-semibold tracking-tight">Upload Documents</h2> </div>
<p className="text-xs sm:text-base text-muted-foreground mt-0.5 sm:mt-1"> <div className="flex-1 min-w-0 pr-8 sm:pr-0">
Upload and sync your documents to your search space <h2 className="text-base sm:text-2xl font-semibold tracking-tight">
</p> Upload Documents
</h2>
<p className="text-xs sm:text-base text-muted-foreground mt-0.5 sm:mt-1 line-clamp-1 sm:line-clamp-none">
Upload and sync your documents to your search space
</p>
</div>
</div> </div>
</div> </div>
{/* Content */}
<div className="px-4 sm:px-12 pb-4 sm:pb-16">
<DocumentUploadTab searchSpaceId={searchSpaceId} onSuccess={handleSuccess} />
</div>
</div> </div>
{/* Scrollable Content */} {/* Bottom fade shadow - hidden on very small screens */}
<div className="flex-1 min-h-0 relative overflow-hidden"> <div className="hidden sm:block absolute bottom-0 left-0 right-0 h-7 bg-gradient-to-t from-muted via-muted/80 to-transparent pointer-events-none z-10" />
<div className="h-full overflow-y-auto">
<div className="px-6 sm:px-12 pb-5 sm:pb-16">
<DocumentUploadTab
searchSpaceId={searchSpaceId}
onSuccess={handleSuccess}
/>
</div>
</div>
{/* Bottom fade shadow */}
<div className="absolute bottom-0 left-0 right-0 h-2 sm:h-7 bg-gradient-to-t from-muted via-muted/80 to-transparent pointer-events-none z-10" />
</div>
</DialogContent> </DialogContent>
</Dialog> </Dialog>
); );

View file

@ -19,9 +19,7 @@ import {
ChevronRightIcon, ChevronRightIcon,
CopyIcon, CopyIcon,
DownloadIcon, DownloadIcon,
FileText,
Loader2, Loader2,
PencilIcon,
RefreshCwIcon, RefreshCwIcon,
SquareIcon, SquareIcon,
} from "lucide-react"; } from "lucide-react";
@ -31,7 +29,6 @@ import { createPortal } from "react-dom";
import { import {
mentionedDocumentIdsAtom, mentionedDocumentIdsAtom,
mentionedDocumentsAtom, mentionedDocumentsAtom,
messageDocumentsMapAtom,
} from "@/atoms/chat/mentioned-documents.atom"; } from "@/atoms/chat/mentioned-documents.atom";
import { import {
globalNewLLMConfigsAtom, globalNewLLMConfigsAtom,
@ -39,11 +36,7 @@ import {
newLLMConfigsAtom, newLLMConfigsAtom,
} from "@/atoms/new-llm-config/new-llm-config-query.atoms"; } from "@/atoms/new-llm-config/new-llm-config-query.atoms";
import { currentUserAtom } from "@/atoms/user/user-query.atoms"; import { currentUserAtom } from "@/atoms/user/user-query.atoms";
import { import { ComposerAddAttachment, ComposerAttachments } from "@/components/assistant-ui/attachment";
ComposerAddAttachment,
ComposerAttachments,
UserMessageAttachments,
} from "@/components/assistant-ui/attachment";
import { ConnectorIndicator } from "@/components/assistant-ui/connector-popup"; import { ConnectorIndicator } from "@/components/assistant-ui/connector-popup";
import { import {
InlineMentionEditor, InlineMentionEditor,
@ -56,6 +49,7 @@ import {
} from "@/components/assistant-ui/thinking-steps"; } from "@/components/assistant-ui/thinking-steps";
import { ToolFallback } from "@/components/assistant-ui/tool-fallback"; import { ToolFallback } from "@/components/assistant-ui/tool-fallback";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
import { UserMessage } from "@/components/assistant-ui/user-message";
import { import {
DocumentMentionPicker, DocumentMentionPicker,
type DocumentMentionPickerRef, type DocumentMentionPickerRef,
@ -639,70 +633,6 @@ const AssistantActionBar: FC = () => {
); );
}; };
const UserMessage: FC = () => {
const messageId = useAssistantState(({ message }) => message?.id);
const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom);
const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined;
const hasAttachments = useAssistantState(
({ message }) => message?.attachments && message.attachments.length > 0
);
return (
<MessagePrimitive.Root
className="aui-user-message-root fade-in slide-in-from-bottom-1 mx-auto grid w-full max-w-(--thread-max-width) animate-in auto-rows-auto grid-cols-[minmax(72px,1fr)_auto] content-start gap-y-2 px-2 py-3 duration-150 [&:where(>*)]:col-start-2"
data-role="user"
>
<div className="aui-user-message-content-wrapper col-start-2 min-w-0">
{/* Display attachments and mentioned documents */}
{(hasAttachments || (mentionedDocs && mentionedDocs.length > 0)) && (
<div className="flex flex-wrap items-end gap-2 mb-2 justify-end">
{/* Attachments (images show as thumbnails, documents as chips) */}
<UserMessageAttachments />
{/* Mentioned documents as chips */}
{mentionedDocs?.map((doc) => (
<span
key={`${doc.document_type}:${doc.id}`}
className="inline-flex items-center gap-1 px-2 py-0.5 rounded-full bg-primary/10 text-xs font-medium text-primary border border-primary/20"
title={doc.title}
>
<FileText className="size-3" />
<span className="max-w-[150px] truncate">{doc.title}</span>
</span>
))}
</div>
)}
{/* Message bubble with action bar positioned relative to it */}
<div className="relative">
<div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground">
<MessagePrimitive.Parts />
</div>
<div className="aui-user-action-bar-wrapper absolute top-1/2 right-full -translate-y-1/2 pr-1">
<UserActionBar />
</div>
</div>
</div>
<BranchPicker className="aui-user-branch-picker -mr-1 col-span-full col-start-1 row-start-3 justify-end" />
</MessagePrimitive.Root>
);
};
const UserActionBar: FC = () => {
return (
<ActionBarPrimitive.Root
hideWhenRunning
autohide="not-last"
className="aui-user-action-bar-root flex flex-col items-end"
>
<ActionBarPrimitive.Edit asChild>
<TooltipIconButton tooltip="Edit" className="aui-user-action-edit p-4">
<PencilIcon />
</TooltipIconButton>
</ActionBarPrimitive.Edit>
</ActionBarPrimitive.Root>
);
};
const EditComposer: FC = () => { const EditComposer: FC = () => {
return ( return (
<MessagePrimitive.Root className="aui-edit-composer-wrapper mx-auto flex w-full max-w-(--thread-max-width) flex-col px-2 py-3"> <MessagePrimitive.Root className="aui-edit-composer-wrapper mx-auto flex w-full max-w-(--thread-max-width) flex-col px-2 py-3">

View file

@ -1,16 +1,54 @@
import { ActionBarPrimitive, MessagePrimitive, useAssistantState } from "@assistant-ui/react"; import { ActionBarPrimitive, MessagePrimitive, useAssistantState } from "@assistant-ui/react";
import { useAtomValue } from "jotai"; import { useAtomValue } from "jotai";
import { FileText, PencilIcon } from "lucide-react"; import { FileText, PencilIcon } from "lucide-react";
import type { FC } from "react"; import { type FC, useState } from "react";
import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom"; import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom";
import { UserMessageAttachments } from "@/components/assistant-ui/attachment"; import { UserMessageAttachments } from "@/components/assistant-ui/attachment";
import { BranchPicker } from "@/components/assistant-ui/branch-picker"; import { BranchPicker } from "@/components/assistant-ui/branch-picker";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
interface AuthorMetadata {
displayName: string | null;
avatarUrl: string | null;
}
const UserAvatar: FC<AuthorMetadata> = ({ displayName, avatarUrl }) => {
const [hasError, setHasError] = useState(false);
const initials = displayName
? displayName
.split(" ")
.map((n) => n[0])
.join("")
.toUpperCase()
.slice(0, 2)
: "U";
if (avatarUrl && !hasError) {
return (
<img
src={avatarUrl}
alt={displayName || "User"}
className="size-8 rounded-full object-cover"
referrerPolicy="no-referrer"
onError={() => setHasError(true)}
/>
);
}
return (
<div className="flex size-8 items-center justify-center rounded-full bg-primary/10 text-xs font-medium text-primary">
{initials}
</div>
);
};
export const UserMessage: FC = () => { export const UserMessage: FC = () => {
const messageId = useAssistantState(({ message }) => message?.id); const messageId = useAssistantState(({ message }) => message?.id);
const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom); const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom);
const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined; const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined;
const metadata = useAssistantState(({ message }) => message?.metadata);
const author = metadata?.custom?.author as AuthorMetadata | undefined;
const hasAttachments = useAssistantState( const hasAttachments = useAssistantState(
({ message }) => message?.attachments && message.attachments.length > 0 ({ message }) => message?.attachments && message.attachments.length > 0
); );
@ -20,34 +58,42 @@ export const UserMessage: FC = () => {
className="aui-user-message-root fade-in slide-in-from-bottom-1 mx-auto grid w-full max-w-(--thread-max-width) animate-in auto-rows-auto grid-cols-[minmax(72px,1fr)_auto] content-start gap-y-2 px-2 py-3 duration-150 [&:where(>*)]:col-start-2" className="aui-user-message-root fade-in slide-in-from-bottom-1 mx-auto grid w-full max-w-(--thread-max-width) animate-in auto-rows-auto grid-cols-[minmax(72px,1fr)_auto] content-start gap-y-2 px-2 py-3 duration-150 [&:where(>*)]:col-start-2"
data-role="user" data-role="user"
> >
<div className="aui-user-message-content-wrapper col-start-2 min-w-0"> <div className="aui-user-message-content-wrapper col-start-2 min-w-0 flex items-end gap-2">
{/* Display attachments and mentioned documents */} <div className="flex-1 min-w-0">
{(hasAttachments || (mentionedDocs && mentionedDocs.length > 0)) && ( {/* Display attachments and mentioned documents */}
<div className="flex flex-wrap items-end gap-2 mb-2 justify-end"> {(hasAttachments || (mentionedDocs && mentionedDocs.length > 0)) && (
{/* Attachments (images show as thumbnails, documents as chips) */} <div className="flex flex-wrap items-end gap-2 mb-2 justify-end">
<UserMessageAttachments /> {/* Attachments (images show as thumbnails, documents as chips) */}
{/* Mentioned documents as chips */} <UserMessageAttachments />
{mentionedDocs?.map((doc) => ( {/* Mentioned documents as chips */}
<span {mentionedDocs?.map((doc) => (
key={`${doc.document_type}:${doc.id}`} <span
className="inline-flex items-center gap-1 px-2 py-0.5 rounded-full bg-primary/10 text-xs font-medium text-primary border border-primary/20" key={`${doc.document_type}:${doc.id}`}
title={doc.title} className="inline-flex items-center gap-1 px-2 py-0.5 rounded-full bg-primary/10 text-xs font-medium text-primary border border-primary/20"
> title={doc.title}
<FileText className="size-3" /> >
<span className="max-w-[150px] truncate">{doc.title}</span> <FileText className="size-3" />
</span> <span className="max-w-[150px] truncate">{doc.title}</span>
))} </span>
</div> ))}
)} </div>
{/* Message bubble with action bar positioned relative to it */} )}
<div className="relative"> {/* Message bubble with action bar positioned relative to it */}
<div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground"> <div className="relative">
<MessagePrimitive.Parts /> <div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground">
</div> <MessagePrimitive.Parts />
<div className="aui-user-action-bar-wrapper absolute top-1/2 right-full -translate-y-1/2 pr-1"> </div>
<UserActionBar /> <div className="aui-user-action-bar-wrapper absolute top-1/2 right-full -translate-y-1/2 pr-1">
<UserActionBar />
</div>
</div> </div>
</div> </div>
{/* User avatar - only shown in shared chats */}
{author && (
<div className="shrink-0">
<UserAvatar displayName={author.displayName} avatarUrl={author.avatarUrl} />
</div>
)}
</div> </div>
<BranchPicker className="aui-user-branch-picker -mr-1 col-span-full col-start-1 row-start-3 justify-end" /> <BranchPicker className="aui-user-branch-picker -mr-1 col-span-full col-start-1 row-start-3 justify-end" />

View file

@ -354,7 +354,11 @@ export function LayoutDataProvider({
onChatDelete={handleChatDelete} onChatDelete={handleChatDelete}
onViewAllSharedChats={handleViewAllSharedChats} onViewAllSharedChats={handleViewAllSharedChats}
onViewAllPrivateChats={handleViewAllPrivateChats} onViewAllPrivateChats={handleViewAllPrivateChats}
user={{ email: user?.email || "", name: user?.email?.split("@")[0] }} user={{
email: user?.email || "",
name: user?.display_name || user?.email?.split("@")[0],
avatarUrl: user?.avatar_url || undefined,
}}
onSettings={handleSettings} onSettings={handleSettings}
onManageMembers={handleManageMembers} onManageMembers={handleManageMembers}
onUserSettings={handleUserSettings} onUserSettings={handleUserSettings}

View file

@ -12,6 +12,7 @@ export interface SearchSpace {
export interface User { export interface User {
email: string; email: string;
name?: string; name?: string;
avatarUrl?: string;
} }
export interface NavItem { export interface NavItem {

View file

@ -61,6 +61,39 @@ function getInitials(email: string): string {
return name.slice(0, 2).toUpperCase(); return name.slice(0, 2).toUpperCase();
} }
/**
* User avatar component - shows image if available, otherwise falls back to initials
*/
function UserAvatar({
avatarUrl,
initials,
bgColor,
}: {
avatarUrl?: string;
initials: string;
bgColor: string;
}) {
if (avatarUrl) {
return (
<img
src={avatarUrl}
alt="User avatar"
className="h-8 w-8 shrink-0 rounded-lg object-cover"
referrerPolicy="no-referrer"
/>
);
}
return (
<div
className="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg text-xs font-semibold text-white"
style={{ backgroundColor: bgColor }}
>
{initials}
</div>
);
}
export function SidebarUserProfile({ export function SidebarUserProfile({
user, user,
onUserSettings, onUserSettings,
@ -88,12 +121,7 @@ export function SidebarUserProfile({
"focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring" "focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring"
)} )}
> >
<div <UserAvatar avatarUrl={user.avatarUrl} initials={initials} bgColor={bgColor} />
className="flex h-8 w-8 items-center justify-center rounded-lg text-xs font-semibold text-white"
style={{ backgroundColor: bgColor }}
>
{initials}
</div>
<span className="sr-only">{displayName}</span> <span className="sr-only">{displayName}</span>
</button> </button>
</DropdownMenuTrigger> </DropdownMenuTrigger>
@ -104,12 +132,7 @@ export function SidebarUserProfile({
<DropdownMenuContent className="w-56" side="right" align="end" sideOffset={8}> <DropdownMenuContent className="w-56" side="right" align="end" sideOffset={8}>
<DropdownMenuLabel className="font-normal"> <DropdownMenuLabel className="font-normal">
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
<div <UserAvatar avatarUrl={user.avatarUrl} initials={initials} bgColor={bgColor} />
className="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg text-xs font-semibold text-white"
style={{ backgroundColor: bgColor }}
>
{initials}
</div>
<div className="flex-1 min-w-0"> <div className="flex-1 min-w-0">
<p className="truncate text-sm font-medium">{displayName}</p> <p className="truncate text-sm font-medium">{displayName}</p>
<p className="truncate text-xs text-muted-foreground">{user.email}</p> <p className="truncate text-xs text-muted-foreground">{user.email}</p>
@ -149,13 +172,7 @@ export function SidebarUserProfile({
"focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring" "focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring"
)} )}
> >
{/* Avatar */} <UserAvatar avatarUrl={user.avatarUrl} initials={initials} bgColor={bgColor} />
<div
className="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg text-xs font-semibold text-white"
style={{ backgroundColor: bgColor }}
>
{initials}
</div>
{/* Name and email */} {/* Name and email */}
<div className="flex-1 min-w-0"> <div className="flex-1 min-w-0">
@ -171,12 +188,7 @@ export function SidebarUserProfile({
<DropdownMenuContent className="w-56" side="top" align="start" sideOffset={4}> <DropdownMenuContent className="w-56" side="top" align="start" sideOffset={4}>
<DropdownMenuLabel className="font-normal"> <DropdownMenuLabel className="font-normal">
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
<div <UserAvatar avatarUrl={user.avatarUrl} initials={initials} bgColor={bgColor} />
className="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg text-xs font-semibold text-white"
style={{ backgroundColor: bgColor }}
>
{initials}
</div>
<div className="flex-1 min-w-0"> <div className="flex-1 min-w-0">
<p className="truncate text-sm font-medium">{displayName}</p> <p className="truncate text-sm font-medium">{displayName}</p>
<p className="truncate text-xs text-muted-foreground">{user.email}</p> <p className="truncate text-xs text-muted-foreground">{user.email}</p>

View file

@ -215,6 +215,16 @@ export const DocumentMentionPicker = forwardRef<
isSurfsenseDocsLoading) && isSurfsenseDocsLoading) &&
currentPage === 0; currentPage === 0;
// Split documents into SurfSense docs and user docs for grouped rendering
const surfsenseDocsList = useMemo(
() => actualDocuments.filter((doc) => doc.document_type === "SURFSENSE_DOCS"),
[actualDocuments]
);
const userDocsList = useMemo(
() => actualDocuments.filter((doc) => doc.document_type !== "SURFSENSE_DOCS"),
[actualDocuments]
);
// Track already selected documents using unique key (document_type:id) to avoid ID collisions // Track already selected documents using unique key (document_type:id) to avoid ID collisions
const selectedKeys = useMemo( const selectedKeys = useMemo(
() => new Set(initialSelectedDocuments.map((d) => `${d.document_type}:${d.id}`)), () => new Set(initialSelectedDocuments.map((d) => `${d.document_type}:${d.id}`)),
@ -324,47 +334,102 @@ export const DocumentMentionPicker = forwardRef<
</div> </div>
) : ( ) : (
<div className="py-1"> <div className="py-1">
{actualDocuments.map((doc) => { {/* SurfSense Documentation Section */}
const docKey = `${doc.document_type}:${doc.id}`; {surfsenseDocsList.length > 0 && (
const isAlreadySelected = selectedKeys.has(docKey); <>
const selectableIndex = selectableDocuments.findIndex( <div className="sticky top-0 z-10 px-3 py-2 text-xs font-bold uppercase tracking-wider bg-muted text-foreground/80 border-b border-border">
(d) => d.document_type === doc.document_type && d.id === doc.id SurfSense Docs
); </div>
const isHighlighted = !isAlreadySelected && selectableIndex === highlightedIndex; {surfsenseDocsList.map((doc) => {
const docKey = `${doc.document_type}:${doc.id}`;
const isAlreadySelected = selectedKeys.has(docKey);
const selectableIndex = selectableDocuments.findIndex(
(d) => d.document_type === doc.document_type && d.id === doc.id
);
const isHighlighted = !isAlreadySelected && selectableIndex === highlightedIndex;
return (
<button
key={docKey}
ref={(el) => {
if (el && selectableIndex >= 0) {
itemRefs.current.set(selectableIndex, el);
}
}}
type="button"
onClick={() => !isAlreadySelected && handleSelectDocument(doc)}
onMouseEnter={() => {
if (!isAlreadySelected && selectableIndex >= 0) {
setHighlightedIndex(selectableIndex);
}
}}
disabled={isAlreadySelected}
className={cn(
"w-full flex items-center gap-2 px-3 py-2 text-left transition-colors",
isAlreadySelected ? "opacity-50 cursor-not-allowed" : "cursor-pointer",
isHighlighted && "bg-accent"
)}
>
<span className="shrink-0 text-muted-foreground text-sm">
{getConnectorIcon(doc.document_type)}
</span>
<span className="flex-1 text-sm truncate" title={doc.title}>
{doc.title}
</span>
</button>
);
})}
</>
)}
{/* User Documents Section */}
{userDocsList.length > 0 && (
<>
<div className="sticky top-0 z-10 px-3 py-2 text-xs font-bold uppercase tracking-wider bg-muted text-foreground/80 border-b border-border">
Your Documents
</div>
{userDocsList.map((doc) => {
const docKey = `${doc.document_type}:${doc.id}`;
const isAlreadySelected = selectedKeys.has(docKey);
const selectableIndex = selectableDocuments.findIndex(
(d) => d.document_type === doc.document_type && d.id === doc.id
);
const isHighlighted = !isAlreadySelected && selectableIndex === highlightedIndex;
return (
<button
key={docKey}
ref={(el) => {
if (el && selectableIndex >= 0) {
itemRefs.current.set(selectableIndex, el);
}
}}
type="button"
onClick={() => !isAlreadySelected && handleSelectDocument(doc)}
onMouseEnter={() => {
if (!isAlreadySelected && selectableIndex >= 0) {
setHighlightedIndex(selectableIndex);
}
}}
disabled={isAlreadySelected}
className={cn(
"w-full flex items-center gap-2 px-3 py-2 text-left transition-colors",
isAlreadySelected ? "opacity-50 cursor-not-allowed" : "cursor-pointer",
isHighlighted && "bg-accent"
)}
>
<span className="shrink-0 text-muted-foreground text-sm">
{getConnectorIcon(doc.document_type)}
</span>
<span className="flex-1 text-sm truncate" title={doc.title}>
{doc.title}
</span>
</button>
);
})}
</>
)}
return (
<button
key={docKey}
ref={(el) => {
if (el && selectableIndex >= 0) {
itemRefs.current.set(selectableIndex, el);
}
}}
type="button"
onClick={() => !isAlreadySelected && handleSelectDocument(doc)}
onMouseEnter={() => {
if (!isAlreadySelected && selectableIndex >= 0) {
setHighlightedIndex(selectableIndex);
}
}}
disabled={isAlreadySelected}
className={cn(
"w-full flex items-center gap-2 px-3 py-2 text-left transition-colors",
isAlreadySelected ? "opacity-50 cursor-not-allowed" : "cursor-pointer",
isHighlighted && "bg-accent"
)}
>
{/* Type icon */}
<span className="flex-shrink-0 text-muted-foreground text-sm">
{getConnectorIcon(doc.document_type)}
</span>
{/* Title */}
<span className="flex-1 text-sm truncate" title={doc.title}>
{doc.title}
</span>
</button>
);
})}
{/* Loading indicator for additional pages */} {/* Loading indicator for additional pages */}
{isLoadingMore && ( {isLoadingMore && (
<div className="flex items-center justify-center py-2"> <div className="flex items-center justify-center py-2">

View file

@ -110,6 +110,11 @@ const FILE_TYPE_CONFIG: Record<string, Record<string, string[]>> = {
const cardClass = "border border-border bg-slate-400/5 dark:bg-white/5"; const cardClass = "border border-border bg-slate-400/5 dark:bg-white/5";
// Upload limits
const MAX_FILES = 10;
const MAX_TOTAL_SIZE_MB = 200;
const MAX_TOTAL_SIZE_BYTES = MAX_TOTAL_SIZE_MB * 1024 * 1024;
export function DocumentUploadTab({ export function DocumentUploadTab({
searchSpaceId, searchSpaceId,
onSuccess, onSuccess,
@ -134,15 +139,40 @@ export function DocumentUploadTab({
[acceptedFileTypes] [acceptedFileTypes]
); );
const onDrop = useCallback((acceptedFiles: File[]) => { const onDrop = useCallback(
setFiles((prev) => [...prev, ...acceptedFiles]); (acceptedFiles: File[]) => {
}, []); setFiles((prev) => {
const newFiles = [...prev, ...acceptedFiles];
// Check file count limit
if (newFiles.length > MAX_FILES) {
toast.error(t("max_files_exceeded"), {
description: t("max_files_exceeded_desc", { max: MAX_FILES }),
});
return prev;
}
// Check total size limit
const newTotalSize = newFiles.reduce((sum, file) => sum + file.size, 0);
if (newTotalSize > MAX_TOTAL_SIZE_BYTES) {
toast.error(t("max_size_exceeded"), {
description: t("max_size_exceeded_desc", { max: MAX_TOTAL_SIZE_MB }),
});
return prev;
}
return newFiles;
});
},
[t]
);
const { getRootProps, getInputProps, isDragActive } = useDropzone({ const { getRootProps, getInputProps, isDragActive } = useDropzone({
onDrop, onDrop,
accept: acceptedFileTypes, accept: acceptedFileTypes,
maxSize: 50 * 1024 * 1024, maxSize: 50 * 1024 * 1024, // 50MB per file
noClick: false, noClick: false,
disabled: files.length >= MAX_FILES,
}); });
// Handle file input click to prevent event bubbling that might reopen dialog // Handle file input click to prevent event bubbling that might reopen dialog
@ -160,6 +190,15 @@ export function DocumentUploadTab({
const totalFileSize = files.reduce((total, file) => total + file.size, 0); const totalFileSize = files.reduce((total, file) => total + file.size, 0);
// Check if limits are reached
const isFileCountLimitReached = files.length >= MAX_FILES;
const isSizeLimitReached = totalFileSize >= MAX_TOTAL_SIZE_BYTES;
const remainingFiles = MAX_FILES - files.length;
const remainingSizeMB = Math.max(
0,
(MAX_TOTAL_SIZE_BYTES - totalFileSize) / (1024 * 1024)
).toFixed(1);
// Track accordion state changes // Track accordion state changes
const handleAccordionChange = useCallback( const handleAccordionChange = useCallback(
(value: string) => { (value: string) => {
@ -210,7 +249,8 @@ export function DocumentUploadTab({
<Alert className="border border-border bg-slate-400/5 dark:bg-white/5 flex items-start gap-3 [&>svg]:relative [&>svg]:left-0 [&>svg]:top-0 [&>svg~*]:pl-0"> <Alert className="border border-border bg-slate-400/5 dark:bg-white/5 flex items-start gap-3 [&>svg]:relative [&>svg]:left-0 [&>svg]:top-0 [&>svg~*]:pl-0">
<Info className="h-4 w-4 shrink-0 mt-0.5" /> <Info className="h-4 w-4 shrink-0 mt-0.5" />
<AlertDescription className="text-xs sm:text-sm leading-relaxed pt-0.5"> <AlertDescription className="text-xs sm:text-sm leading-relaxed pt-0.5">
{t("file_size_limit")} {t("file_size_limit")}{" "}
{t("upload_limits", { maxFiles: MAX_FILES, maxSizeMB: MAX_TOTAL_SIZE_MB })}
</AlertDescription> </AlertDescription>
</Alert> </Alert>
@ -221,7 +261,11 @@ export function DocumentUploadTab({
<CardContent className="p-4 sm:p-10 relative z-10"> <CardContent className="p-4 sm:p-10 relative z-10">
<div <div
{...getRootProps()} {...getRootProps()}
className="flex flex-col items-center justify-center min-h-[200px] sm:min-h-[300px] border-2 border-dashed border-border rounded-lg hover:border-primary/50 transition-colors cursor-pointer" className={`flex flex-col items-center justify-center min-h-[200px] sm:min-h-[300px] border-2 border-dashed rounded-lg transition-colors ${
isFileCountLimitReached || isSizeLimitReached
? "border-destructive/50 bg-destructive/5 cursor-not-allowed"
: "border-border hover:border-primary/50 cursor-pointer"
}`}
> >
<input <input
{...getInputProps()} {...getInputProps()}
@ -229,7 +273,19 @@ export function DocumentUploadTab({
className="hidden" className="hidden"
onClick={handleFileInputClick} onClick={handleFileInputClick}
/> />
{isDragActive ? ( {isFileCountLimitReached ? (
<div className="flex flex-col items-center gap-2 sm:gap-4 text-center px-4">
<Upload className="h-8 w-8 sm:h-12 sm:w-12 text-destructive/70" />
<div>
<p className="text-sm sm:text-lg font-medium text-destructive">
{t("file_limit_reached")}
</p>
<p className="text-xs sm:text-sm text-muted-foreground mt-1">
{t("file_limit_reached_desc", { max: MAX_FILES })}
</p>
</div>
</div>
) : isDragActive ? (
<motion.div <motion.div
initial={{ opacity: 0, scale: 0.8 }} initial={{ opacity: 0, scale: 0.8 }}
animate={{ opacity: 1, scale: 1 }} animate={{ opacity: 1, scale: 1 }}
@ -245,22 +301,29 @@ export function DocumentUploadTab({
<p className="text-sm sm:text-lg font-medium">{t("drag_drop")}</p> <p className="text-sm sm:text-lg font-medium">{t("drag_drop")}</p>
<p className="text-xs sm:text-sm text-muted-foreground mt-1">{t("or_browse")}</p> <p className="text-xs sm:text-sm text-muted-foreground mt-1">{t("or_browse")}</p>
</div> </div>
{files.length > 0 && (
<p className="text-xs text-muted-foreground">
{t("remaining_capacity", { files: remainingFiles, sizeMB: remainingSizeMB })}
</p>
)}
</div>
)}
{!isFileCountLimitReached && (
<div className="mt-2 sm:mt-4">
<Button
variant="outline"
size="sm"
className="text-xs sm:text-sm"
onClick={(e) => {
e.stopPropagation();
e.preventDefault();
fileInputRef.current?.click();
}}
>
{t("browse_files")}
</Button>
</div> </div>
)} )}
<div className="mt-2 sm:mt-4">
<Button
variant="outline"
size="sm"
className="text-xs sm:text-sm"
onClick={(e) => {
e.stopPropagation();
e.preventDefault();
fileInputRef.current?.click();
}}
>
{t("browse_files")}
</Button>
</div>
</div> </div>
</CardContent> </CardContent>
</Card> </Card>

View file

@ -8,6 +8,8 @@ export const user = z.object({
is_verified: z.boolean(), is_verified: z.boolean(),
pages_limit: z.number(), pages_limit: z.number(),
pages_used: z.number(), pages_used: z.number(),
display_name: z.string().nullish(),
avatar_url: z.string().nullish(),
}); });
/** /**
@ -15,5 +17,20 @@ export const user = z.object({
*/ */
export const getMeResponse = user; export const getMeResponse = user;
/**
* Update current user request
*/
export const updateUserRequest = z.object({
display_name: z.string().nullish(),
avatar_url: z.string().nullish(),
});
/**
* Update current user response
*/
export const updateUserResponse = user;
export type User = z.infer<typeof user>; export type User = z.infer<typeof user>;
export type GetMeResponse = z.infer<typeof getMeResponse>; export type GetMeResponse = z.infer<typeof getMeResponse>;
export type UpdateUserRequest = z.infer<typeof updateUserRequest>;
export type UpdateUserResponse = z.infer<typeof updateUserResponse>;

View file

@ -1,4 +1,8 @@
import { getMeResponse } from "@/contracts/types/user.types"; import {
getMeResponse,
type UpdateUserRequest,
updateUserResponse,
} from "@/contracts/types/user.types";
import { baseApiService } from "./base-api.service"; import { baseApiService } from "./base-api.service";
class UserApiService { class UserApiService {
@ -8,6 +12,15 @@ class UserApiService {
getMe = async () => { getMe = async () => {
return baseApiService.get(`/users/me`, getMeResponse); return baseApiService.get(`/users/me`, getMeResponse);
}; };
/**
* Update current authenticated user
*/
updateMe = async (request: UpdateUserRequest) => {
return baseApiService.patch(`/users/me`, updateUserResponse, {
body: request,
});
};
} }
export const userApiService = new UserApiService(); export const userApiService = new UserApiService();

View file

@ -31,6 +31,9 @@ export interface MessageRecord {
role: "user" | "assistant" | "system"; role: "user" | "assistant" | "system";
content: unknown; content: unknown;
created_at: string; created_at: string;
author_id?: string | null;
author_display_name?: string | null;
author_avatar_url?: string | null;
} }
export interface ThreadListResponse { export interface ThreadListResponse {

View file

@ -109,6 +109,17 @@
"title": "User Settings", "title": "User Settings",
"description": "Manage your account settings and API access", "description": "Manage your account settings and API access",
"back_to_app": "Back to app", "back_to_app": "Back to app",
"profile_nav_label": "Profile",
"profile_nav_description": "Manage your display name and avatar",
"profile_title": "Profile",
"profile_description": "Update your personal information",
"profile_avatar": "Profile Picture",
"profile_display_name": "Display Name",
"profile_display_name_hint": "This is how your name appears across the app",
"profile_email": "Email",
"profile_save": "Save Changes",
"profile_saved": "Profile updated successfully",
"profile_save_error": "Failed to update profile",
"api_key_nav_label": "API Key", "api_key_nav_label": "API Key",
"api_key_nav_description": "Manage your API access token", "api_key_nav_description": "Manage your API access token",
"api_key_title": "API Key", "api_key_title": "API Key",
@ -367,6 +378,7 @@
"title": "Upload Documents", "title": "Upload Documents",
"subtitle": "Upload your files to make them searchable and accessible through AI-powered conversations.", "subtitle": "Upload your files to make them searchable and accessible through AI-powered conversations.",
"file_size_limit": "Maximum file size: 50MB per file. Supported formats vary based on your ETL service configuration.", "file_size_limit": "Maximum file size: 50MB per file. Supported formats vary based on your ETL service configuration.",
"upload_limits": "Upload limit: {maxFiles} files, {maxSizeMB}MB total.",
"drop_files": "Drop files here", "drop_files": "Drop files here",
"drag_drop": "Drag & drop files here", "drag_drop": "Drag & drop files here",
"or_browse": "or click to browse", "or_browse": "or click to browse",
@ -382,7 +394,14 @@
"upload_error": "Upload Error", "upload_error": "Upload Error",
"upload_error_desc": "Error uploading files", "upload_error_desc": "Error uploading files",
"supported_file_types": "Supported File Types", "supported_file_types": "Supported File Types",
"file_types_desc": "These file types are supported based on your current ETL service configuration." "file_types_desc": "These file types are supported based on your current ETL service configuration.",
"max_files_exceeded": "File Limit Exceeded",
"max_files_exceeded_desc": "You can upload a maximum of {max} files at a time.",
"max_size_exceeded": "Size Limit Exceeded",
"max_size_exceeded_desc": "Total file size cannot exceed {max}MB.",
"file_limit_reached": "Maximum Files Reached",
"file_limit_reached_desc": "Remove some files to add more (max {max} files).",
"remaining_capacity": "{files} files remaining • {sizeMB}MB available"
}, },
"add_webpage": { "add_webpage": {
"title": "Add Webpages for Crawling", "title": "Add Webpages for Crawling",

View file

@ -363,6 +363,7 @@
"title": "上传文档", "title": "上传文档",
"subtitle": "上传您的文件,使其可通过 AI 对话进行搜索和访问。", "subtitle": "上传您的文件,使其可通过 AI 对话进行搜索和访问。",
"file_size_limit": "最大文件大小:每个文件 50MB。支持的格式因您的 ETL 服务配置而异。", "file_size_limit": "最大文件大小:每个文件 50MB。支持的格式因您的 ETL 服务配置而异。",
"upload_limits": "上传限制:最多 {maxFiles} 个文件,总大小不超过 {maxSizeMB}MB。",
"drop_files": "放下文件到这里", "drop_files": "放下文件到这里",
"drag_drop": "拖放文件到这里", "drag_drop": "拖放文件到这里",
"or_browse": "或点击浏览", "or_browse": "或点击浏览",
@ -378,7 +379,14 @@
"upload_error": "上传错误", "upload_error": "上传错误",
"upload_error_desc": "上传文件时出错", "upload_error_desc": "上传文件时出错",
"supported_file_types": "支持的文件类型", "supported_file_types": "支持的文件类型",
"file_types_desc": "根据您当前的 ETL 服务配置支持这些文件类型。" "file_types_desc": "根据您当前的 ETL 服务配置支持这些文件类型。",
"max_files_exceeded": "超过文件数量限制",
"max_files_exceeded_desc": "一次最多只能上传 {max} 个文件。",
"max_size_exceeded": "超过文件大小限制",
"max_size_exceeded_desc": "文件总大小不能超过 {max}MB。",
"file_limit_reached": "已达到最大文件数量",
"file_limit_reached_desc": "移除一些文件以添加更多(最多 {max} 个文件)。",
"remaining_capacity": "剩余 {files} 个文件名额 • 可用 {sizeMB}MB"
}, },
"add_webpage": { "add_webpage": {
"title": "添加网页爬取", "title": "添加网页爬取",