mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-05 22:02:39 +02:00
Merge remote-tracking branch 'upstream/dev' into feat/replace-logs
This commit is contained in:
commit
2e0f742000
47 changed files with 2365 additions and 700 deletions
54
surfsense_backend/alembic/versions/0_initial_schema.py
Normal file
54
surfsense_backend/alembic/versions/0_initial_schema.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
"""Initial schema setup
|
||||
|
||||
Revision ID: 0
|
||||
Revises: None
|
||||
|
||||
Creates all tables from SQLAlchemy models. Idempotent - safe to run on existing databases.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "0"
|
||||
down_revision: str | None = None
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
from app.db import Base
|
||||
|
||||
connection = op.get_bind()
|
||||
|
||||
# Create tables
|
||||
op.execute(sa.text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
Base.metadata.create_all(bind=connection)
|
||||
|
||||
# Set up indexes
|
||||
op.execute(
|
||||
sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)"
|
||||
)
|
||||
)
|
||||
op.execute(
|
||||
sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector('english', content))"
|
||||
)
|
||||
)
|
||||
op.execute(
|
||||
sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)"
|
||||
)
|
||||
)
|
||||
op.execute(
|
||||
sa.text(
|
||||
"CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector('english', content))"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
|
|
@ -6,6 +6,8 @@ Revises: 9
|
|||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
|
@ -18,9 +20,37 @@ depends_on: str | Sequence[str] | None = None
|
|||
CHAT_TYPE_ENUM = "chattype"
|
||||
|
||||
|
||||
def enum_exists(enum_name: str) -> bool:
|
||||
"""Check if an enum type exists in the database."""
|
||||
conn = op.get_bind()
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)"
|
||||
),
|
||||
{"enum_name": enum_name},
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
|
||||
def table_exists(table_name: str) -> bool:
|
||||
"""Check if a table exists in the database."""
|
||||
conn = op.get_bind()
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)"
|
||||
),
|
||||
{"table_name": table_name},
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema - replace ChatType enum values with new QNA/REPORT structure."""
|
||||
|
||||
# Skip if chats table or chattype enum doesn't exist (fresh database)
|
||||
if not table_exists("chats") or not enum_exists(CHAT_TYPE_ENUM):
|
||||
return
|
||||
|
||||
# Old enum name for temporary storage
|
||||
old_enum_name = f"{CHAT_TYPE_ENUM}_old"
|
||||
|
||||
|
|
@ -72,6 +102,10 @@ def upgrade() -> None:
|
|||
def downgrade() -> None:
|
||||
"""Downgrade schema - revert ChatType enum to old GENERAL/DEEP/DEEPER/DEEPEST structure."""
|
||||
|
||||
# Skip if chats table or chattype enum doesn't exist
|
||||
if not table_exists("chats") or not enum_exists(CHAT_TYPE_ENUM):
|
||||
return
|
||||
|
||||
# Old enum name for temporary storage
|
||||
old_enum_name = f"{CHAT_TYPE_ENUM}_old"
|
||||
|
||||
|
|
|
|||
|
|
@ -7,22 +7,36 @@ Revises:
|
|||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# Import pgvector if needed for other types, though not for this ENUM change
|
||||
# import pgvector
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "1"
|
||||
down_revision: str | None = None
|
||||
down_revision: str | None = "0"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def enum_exists(enum_name: str) -> bool:
|
||||
"""Check if an enum type exists in the database."""
|
||||
conn = op.get_bind()
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)"
|
||||
),
|
||||
{"enum_name": enum_name},
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
# Skip if the enum doesn't exist (fresh DB after downgrade - create_db_and_tables will handle it)
|
||||
if not enum_exists("searchsourceconnectortype"):
|
||||
return
|
||||
|
||||
# Manually add the command to add the enum value
|
||||
# Note: It's generally better to let autogenerate handle this, but we're bypassing it
|
||||
op.execute(
|
||||
|
|
@ -51,6 +65,10 @@ END$$;
|
|||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
# Skip if the enum doesn't exist
|
||||
if not enum_exists("searchsourceconnectortype"):
|
||||
return
|
||||
|
||||
# Downgrading removal of an enum value is complex and potentially dangerous
|
||||
# if the value is in use. Often omitted or requires manual SQL based on context.
|
||||
# For now, we'll just pass. If you needed to reverse this, you'd likely
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ Revises: 23
|
|||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
|
@ -16,11 +18,27 @@ branch_labels: str | Sequence[str] | None = None
|
|||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def table_exists(table_name: str) -> bool:
|
||||
"""Check if a table exists in the database."""
|
||||
conn = op.get_bind()
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)"
|
||||
),
|
||||
{"table_name": table_name},
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""
|
||||
Fix any chats with NULL type values by setting them to QNA.
|
||||
This handles edge cases from previous migrations where type values were not properly migrated.
|
||||
"""
|
||||
# Skip if chats table doesn't exist (fresh database)
|
||||
if not table_exists("chats"):
|
||||
return
|
||||
|
||||
# Update any NULL type values to QNA (the default chat type)
|
||||
op.execute(
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ Revises: 33
|
|||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers
|
||||
|
|
@ -19,42 +21,59 @@ branch_labels: str | Sequence[str] | None = None
|
|||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def table_exists(table_name: str) -> bool:
|
||||
"""Check if a table exists in the database."""
|
||||
conn = op.get_bind()
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)"
|
||||
),
|
||||
{"table_name": table_name},
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add columns only if they don't already exist (safe for re-runs)."""
|
||||
|
||||
# Add 'state_version' column to chats table (default 1)
|
||||
op.execute("""
|
||||
ALTER TABLE chats
|
||||
ADD COLUMN IF NOT EXISTS state_version BIGINT DEFAULT 1 NOT NULL
|
||||
""")
|
||||
# Skip if chats table doesn't exist (fresh database)
|
||||
if table_exists("chats"):
|
||||
op.execute("""
|
||||
ALTER TABLE chats
|
||||
ADD COLUMN IF NOT EXISTS state_version BIGINT DEFAULT 1 NOT NULL
|
||||
""")
|
||||
|
||||
# Add 'chat_state_version' column to podcasts table
|
||||
op.execute("""
|
||||
ALTER TABLE podcasts
|
||||
ADD COLUMN IF NOT EXISTS chat_state_version BIGINT
|
||||
""")
|
||||
if table_exists("podcasts"):
|
||||
op.execute("""
|
||||
ALTER TABLE podcasts
|
||||
ADD COLUMN IF NOT EXISTS chat_state_version BIGINT
|
||||
""")
|
||||
|
||||
# Add 'chat_id' column to podcasts table
|
||||
op.execute("""
|
||||
ALTER TABLE podcasts
|
||||
ADD COLUMN IF NOT EXISTS chat_id INTEGER
|
||||
""")
|
||||
# Add 'chat_id' column to podcasts table
|
||||
op.execute("""
|
||||
ALTER TABLE podcasts
|
||||
ADD COLUMN IF NOT EXISTS chat_id INTEGER
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove columns only if they exist."""
|
||||
|
||||
op.execute("""
|
||||
ALTER TABLE podcasts
|
||||
DROP COLUMN IF EXISTS chat_state_version
|
||||
""")
|
||||
if table_exists("podcasts"):
|
||||
op.execute("""
|
||||
ALTER TABLE podcasts
|
||||
DROP COLUMN IF EXISTS chat_state_version
|
||||
""")
|
||||
|
||||
op.execute("""
|
||||
ALTER TABLE podcasts
|
||||
DROP COLUMN IF EXISTS chat_id
|
||||
""")
|
||||
op.execute("""
|
||||
ALTER TABLE podcasts
|
||||
DROP COLUMN IF EXISTS chat_id
|
||||
""")
|
||||
|
||||
op.execute("""
|
||||
ALTER TABLE chats
|
||||
DROP COLUMN IF EXISTS state_version
|
||||
""")
|
||||
if table_exists("chats"):
|
||||
op.execute("""
|
||||
ALTER TABLE chats
|
||||
DROP COLUMN IF EXISTS state_version
|
||||
""")
|
||||
|
|
|
|||
|
|
@ -62,8 +62,25 @@ def parse_timestamp(ts, fallback):
|
|||
return fallback
|
||||
|
||||
|
||||
def table_exists(table_name: str) -> bool:
|
||||
"""Check if a table exists in the database."""
|
||||
conn = op.get_bind()
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)"
|
||||
),
|
||||
{"table_name": table_name},
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Migrate old chats to new_chat_threads and remove old tables."""
|
||||
# Skip if chats table doesn't exist (fresh database)
|
||||
if not table_exists("chats"):
|
||||
print("[Migration 49] Chats table does not exist, skipping migration")
|
||||
return
|
||||
|
||||
connection = op.get_bind()
|
||||
|
||||
# Get all old chats
|
||||
|
|
@ -176,36 +193,49 @@ def upgrade() -> None:
|
|||
print("[Migration 49] Migration complete!")
|
||||
|
||||
|
||||
def enum_exists(enum_name: str) -> bool:
|
||||
"""Check if an enum type exists in the database."""
|
||||
conn = op.get_bind()
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)"
|
||||
),
|
||||
{"enum_name": enum_name},
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Recreate old chats table (data cannot be restored)."""
|
||||
# Recreate chattype enum
|
||||
# Skip if chats table already exists
|
||||
if table_exists("chats"):
|
||||
print("[Migration 49 Downgrade] Chats table already exists, skipping")
|
||||
return
|
||||
|
||||
# Recreate chattype enum if it doesn't exist
|
||||
if not enum_exists("chattype"):
|
||||
op.execute(
|
||||
sa.text("""
|
||||
CREATE TYPE chattype AS ENUM ('QNA')
|
||||
""")
|
||||
)
|
||||
|
||||
# Recreate chats table using raw SQL to avoid SQLAlchemy trying to create the enum
|
||||
op.execute(
|
||||
sa.text("""
|
||||
CREATE TYPE chattype AS ENUM ('QNA')
|
||||
CREATE TABLE chats (
|
||||
id SERIAL PRIMARY KEY,
|
||||
type chattype NOT NULL,
|
||||
title VARCHAR NOT NULL,
|
||||
initial_connectors VARCHAR[],
|
||||
messages JSON NOT NULL,
|
||||
state_version BIGINT NOT NULL DEFAULT 1,
|
||||
search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
|
||||
)
|
||||
""")
|
||||
)
|
||||
|
||||
# Recreate chats table
|
||||
op.create_table(
|
||||
"chats",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, index=True),
|
||||
sa.Column("type", sa.Enum("QNA", name="chattype"), nullable=False),
|
||||
sa.Column("title", sa.String(), nullable=False, index=True),
|
||||
sa.Column("initial_connectors", sa.ARRAY(sa.String()), nullable=True),
|
||||
sa.Column("messages", sa.JSON(), nullable=False),
|
||||
sa.Column("state_version", sa.BigInteger(), nullable=False, default=1),
|
||||
sa.Column(
|
||||
"search_space_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("searchspaces.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
op.execute(sa.text("CREATE INDEX ix_chats_id ON chats (id)"))
|
||||
op.execute(sa.text("CREATE INDEX ix_chats_title ON chats (title)"))
|
||||
|
||||
print("[Migration 49 Downgrade] Chats table recreated (data not restored)")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,50 @@
|
|||
"""allow_multiple_connectors_with_unique_names
|
||||
|
||||
Revision ID: 5263aa4e7f94
|
||||
Revises: a1b2c3d4e5f6
|
||||
Create Date: 2026-01-13 12:23:31.481643
|
||||
|
||||
"""
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '5263aa4e7f94'
|
||||
down_revision: str | None = 'a1b2c3d4e5f6'
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# Drop the old unique constraint
|
||||
op.drop_constraint(
|
||||
'uq_searchspace_user_connector_type',
|
||||
'search_source_connectors',
|
||||
type_='unique'
|
||||
)
|
||||
|
||||
# Create new unique constraint that includes name
|
||||
op.create_unique_constraint(
|
||||
'uq_searchspace_user_connector_type_name',
|
||||
'search_source_connectors',
|
||||
['search_space_id', 'user_id', 'connector_type', 'name']
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# Drop the new constraint
|
||||
op.drop_constraint(
|
||||
'uq_searchspace_user_connector_type_name',
|
||||
'search_source_connectors',
|
||||
type_='unique'
|
||||
)
|
||||
|
||||
# Restore the old constraint
|
||||
op.create_unique_constraint(
|
||||
'uq_searchspace_user_connector_type',
|
||||
'search_source_connectors',
|
||||
['search_space_id', 'user_id', 'connector_type']
|
||||
)
|
||||
|
|
@ -39,7 +39,7 @@ def upgrade():
|
|||
"""
|
||||
)
|
||||
|
||||
# Rename columns (only if they exist with old names)
|
||||
# Rename columns (only if source exists and target doesn't already exist)
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
|
|
@ -47,6 +47,9 @@ def upgrade():
|
|||
IF EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'searchspaces' AND column_name = 'fast_llm_id'
|
||||
) AND NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'searchspaces' AND column_name = 'agent_llm_id'
|
||||
) THEN
|
||||
ALTER TABLE searchspaces RENAME COLUMN fast_llm_id TO agent_llm_id;
|
||||
END IF;
|
||||
|
|
@ -61,6 +64,9 @@ def upgrade():
|
|||
IF EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'searchspaces' AND column_name = 'long_context_llm_id'
|
||||
) AND NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'searchspaces' AND column_name = 'document_summary_llm_id'
|
||||
) THEN
|
||||
ALTER TABLE searchspaces RENAME COLUMN long_context_llm_id TO document_summary_llm_id;
|
||||
END IF;
|
||||
|
|
@ -100,7 +106,7 @@ def downgrade():
|
|||
"""
|
||||
)
|
||||
|
||||
# Rename columns back
|
||||
# Rename columns back (only if source exists and target doesn't already exist)
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
|
|
@ -108,6 +114,9 @@ def downgrade():
|
|||
IF EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'searchspaces' AND column_name = 'agent_llm_id'
|
||||
) AND NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'searchspaces' AND column_name = 'fast_llm_id'
|
||||
) THEN
|
||||
ALTER TABLE searchspaces RENAME COLUMN agent_llm_id TO fast_llm_id;
|
||||
END IF;
|
||||
|
|
@ -122,6 +131,9 @@ def downgrade():
|
|||
IF EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'searchspaces' AND column_name = 'document_summary_llm_id'
|
||||
) AND NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'searchspaces' AND column_name = 'long_context_llm_id'
|
||||
) THEN
|
||||
ALTER TABLE searchspaces RENAME COLUMN document_summary_llm_id TO long_context_llm_id;
|
||||
END IF;
|
||||
|
|
|
|||
|
|
@ -60,14 +60,28 @@ def downgrade() -> None:
|
|||
|
||||
connection = op.get_bind()
|
||||
|
||||
connection.execute(
|
||||
# Only update if the target enum value exists (it won't on fresh databases)
|
||||
result = connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE documents
|
||||
SET document_type = 'GOOGLE_DRIVE_CONNECTOR'
|
||||
WHERE document_type = 'GOOGLE_DRIVE_FILE';
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM pg_type t
|
||||
JOIN pg_enum e ON t.oid = e.enumtypid
|
||||
WHERE t.typname = 'documenttype' AND e.enumlabel = 'GOOGLE_DRIVE_CONNECTOR'
|
||||
);
|
||||
"""
|
||||
)
|
||||
)
|
||||
enum_exists = result.scalar()
|
||||
|
||||
connection.commit()
|
||||
if enum_exists:
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE documents
|
||||
SET document_type = 'GOOGLE_DRIVE_CONNECTOR'
|
||||
WHERE document_type = 'GOOGLE_DRIVE_FILE';
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.commit()
|
||||
|
|
|
|||
|
|
@ -18,59 +18,77 @@ branch_labels: str | Sequence[str] | None = None
|
|||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Alter Chat table
|
||||
op.alter_column(
|
||||
"chats",
|
||||
"title",
|
||||
existing_type=sa.String(200),
|
||||
type_=sa.String(),
|
||||
existing_nullable=False,
|
||||
def table_exists(table_name: str) -> bool:
|
||||
"""Check if a table exists in the database."""
|
||||
conn = op.get_bind()
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)"
|
||||
),
|
||||
{"table_name": table_name},
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Alter Chat table (may not exist on fresh databases, removed in migration 49)
|
||||
if table_exists("chats"):
|
||||
op.alter_column(
|
||||
"chats",
|
||||
"title",
|
||||
existing_type=sa.String(200),
|
||||
type_=sa.String(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
# Alter Document table
|
||||
op.alter_column(
|
||||
"documents",
|
||||
"title",
|
||||
existing_type=sa.String(200),
|
||||
type_=sa.String(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
if table_exists("documents"):
|
||||
op.alter_column(
|
||||
"documents",
|
||||
"title",
|
||||
existing_type=sa.String(200),
|
||||
type_=sa.String(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
# Alter Podcast table
|
||||
op.alter_column(
|
||||
"podcasts",
|
||||
"title",
|
||||
existing_type=sa.String(200),
|
||||
type_=sa.String(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
if table_exists("podcasts"):
|
||||
op.alter_column(
|
||||
"podcasts",
|
||||
"title",
|
||||
existing_type=sa.String(200),
|
||||
type_=sa.String(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert Chat table
|
||||
op.alter_column(
|
||||
"chats",
|
||||
"title",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(200),
|
||||
existing_nullable=False,
|
||||
)
|
||||
if table_exists("chats"):
|
||||
op.alter_column(
|
||||
"chats",
|
||||
"title",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(200),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
# Revert Document table
|
||||
op.alter_column(
|
||||
"documents",
|
||||
"title",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(200),
|
||||
existing_nullable=False,
|
||||
)
|
||||
if table_exists("documents"):
|
||||
op.alter_column(
|
||||
"documents",
|
||||
"title",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(200),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
# Revert Podcast table
|
||||
op.alter_column(
|
||||
"podcasts",
|
||||
"title",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(200),
|
||||
existing_nullable=False,
|
||||
)
|
||||
if table_exists("podcasts"):
|
||||
op.alter_column(
|
||||
"podcasts",
|
||||
"title",
|
||||
existing_type=sa.String(),
|
||||
type_=sa.String(200),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,72 @@
|
|||
"""Add display_name and avatar_url columns to user table
|
||||
|
||||
This migration adds:
|
||||
- display_name column for user's full name from OAuth
|
||||
- avatar_url column for user's profile picture URL from OAuth
|
||||
|
||||
Revision ID: 62
|
||||
Revises: 61
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "62"
|
||||
down_revision: str | None = "61"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add display_name and avatar_url columns to user table."""
|
||||
|
||||
# Add display_name column (nullable for existing users)
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'user' AND column_name = 'display_name'
|
||||
) THEN
|
||||
ALTER TABLE "user"
|
||||
ADD COLUMN display_name VARCHAR;
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
# Add avatar_url column (nullable for existing users)
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'user' AND column_name = 'avatar_url'
|
||||
) THEN
|
||||
ALTER TABLE "user"
|
||||
ADD COLUMN avatar_url VARCHAR;
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove display_name and avatar_url columns from user table."""
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE "user"
|
||||
DROP COLUMN IF EXISTS avatar_url;
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE "user"
|
||||
DROP COLUMN IF EXISTS display_name;
|
||||
"""
|
||||
)
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
"""Add author_id column to new_chat_messages table
|
||||
|
||||
Revision ID: 63
|
||||
Revises: 62
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "63"
|
||||
down_revision: str | None = "62"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add author_id column to new_chat_messages table."""
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'new_chat_messages' AND column_name = 'author_id'
|
||||
) THEN
|
||||
ALTER TABLE new_chat_messages
|
||||
ADD COLUMN author_id UUID REFERENCES "user"(id) ON DELETE SET NULL;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS ix_new_chat_messages_author_id
|
||||
ON new_chat_messages(author_id);
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove author_id column from new_chat_messages table."""
|
||||
op.execute(
|
||||
"""
|
||||
DROP INDEX IF EXISTS ix_new_chat_messages_author_id;
|
||||
ALTER TABLE new_chat_messages
|
||||
DROP COLUMN IF EXISTS author_id;
|
||||
"""
|
||||
)
|
||||
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
"""Add MCP connector type
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises: 61
|
||||
Create Date: 2026-01-09 15:19:51.827647
|
||||
|
||||
"""
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'a1b2c3d4e5f6'
|
||||
down_revision: str | None = '61'
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add MCP_CONNECTOR to SearchSourceConnectorType enum."""
|
||||
# Add new enum value using raw SQL
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TYPE searchsourceconnectortype ADD VALUE IF NOT EXISTS 'MCP_CONNECTOR';
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove MCP_CONNECTOR from SearchSourceConnectorType enum."""
|
||||
# Note: PostgreSQL does not support removing enum values directly.
|
||||
# To downgrade, you would need to:
|
||||
# 1. Create a new enum without MCP_CONNECTOR
|
||||
# 2. Alter the column to use the new enum
|
||||
# 3. Drop the old enum
|
||||
# This is left as a manual operation if needed.
|
||||
pass
|
||||
|
|
@ -20,7 +20,7 @@ from app.agents.new_chat.system_prompt import (
|
|||
build_configurable_system_prompt,
|
||||
build_surfsense_system_prompt,
|
||||
)
|
||||
from app.agents.new_chat.tools import build_tools
|
||||
from app.agents.new_chat.tools.registry import build_tools_async
|
||||
from app.services.connector_service import ConnectorService
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -28,7 +28,7 @@ from app.services.connector_service import ConnectorService
|
|||
# =============================================================================
|
||||
|
||||
|
||||
def create_surfsense_deep_agent(
|
||||
async def create_surfsense_deep_agent(
|
||||
llm: ChatLiteLLM,
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
|
|
@ -120,8 +120,8 @@ def create_surfsense_deep_agent(
|
|||
"firecrawl_api_key": firecrawl_api_key,
|
||||
}
|
||||
|
||||
# Build tools using the registry
|
||||
tools = build_tools(
|
||||
# Build tools using the async registry (includes MCP tools)
|
||||
tools = await build_tools_async(
|
||||
dependencies=dependencies,
|
||||
enabled_tools=enabled_tools,
|
||||
disabled_tools=disabled_tools,
|
||||
|
|
|
|||
188
surfsense_backend/app/agents/new_chat/tools/mcp_client.py
Normal file
188
surfsense_backend/app/agents/new_chat/tools/mcp_client.py
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
"""MCP Client Wrapper.
|
||||
|
||||
This module provides a client for communicating with MCP servers via stdio transport.
|
||||
It handles server lifecycle management, tool discovery, and tool execution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""Client for communicating with an MCP server."""
|
||||
|
||||
def __init__(self, command: str, args: list[str], env: dict[str, str] | None = None):
|
||||
"""Initialize MCP client.
|
||||
|
||||
Args:
|
||||
command: Command to spawn the MCP server (e.g., "uvx", "node")
|
||||
args: Arguments for the command (e.g., ["mcp-server-git"])
|
||||
env: Optional environment variables for the server process
|
||||
|
||||
"""
|
||||
self.command = command
|
||||
self.args = args
|
||||
self.env = env or {}
|
||||
self.session: ClientSession | None = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect(self):
|
||||
"""Connect to the MCP server and manage its lifecycle.
|
||||
|
||||
Yields:
|
||||
ClientSession: Active MCP session for making requests
|
||||
|
||||
"""
|
||||
try:
|
||||
# Merge env vars with current environment
|
||||
server_env = os.environ.copy()
|
||||
server_env.update(self.env)
|
||||
|
||||
# Create server parameters with env
|
||||
server_params = StdioServerParameters(
|
||||
command=self.command,
|
||||
args=self.args,
|
||||
env=server_env
|
||||
)
|
||||
|
||||
# Spawn server process and create session
|
||||
# Note: Cannot combine these context managers because ClientSession
|
||||
# needs the read/write streams from stdio_client
|
||||
async with stdio_client(server=server_params) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the connection
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
logger.info(
|
||||
"Connected to MCP server: %s %s",
|
||||
self.command,
|
||||
" ".join(self.args),
|
||||
)
|
||||
yield session
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect to MCP server: %s", e, exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
self.session = None
|
||||
logger.info("Disconnected from MCP server: %s", self.command)
|
||||
|
||||
async def list_tools(self) -> list[dict[str, Any]]:
|
||||
"""List all tools available from the MCP server.
|
||||
|
||||
Returns:
|
||||
List of tool definitions with name, description, and input schema
|
||||
|
||||
Raises:
|
||||
RuntimeError: If not connected to server
|
||||
|
||||
"""
|
||||
if not self.session:
|
||||
raise RuntimeError("Not connected to MCP server. Use 'async with client.connect():'")
|
||||
|
||||
try:
|
||||
# Call tools/list RPC method
|
||||
response = await self.session.list_tools()
|
||||
|
||||
tools = []
|
||||
for tool in response.tools:
|
||||
tools.append({
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"input_schema": tool.inputSchema if hasattr(tool, "inputSchema") else {},
|
||||
})
|
||||
|
||||
logger.info("Listed %d tools from MCP server", len(tools))
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
|
||||
raise
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
||||
"""Call a tool on the MCP server.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to call
|
||||
arguments: Arguments to pass to the tool
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
|
||||
Raises:
|
||||
RuntimeError: If not connected to server
|
||||
|
||||
"""
|
||||
if not self.session:
|
||||
raise RuntimeError("Not connected to MCP server. Use 'async with client.connect():'")
|
||||
|
||||
try:
|
||||
logger.info("Calling MCP tool '%s' with arguments: %s", tool_name, arguments)
|
||||
|
||||
# Call tools/call RPC method
|
||||
response = await self.session.call_tool(tool_name, arguments=arguments)
|
||||
|
||||
# Extract content from response
|
||||
result = []
|
||||
for content in response.content:
|
||||
if hasattr(content, "text"):
|
||||
result.append(content.text)
|
||||
elif hasattr(content, "data"):
|
||||
result.append(str(content.data))
|
||||
else:
|
||||
result.append(str(content))
|
||||
|
||||
result_str = "\n".join(result) if result else ""
|
||||
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
|
||||
return result_str
|
||||
|
||||
except RuntimeError as e:
|
||||
# Handle validation errors from MCP server responses
|
||||
# Some MCP servers (like server-memory) return extra fields not in their schema
|
||||
if "Invalid structured content" in str(e):
|
||||
logger.warning("MCP server returned data not matching its schema, but continuing: %s", e)
|
||||
# Try to extract result from error message or return a success message
|
||||
return "Operation completed (server returned unexpected format)"
|
||||
raise
|
||||
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
||||
logger.error("Failed to call MCP tool '%s': %s", tool_name, e, exc_info=True)
|
||||
return f"Error calling tool: {e!s}"
|
||||
|
||||
|
||||
async def test_mcp_connection(
|
||||
command: str, args: list[str], env: dict[str, str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Test connection to an MCP server and fetch available tools.
|
||||
|
||||
Args:
|
||||
command: Command to spawn the MCP server
|
||||
args: Arguments for the command
|
||||
env: Optional environment variables
|
||||
|
||||
Returns:
|
||||
Dict with connection status and available tools
|
||||
|
||||
"""
|
||||
client = MCPClient(command, args, env)
|
||||
|
||||
try:
|
||||
async with client.connect():
|
||||
tools = await client.list_tools()
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Connected successfully. Found {len(tools)} tools.",
|
||||
"tools": tools,
|
||||
}
|
||||
except (RuntimeError, ConnectionError, TimeoutError, OSError) as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to connect: {e!s}",
|
||||
"tools": [],
|
||||
}
|
||||
189
surfsense_backend/app/agents/new_chat/tools/mcp_tool.py
Normal file
189
surfsense_backend/app/agents/new_chat/tools/mcp_tool.py
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
"""MCP Tool Factory.
|
||||
|
||||
This module creates LangChain tools from MCP servers using the Model Context Protocol.
|
||||
Tools are dynamically discovered from MCP servers - no manual configuration needed.
|
||||
|
||||
This implements real MCP protocol support similar to Cursor's implementation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, create_model
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.mcp_client import MCPClient
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_dynamic_input_model_from_schema(
|
||||
tool_name: str, input_schema: dict[str, Any],
|
||||
) -> type[BaseModel]:
|
||||
"""Create a Pydantic model from MCP tool's JSON schema.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool (used for model class name)
|
||||
input_schema: JSON schema from MCP server
|
||||
|
||||
Returns:
|
||||
Pydantic model class for tool input validation
|
||||
|
||||
"""
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = input_schema.get("required", [])
|
||||
|
||||
# Build Pydantic field definitions
|
||||
field_definitions = {}
|
||||
for param_name, param_schema in properties.items():
|
||||
param_description = param_schema.get("description", "")
|
||||
is_required = param_name in required_fields
|
||||
|
||||
# Use Any type for complex schemas to preserve structure
|
||||
# This allows the MCP server to do its own validation
|
||||
from typing import Any as AnyType
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
if is_required:
|
||||
field_definitions[param_name] = (AnyType, Field(..., description=param_description))
|
||||
else:
|
||||
field_definitions[param_name] = (
|
||||
AnyType | None,
|
||||
Field(None, description=param_description),
|
||||
)
|
||||
|
||||
# Create dynamic model
|
||||
model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input"
|
||||
return create_model(model_name, **field_definitions)
|
||||
|
||||
|
||||
async def _create_mcp_tool_from_definition(
|
||||
tool_def: dict[str, Any],
|
||||
mcp_client: MCPClient,
|
||||
) -> StructuredTool:
|
||||
"""Create a LangChain tool from an MCP tool definition.
|
||||
|
||||
Args:
|
||||
tool_def: Tool definition from MCP server with name, description, input_schema
|
||||
mcp_client: MCP client instance for calling the tool
|
||||
|
||||
Returns:
|
||||
LangChain StructuredTool instance
|
||||
|
||||
"""
|
||||
tool_name = tool_def.get("name", "unnamed_tool")
|
||||
tool_description = tool_def.get("description", "No description provided")
|
||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||
|
||||
# Log the actual schema for debugging
|
||||
logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}")
|
||||
|
||||
# Create dynamic input model from schema
|
||||
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
||||
|
||||
async def mcp_tool_call(**kwargs) -> str:
|
||||
"""Execute the MCP tool call via the client."""
|
||||
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
|
||||
|
||||
try:
|
||||
# Connect to server and call tool
|
||||
async with mcp_client.connect():
|
||||
result = await mcp_client.call_tool(tool_name, kwargs)
|
||||
return str(result)
|
||||
except Exception as e:
|
||||
error_msg = f"MCP tool '{tool_name}' failed: {e!s}"
|
||||
logger.exception(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
|
||||
# Create StructuredTool with response_format to preserve exact schema
|
||||
tool = StructuredTool(
|
||||
name=tool_name,
|
||||
description=tool_description,
|
||||
coroutine=mcp_tool_call,
|
||||
args_schema=input_model,
|
||||
# Store the original MCP schema as metadata so we can access it later
|
||||
metadata={"mcp_input_schema": input_schema},
|
||||
)
|
||||
|
||||
logger.info(f"Created MCP tool: '{tool_name}'")
|
||||
return tool
|
||||
|
||||
|
||||
async def load_mcp_tools(
|
||||
session: AsyncSession, search_space_id: int,
|
||||
) -> list[StructuredTool]:
|
||||
"""Load all MCP tools from user's active MCP server connectors.
|
||||
|
||||
This discovers tools dynamically from MCP servers using the protocol.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
search_space_id: User's search space ID
|
||||
|
||||
Returns:
|
||||
List of LangChain StructuredTool instances
|
||||
|
||||
"""
|
||||
try:
|
||||
# Fetch all MCP connectors for this search space
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
),
|
||||
)
|
||||
|
||||
tools: list[StructuredTool] = []
|
||||
for connector in result.scalars():
|
||||
try:
|
||||
# Extract server config
|
||||
config = connector.config or {}
|
||||
server_config = config.get("server_config", {})
|
||||
|
||||
command = server_config.get("command")
|
||||
args = server_config.get("args", [])
|
||||
env = server_config.get("env", {})
|
||||
|
||||
if not command:
|
||||
logger.warning(f"MCP connector {connector.id} missing command, skipping")
|
||||
continue
|
||||
|
||||
# Create MCP client
|
||||
mcp_client = MCPClient(command, args, env)
|
||||
|
||||
# Connect and discover tools
|
||||
async with mcp_client.connect():
|
||||
tool_definitions = await mcp_client.list_tools()
|
||||
|
||||
logger.info(
|
||||
f"Discovered {len(tool_definitions)} tools from MCP server "
|
||||
f"'{command}' (connector {connector.id})"
|
||||
)
|
||||
|
||||
# Create LangChain tools from definitions
|
||||
for tool_def in tool_definitions:
|
||||
try:
|
||||
tool = await _create_mcp_tool_from_definition(tool_def, mcp_client)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create tool '{tool_def.get('name')}' "
|
||||
f"from connector {connector.id}: {e!s}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to load tools from MCP connector {connector.id}: {e!s}",
|
||||
)
|
||||
|
||||
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to load MCP tools: {e!s}")
|
||||
return []
|
||||
|
|
@ -1,5 +1,4 @@
|
|||
"""
|
||||
Tools registry for SurfSense deep agent.
|
||||
"""Tools registry for SurfSense deep agent.
|
||||
|
||||
This module provides a registry pattern for managing tools in the SurfSense agent.
|
||||
It makes it easy for OSS contributors to add new tools by:
|
||||
|
|
@ -37,6 +36,7 @@ Example of adding a new tool:
|
|||
),
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
|
@ -46,6 +46,7 @@ from langchain_core.tools import BaseTool
|
|||
from .display_image import create_display_image_tool
|
||||
from .knowledge_base import create_search_knowledge_base_tool
|
||||
from .link_preview import create_link_preview_tool
|
||||
from .mcp_tool import load_mcp_tools
|
||||
from .podcast import create_generate_podcast_tool
|
||||
from .scrape_webpage import create_scrape_webpage_tool
|
||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||
|
|
@ -57,8 +58,7 @@ from .search_surfsense_docs import create_search_surfsense_docs_tool
|
|||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
"""
|
||||
Definition of a tool that can be added to the agent.
|
||||
"""Definition of a tool that can be added to the agent.
|
||||
|
||||
Attributes:
|
||||
name: Unique identifier for the tool
|
||||
|
|
@ -66,6 +66,7 @@ class ToolDefinition:
|
|||
factory: Callable that creates the tool. Receives a dict of dependencies.
|
||||
requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session")
|
||||
enabled_by_default: Whether the tool is enabled when no explicit config is provided
|
||||
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
|
@ -178,8 +179,7 @@ def build_tools(
|
|||
disabled_tools: list[str] | None = None,
|
||||
additional_tools: list[BaseTool] | None = None,
|
||||
) -> list[BaseTool]:
|
||||
"""
|
||||
Build the list of tools for the agent.
|
||||
"""Build the list of tools for the agent.
|
||||
|
||||
Args:
|
||||
dependencies: Dict containing all possible dependencies:
|
||||
|
|
@ -206,6 +206,7 @@ def build_tools(
|
|||
|
||||
# Add custom tools
|
||||
tools = build_tools(deps, additional_tools=[my_custom_tool])
|
||||
|
||||
"""
|
||||
# Determine which tools to enable
|
||||
if enabled_tools is not None:
|
||||
|
|
@ -226,8 +227,9 @@ def build_tools(
|
|||
# Check that all required dependencies are provided
|
||||
missing_deps = [dep for dep in tool_def.requires if dep not in dependencies]
|
||||
if missing_deps:
|
||||
msg = f"Tool '{tool_def.name}' requires dependencies: {missing_deps}"
|
||||
raise ValueError(
|
||||
f"Tool '{tool_def.name}' requires dependencies: {missing_deps}"
|
||||
msg,
|
||||
)
|
||||
|
||||
# Create the tool
|
||||
|
|
@ -239,3 +241,61 @@ def build_tools(
|
|||
tools.extend(additional_tools)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
async def build_tools_async(
|
||||
dependencies: dict[str, Any],
|
||||
enabled_tools: list[str] | None = None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
additional_tools: list[BaseTool] | None = None,
|
||||
include_mcp_tools: bool = True,
|
||||
) -> list[BaseTool]:
|
||||
"""Async version of build_tools that also loads MCP tools from database.
|
||||
|
||||
Design Note:
|
||||
This function exists because MCP tools require database queries to load user configs,
|
||||
while built-in tools are created synchronously from static code.
|
||||
|
||||
Alternative: We could make build_tools() itself async and always query the database,
|
||||
but that would force async everywhere even when only using built-in tools. The current
|
||||
design keeps the simple case (static tools only) synchronous while supporting dynamic
|
||||
database-loaded tools through this async wrapper.
|
||||
|
||||
Args:
|
||||
dependencies: Dict containing all possible dependencies
|
||||
enabled_tools: Explicit list of tool names to enable. If None, uses defaults.
|
||||
disabled_tools: List of tool names to disable (applied after enabled_tools).
|
||||
additional_tools: Extra tools to add (e.g., custom tools not in registry).
|
||||
include_mcp_tools: Whether to load user's MCP tools from database.
|
||||
|
||||
Returns:
|
||||
List of configured tool instances ready for the agent, including MCP tools.
|
||||
|
||||
"""
|
||||
# Build standard tools
|
||||
tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools)
|
||||
|
||||
# Load MCP tools if requested and dependencies are available
|
||||
if (
|
||||
include_mcp_tools
|
||||
and "db_session" in dependencies
|
||||
and "search_space_id" in dependencies
|
||||
):
|
||||
try:
|
||||
mcp_tools = await load_mcp_tools(
|
||||
dependencies["db_session"], dependencies["search_space_id"],
|
||||
)
|
||||
tools.extend(mcp_tools)
|
||||
logging.info(
|
||||
f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}",
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but don't fail - just continue without MCP tools
|
||||
logging.exception(f"Failed to load MCP tools: {e!s}")
|
||||
|
||||
# Log all tools being returned to agent
|
||||
logging.info(
|
||||
f"Total tools for agent: {len(tools)} - {[t.name for t in tools]}",
|
||||
)
|
||||
|
||||
return tools
|
||||
|
|
|
|||
|
|
@ -80,6 +80,7 @@ class SearchSourceConnectorType(str, Enum):
|
|||
WEBCRAWLER_CONNECTOR = "WEBCRAWLER_CONNECTOR"
|
||||
BOOKSTACK_CONNECTOR = "BOOKSTACK_CONNECTOR"
|
||||
CIRCLEBACK_CONNECTOR = "CIRCLEBACK_CONNECTOR"
|
||||
MCP_CONNECTOR = "MCP_CONNECTOR" # Model Context Protocol - User-defined API tools
|
||||
|
||||
|
||||
class LiteLLMProvider(str, Enum):
|
||||
|
|
@ -412,8 +413,17 @@ class NewChatMessage(BaseModel, TimestampMixin):
|
|||
index=True,
|
||||
)
|
||||
|
||||
# Relationship
|
||||
# Track who sent this message (for shared chats)
|
||||
author_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
thread = relationship("NewChatThread", back_populates="messages")
|
||||
author = relationship("User")
|
||||
|
||||
|
||||
class Document(BaseModel, TimestampMixin):
|
||||
|
|
@ -611,7 +621,8 @@ class SearchSourceConnector(BaseModel, TimestampMixin):
|
|||
"search_space_id",
|
||||
"user_id",
|
||||
"connector_type",
|
||||
name="uq_searchspace_user_connector_type",
|
||||
"name",
|
||||
name="uq_searchspace_user_connector_type_name",
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -919,6 +930,10 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
)
|
||||
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
# User profile from OAuth
|
||||
display_name = Column(String, nullable=True)
|
||||
avatar_url = Column(String, nullable=True)
|
||||
|
||||
else:
|
||||
|
||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
|
|
@ -958,6 +973,10 @@ else:
|
|||
)
|
||||
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
# User profile (can be set manually for non-OAuth users)
|
||||
display_name = Column(String, nullable=True)
|
||||
avatar_url = Column(String, nullable=True)
|
||||
|
||||
|
||||
engine = create_async_engine(DATABASE_URL)
|
||||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
|
|
|||
|
|
@ -411,11 +411,9 @@ async def get_thread_messages(
|
|||
Requires CHATS_READ permission.
|
||||
"""
|
||||
try:
|
||||
# Get thread with messages
|
||||
# Get thread first
|
||||
result = await session.execute(
|
||||
select(NewChatThread)
|
||||
.options(selectinload(NewChatThread.messages))
|
||||
.filter(NewChatThread.id == thread_id)
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
|
||||
|
|
@ -434,6 +432,15 @@ async def get_thread_messages(
|
|||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
|
||||
# Get messages with their authors loaded
|
||||
messages_result = await session.execute(
|
||||
select(NewChatMessage)
|
||||
.options(selectinload(NewChatMessage.author))
|
||||
.filter(NewChatMessage.thread_id == thread_id)
|
||||
.order_by(NewChatMessage.created_at)
|
||||
)
|
||||
db_messages = messages_result.scalars().all()
|
||||
|
||||
# Return messages in the format expected by assistant-ui
|
||||
messages = [
|
||||
NewChatMessageRead(
|
||||
|
|
@ -442,8 +449,11 @@ async def get_thread_messages(
|
|||
role=msg.role,
|
||||
content=msg.content,
|
||||
created_at=msg.created_at,
|
||||
author_id=msg.author_id,
|
||||
author_display_name=msg.author.display_name if msg.author else None,
|
||||
author_avatar_url=msg.author.avatar_url if msg.author else None,
|
||||
)
|
||||
for msg in thread.messages
|
||||
for msg in db_messages
|
||||
]
|
||||
|
||||
return ThreadHistoryLoadResponse(messages=messages)
|
||||
|
|
@ -782,6 +792,7 @@ async def append_message(
|
|||
thread_id=thread_id,
|
||||
role=message_role,
|
||||
content=message.content,
|
||||
author_id=user.id,
|
||||
)
|
||||
session.add(db_message)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,13 @@ PUT /search-source-connectors/{connector_id} - Update a specific connector
|
|||
DELETE /search-source-connectors/{connector_id} - Delete a specific connector
|
||||
POST /search-source-connectors/{connector_id}/index - Index content from a connector to a search space
|
||||
|
||||
MCP (Model Context Protocol) Connector routes:
|
||||
POST /connectors/mcp - Create a new MCP connector with custom API tools
|
||||
GET /connectors/mcp - List all MCP connectors for the current user's search space
|
||||
GET /connectors/mcp/{connector_id} - Get a specific MCP connector with tools config
|
||||
PUT /connectors/mcp/{connector_id} - Update an MCP connector's tools config
|
||||
DELETE /connectors/mcp/{connector_id} - Delete an MCP connector
|
||||
|
||||
Note: OAuth connectors (Gmail, Drive, Slack, etc.) support multiple accounts per search space.
|
||||
Non-OAuth connectors (BookStack, GitHub, etc.) are limited to one per search space.
|
||||
"""
|
||||
|
|
@ -32,6 +39,9 @@ from app.db import (
|
|||
)
|
||||
from app.schemas import (
|
||||
GoogleDriveIndexRequest,
|
||||
MCPConnectorCreate,
|
||||
MCPConnectorRead,
|
||||
MCPConnectorUpdate,
|
||||
SearchSourceConnectorBase,
|
||||
SearchSourceConnectorCreate,
|
||||
SearchSourceConnectorRead,
|
||||
|
|
@ -128,18 +138,20 @@ async def create_search_source_connector(
|
|||
|
||||
# Check if a connector with the same type already exists for this search space
|
||||
# (for non-OAuth connectors that don't support multiple accounts)
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.connector_type == connector.connector_type,
|
||||
)
|
||||
)
|
||||
existing_connector = result.scalars().first()
|
||||
if existing_connector:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"A connector with type {connector.connector_type} already exists in this search space.",
|
||||
# Exception: MCP_CONNECTOR can have multiple instances with different names
|
||||
if connector.connector_type != SearchSourceConnectorType.MCP_CONNECTOR:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.connector_type == connector.connector_type,
|
||||
)
|
||||
)
|
||||
existing_connector = result.scalars().first()
|
||||
if existing_connector:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"A connector with type {connector.connector_type} already exists in this search space.",
|
||||
)
|
||||
|
||||
# Prepare connector data
|
||||
connector_data = connector.model_dump()
|
||||
|
|
@ -2054,3 +2066,348 @@ async def run_bookstack_indexing(
|
|||
indexing_function=index_bookstack_pages,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MCP Connector Routes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/connectors/mcp", response_model=MCPConnectorRead, status_code=201)
|
||||
async def create_mcp_connector(
|
||||
connector_data: MCPConnectorCreate,
|
||||
search_space_id: int = Query(..., description="Search space ID"),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Create a new MCP (Model Context Protocol) connector.
|
||||
|
||||
MCP connectors allow users to connect to MCP servers (like in Cursor).
|
||||
Tools are auto-discovered from the server - no manual configuration needed.
|
||||
|
||||
Args:
|
||||
connector_data: MCP server configuration (command, args, env)
|
||||
search_space_id: ID of the search space to attach the connector to
|
||||
session: Database session
|
||||
user: Current authenticated user
|
||||
|
||||
Returns:
|
||||
Created MCP connector with server configuration
|
||||
|
||||
Raises:
|
||||
HTTPException: If search space not found or permission denied
|
||||
"""
|
||||
try:
|
||||
# Check user has permission to create connectors
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.CONNECTORS_CREATE.value,
|
||||
"You don't have permission to create connectors in this search space",
|
||||
)
|
||||
|
||||
# Create the connector with server config
|
||||
db_connector = SearchSourceConnector(
|
||||
name=connector_data.name,
|
||||
connector_type=SearchSourceConnectorType.MCP_CONNECTOR,
|
||||
is_indexable=False, # MCP connectors are not indexable
|
||||
config={"server_config": connector_data.server_config.model_dump()},
|
||||
periodic_indexing_enabled=False,
|
||||
indexing_frequency_minutes=None,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user.id,
|
||||
)
|
||||
|
||||
session.add(db_connector)
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
f"Created MCP connector {db_connector.id} for server '{connector_data.server_config.command}' "
|
||||
f"for user {user.id} in search space {search_space_id}"
|
||||
)
|
||||
|
||||
# Convert to read schema
|
||||
connector_read = SearchSourceConnectorRead.model_validate(db_connector)
|
||||
return MCPConnectorRead.from_connector(connector_read)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create MCP connector: {e!s}", exc_info=True)
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create MCP connector: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/connectors/mcp", response_model=list[MCPConnectorRead])
|
||||
async def list_mcp_connectors(
|
||||
search_space_id: int = Query(..., description="Search space ID"),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
List all MCP connectors for a search space.
|
||||
|
||||
Args:
|
||||
search_space_id: ID of the search space
|
||||
session: Database session
|
||||
user: Current authenticated user
|
||||
|
||||
Returns:
|
||||
List of MCP connectors with their tool configurations
|
||||
"""
|
||||
try:
|
||||
# Check user has permission to read connectors
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.CONNECTORS_READ.value,
|
||||
"You don't have permission to view connectors in this search space",
|
||||
)
|
||||
|
||||
# Fetch MCP connectors
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
|
||||
connectors = result.scalars().all()
|
||||
return [
|
||||
MCPConnectorRead.from_connector(SearchSourceConnectorRead.model_validate(c))
|
||||
for c in connectors
|
||||
]
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list MCP connectors: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to list MCP connectors: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/connectors/mcp/{connector_id}", response_model=MCPConnectorRead)
|
||||
async def get_mcp_connector(
|
||||
connector_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get a specific MCP connector by ID.
|
||||
|
||||
Args:
|
||||
connector_id: ID of the connector
|
||||
session: Database session
|
||||
user: Current authenticated user
|
||||
|
||||
Returns:
|
||||
MCP connector with tool configurations
|
||||
"""
|
||||
try:
|
||||
# Fetch connector
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
raise HTTPException(status_code=404, detail="MCP connector not found")
|
||||
|
||||
# Check user has permission to read connectors
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
connector.search_space_id,
|
||||
Permission.CONNECTORS_READ.value,
|
||||
"You don't have permission to view this connector",
|
||||
)
|
||||
|
||||
connector_read = SearchSourceConnectorRead.model_validate(connector)
|
||||
return MCPConnectorRead.from_connector(connector_read)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get MCP connector: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get MCP connector: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/connectors/mcp/{connector_id}", response_model=MCPConnectorRead)
|
||||
async def update_mcp_connector(
|
||||
connector_id: int,
|
||||
connector_update: MCPConnectorUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Update an MCP connector.
|
||||
|
||||
Args:
|
||||
connector_id: ID of the connector to update
|
||||
connector_update: Updated connector data
|
||||
session: Database session
|
||||
user: Current authenticated user
|
||||
|
||||
Returns:
|
||||
Updated MCP connector
|
||||
"""
|
||||
try:
|
||||
# Fetch connector
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
raise HTTPException(status_code=404, detail="MCP connector not found")
|
||||
|
||||
# Check user has permission to update connectors
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
connector.search_space_id,
|
||||
Permission.CONNECTORS_UPDATE.value,
|
||||
"You don't have permission to update this connector",
|
||||
)
|
||||
|
||||
# Update fields
|
||||
if connector_update.name is not None:
|
||||
connector.name = connector_update.name
|
||||
|
||||
if connector_update.server_config is not None:
|
||||
connector.config = {
|
||||
"server_config": connector_update.server_config.model_dump()
|
||||
}
|
||||
|
||||
connector.updated_at = datetime.now(UTC)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
logger.info(f"Updated MCP connector {connector_id}")
|
||||
|
||||
connector_read = SearchSourceConnectorRead.model_validate(connector)
|
||||
return MCPConnectorRead.from_connector(connector_read)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update MCP connector: {e!s}", exc_info=True)
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update MCP connector: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/connectors/mcp/{connector_id}", status_code=204)
|
||||
async def delete_mcp_connector(
|
||||
connector_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Delete an MCP connector.
|
||||
|
||||
Args:
|
||||
connector_id: ID of the connector to delete
|
||||
session: Database session
|
||||
user: Current authenticated user
|
||||
"""
|
||||
try:
|
||||
# Fetch connector
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
raise HTTPException(status_code=404, detail="MCP connector not found")
|
||||
|
||||
# Check user has permission to delete connectors
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
connector.search_space_id,
|
||||
Permission.CONNECTORS_DELETE.value,
|
||||
"You don't have permission to delete this connector",
|
||||
)
|
||||
|
||||
await session.delete(connector)
|
||||
await session.commit()
|
||||
|
||||
logger.info(f"Deleted MCP connector {connector_id}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete MCP connector: {e!s}", exc_info=True)
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to delete MCP connector: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/connectors/mcp/test")
|
||||
async def test_mcp_server_connection(
|
||||
server_config: dict = Body(...),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Test connection to an MCP server and fetch available tools.
|
||||
|
||||
This endpoint allows users to test their MCP server configuration
|
||||
before saving it, similar to Cursor's flow.
|
||||
|
||||
Args:
|
||||
server_config: Server configuration with command, args, env
|
||||
user: Current authenticated user
|
||||
|
||||
Returns:
|
||||
Connection status and list of available tools
|
||||
"""
|
||||
try:
|
||||
from app.agents.new_chat.tools.mcp_client import test_mcp_connection
|
||||
|
||||
command = server_config.get("command")
|
||||
args = server_config.get("args", [])
|
||||
env = server_config.get("env", {})
|
||||
|
||||
if not command:
|
||||
raise HTTPException(status_code=400, detail="Server command is required")
|
||||
|
||||
# Test the connection
|
||||
result = await test_mcp_connection(command, args, env)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test MCP connection: {e!s}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to test connection: {e!s}",
|
||||
"tools": [],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,6 +55,10 @@ from .rbac_schemas import (
|
|||
UserSearchSpaceAccess,
|
||||
)
|
||||
from .search_source_connector import (
|
||||
MCPConnectorCreate,
|
||||
MCPConnectorRead,
|
||||
MCPConnectorUpdate,
|
||||
MCPServerConfig,
|
||||
SearchSourceConnectorBase,
|
||||
SearchSourceConnectorCreate,
|
||||
SearchSourceConnectorRead,
|
||||
|
|
@ -108,6 +112,11 @@ __all__ = [
|
|||
"LogFilter",
|
||||
"LogRead",
|
||||
"LogUpdate",
|
||||
# Search source connector schemas
|
||||
"MCPConnectorCreate",
|
||||
"MCPConnectorRead",
|
||||
"MCPConnectorUpdate",
|
||||
"MCPServerConfig",
|
||||
"MembershipRead",
|
||||
"MembershipReadWithUser",
|
||||
"MembershipUpdate",
|
||||
|
|
@ -135,7 +144,6 @@ __all__ = [
|
|||
"RoleCreate",
|
||||
"RoleRead",
|
||||
"RoleUpdate",
|
||||
# Search source connector schemas
|
||||
"SearchSourceConnectorBase",
|
||||
"SearchSourceConnectorCreate",
|
||||
"SearchSourceConnectorRead",
|
||||
|
|
|
|||
|
|
@ -38,6 +38,9 @@ class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
|
|||
"""Schema for reading a message."""
|
||||
|
||||
thread_id: int
|
||||
author_id: UUID | None = None
|
||||
author_display_name: str | None = None
|
||||
author_avatar_url: str | None = None
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class SearchSourceConnectorBase(BaseModel):
|
|||
@field_validator("config")
|
||||
@classmethod
|
||||
def validate_config_for_connector_type(
|
||||
cls, config: dict[str, Any], values: dict[str, Any]
|
||||
cls, config: dict[str, Any], values: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
connector_type = values.data.get("connector_type")
|
||||
return validate_connector_config(connector_type, config)
|
||||
|
|
@ -38,15 +38,18 @@ class SearchSourceConnectorBase(BaseModel):
|
|||
"""
|
||||
if self.periodic_indexing_enabled:
|
||||
if not self.is_indexable:
|
||||
msg = "periodic_indexing_enabled can only be True for indexable connectors"
|
||||
raise ValueError(
|
||||
"periodic_indexing_enabled can only be True for indexable connectors"
|
||||
msg,
|
||||
)
|
||||
if self.indexing_frequency_minutes is None:
|
||||
msg = "indexing_frequency_minutes is required when periodic_indexing_enabled is True"
|
||||
raise ValueError(
|
||||
"indexing_frequency_minutes is required when periodic_indexing_enabled is True"
|
||||
msg,
|
||||
)
|
||||
if self.indexing_frequency_minutes <= 0:
|
||||
raise ValueError("indexing_frequency_minutes must be greater than 0")
|
||||
msg = "indexing_frequency_minutes must be greater than 0"
|
||||
raise ValueError(msg)
|
||||
return self
|
||||
|
||||
|
||||
|
|
@ -70,3 +73,63 @@ class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampMod
|
|||
user_id: uuid.UUID
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MCP-specific schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class MCPServerConfig(BaseModel):
|
||||
"""Configuration for an MCP server connection (similar to Cursor's config)."""
|
||||
|
||||
command: str # e.g., "uvx", "node", "python"
|
||||
args: list[str] = [] # e.g., ["mcp-server-git", "--repository", "/path"]
|
||||
env: dict[str, str] = {} # Environment variables for the server process
|
||||
transport: str = "stdio" # "stdio" | "sse" | "http" (stdio is most common)
|
||||
|
||||
|
||||
class MCPConnectorCreate(BaseModel):
|
||||
"""Schema for creating an MCP connector."""
|
||||
|
||||
name: str
|
||||
server_config: MCPServerConfig
|
||||
|
||||
|
||||
class MCPConnectorUpdate(BaseModel):
|
||||
"""Schema for updating an MCP connector."""
|
||||
|
||||
name: str | None = None
|
||||
server_config: MCPServerConfig | None = None
|
||||
|
||||
|
||||
class MCPConnectorRead(BaseModel):
|
||||
"""Schema for reading an MCP connector with server config."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
connector_type: SearchSourceConnectorType
|
||||
server_config: MCPServerConfig
|
||||
search_space_id: int
|
||||
user_id: uuid.UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@classmethod
|
||||
def from_connector(cls, connector: SearchSourceConnectorRead) -> "MCPConnectorRead":
|
||||
"""Convert from base SearchSourceConnectorRead."""
|
||||
config = connector.config or {}
|
||||
server_config = MCPServerConfig(**config.get("server_config", {}))
|
||||
|
||||
return cls(
|
||||
id=connector.id,
|
||||
name=connector.name,
|
||||
connector_type=connector.connector_type,
|
||||
server_config=server_config,
|
||||
search_space_id=connector.search_space_id,
|
||||
user_id=connector.user_id,
|
||||
created_at=connector.created_at,
|
||||
updated_at=connector.updated_at,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ from fastapi_users import schemas
|
|||
class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
pages_limit: int
|
||||
pages_used: int
|
||||
display_name: str | None = None
|
||||
avatar_url: str | None = None
|
||||
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
|
|
@ -13,4 +15,5 @@ class UserCreate(schemas.BaseUserCreate):
|
|||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
pass
|
||||
display_name: str | None = None
|
||||
avatar_url: str | None = None
|
||||
|
|
|
|||
|
|
@ -237,7 +237,7 @@ async def stream_new_chat(
|
|||
checkpointer = await get_checkpointer()
|
||||
|
||||
# Create the deep agent with checkpointer and configurable prompts
|
||||
agent = create_surfsense_deep_agent(
|
||||
agent = await create_surfsense_deep_agent(
|
||||
llm=llm,
|
||||
search_space_id=search_space_id,
|
||||
db_session=session,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
from fastapi import Depends, Request, Response
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
|
||||
|
|
@ -46,6 +47,71 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||
reset_password_token_secret = SECRET
|
||||
verification_token_secret = SECRET
|
||||
|
||||
async def oauth_callback(
|
||||
self,
|
||||
oauth_name: str,
|
||||
access_token: str,
|
||||
account_id: str,
|
||||
account_email: str,
|
||||
expires_at: int | None = None,
|
||||
refresh_token: str | None = None,
|
||||
request: Request | None = None,
|
||||
*,
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
) -> User:
|
||||
"""
|
||||
Override OAuth callback to capture Google profile data (name, avatar).
|
||||
"""
|
||||
# Call parent implementation to create/get user
|
||||
user = await super().oauth_callback(
|
||||
oauth_name,
|
||||
access_token,
|
||||
account_id,
|
||||
account_email,
|
||||
expires_at,
|
||||
refresh_token,
|
||||
request,
|
||||
associate_by_email=associate_by_email,
|
||||
is_verified_by_default=is_verified_by_default,
|
||||
)
|
||||
|
||||
# Fetch and store Google profile data if not already set
|
||||
if oauth_name == "google" and (not user.display_name or not user.avatar_url):
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
"https://people.googleapis.com/v1/people/me",
|
||||
params={"personFields": "names,photos"},
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
profile = response.json()
|
||||
|
||||
update_dict = {}
|
||||
|
||||
# Extract name from names array
|
||||
names = profile.get("names", [])
|
||||
if not user.display_name and names:
|
||||
display_name = names[0].get("displayName")
|
||||
if display_name:
|
||||
update_dict["display_name"] = display_name
|
||||
|
||||
# Extract photo URL from photos array
|
||||
photos = profile.get("photos", [])
|
||||
if not user.avatar_url and photos:
|
||||
photo_url = photos[0].get("url")
|
||||
if photo_url:
|
||||
update_dict["avatar_url"] = photo_url
|
||||
|
||||
if update_dict:
|
||||
user = await self.user_db.update(user, update_dict)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch Google profile: {e}")
|
||||
|
||||
return user
|
||||
|
||||
async def on_after_register(self, user: User, request: Request | None = None):
|
||||
"""
|
||||
Called after a user registers. Creates a default search space for the user
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@ from typing import Any
|
|||
from urllib.parse import urlparse
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
|
|
@ -27,6 +27,7 @@ BASE_NAME_FOR_TYPE = {
|
|||
SearchSourceConnectorType.DISCORD_CONNECTOR: "Discord",
|
||||
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "Confluence",
|
||||
SearchSourceConnectorType.AIRTABLE_CONNECTOR: "Airtable",
|
||||
SearchSourceConnectorType.MCP_CONNECTOR: "Model Context Protocol (MCP)",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -75,7 +76,7 @@ def extract_identifier_from_credentials(
|
|||
if ".atlassian.net" in hostname:
|
||||
return hostname.replace(".atlassian.net", "")
|
||||
return hostname
|
||||
except Exception:
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -57,6 +57,9 @@ dependencies = [
|
|||
"chonkie[all]>=1.5.0",
|
||||
"langgraph-checkpoint-postgres>=3.0.2",
|
||||
"psycopg[binary,pool]>=3.3.2",
|
||||
"mcp>=1.25.0",
|
||||
"starlette>=0.40.0,<0.51.0",
|
||||
"sse-starlette>=3.1.1,<3.1.2",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
|
|
|||
|
|
@ -79,25 +79,17 @@ export function DocumentsTableShell({
|
|||
[documents, sortKey, sortDesc]
|
||||
);
|
||||
|
||||
// Filter out SURFSENSE_DOCS for selection purposes
|
||||
const selectableDocs = React.useMemo(
|
||||
() => sorted.filter((d) => d.document_type !== "SURFSENSE_DOCS"),
|
||||
[sorted]
|
||||
);
|
||||
|
||||
const allSelectedOnPage =
|
||||
selectableDocs.length > 0 && selectableDocs.every((d) => selectedIds.has(d.id));
|
||||
const someSelectedOnPage =
|
||||
selectableDocs.some((d) => selectedIds.has(d.id)) && !allSelectedOnPage;
|
||||
const allSelectedOnPage = sorted.length > 0 && sorted.every((d) => selectedIds.has(d.id));
|
||||
const someSelectedOnPage = sorted.some((d) => selectedIds.has(d.id)) && !allSelectedOnPage;
|
||||
|
||||
const toggleAll = (checked: boolean) => {
|
||||
const next = new Set(selectedIds);
|
||||
if (checked)
|
||||
selectableDocs.forEach((d) => {
|
||||
sorted.forEach((d) => {
|
||||
next.add(d.id);
|
||||
});
|
||||
else
|
||||
selectableDocs.forEach((d) => {
|
||||
sorted.forEach((d) => {
|
||||
next.delete(d.id);
|
||||
});
|
||||
setSelectedIds(next);
|
||||
|
|
@ -238,10 +230,9 @@ export function DocumentsTableShell({
|
|||
const icon = getDocumentTypeIcon(doc.document_type);
|
||||
const title = doc.title;
|
||||
const truncatedTitle = title.length > 30 ? `${title.slice(0, 30)}...` : title;
|
||||
const isSurfsenseDoc = doc.document_type === "SURFSENSE_DOCS";
|
||||
return (
|
||||
<motion.tr
|
||||
key={`${doc.document_type}-${doc.id}`}
|
||||
key={doc.id}
|
||||
initial={{ opacity: 0, y: 10 }}
|
||||
animate={{
|
||||
opacity: 1,
|
||||
|
|
@ -258,9 +249,8 @@ export function DocumentsTableShell({
|
|||
>
|
||||
<TableCell className="px-4 py-3">
|
||||
<Checkbox
|
||||
checked={selectedIds.has(doc.id) && !isSurfsenseDoc}
|
||||
onCheckedChange={(v) => !isSurfsenseDoc && toggleOne(doc.id, !!v)}
|
||||
disabled={isSurfsenseDoc}
|
||||
checked={selectedIds.has(doc.id)}
|
||||
onCheckedChange={(v) => toggleOne(doc.id, !!v)}
|
||||
aria-label="Select row"
|
||||
/>
|
||||
</TableCell>
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import { cacheKeys } from "@/lib/query-client/cache-keys";
|
|||
import { DocumentsFilters } from "./components/DocumentsFilters";
|
||||
import { DocumentsTableShell, type SortKey } from "./components/DocumentsTableShell";
|
||||
import { PaginationControls } from "./components/PaginationControls";
|
||||
import type { ColumnVisibility, Document } from "./components/types";
|
||||
import type { ColumnVisibility } from "./components/types";
|
||||
|
||||
function useDebounced<T>(value: T, delay = 250) {
|
||||
const [debounced, setDebounced] = useState(value);
|
||||
|
|
@ -58,39 +58,30 @@ export default function DocumentsTable() {
|
|||
const { data: rawTypeCounts } = useAtomValue(documentTypeCountsAtom);
|
||||
const { mutateAsync: deleteDocumentMutation } = useAtomValue(deleteDocumentMutationAtom);
|
||||
|
||||
// Filter out SURFSENSE_DOCS from active types for regular documents API
|
||||
const regularDocumentTypes = useMemo(
|
||||
() => activeTypes.filter((t) => t !== "SURFSENSE_DOCS"),
|
||||
[activeTypes]
|
||||
);
|
||||
|
||||
// Check if only SURFSENSE_DOCS is selected (skip regular docs query)
|
||||
const onlySurfsenseDocsSelected = activeTypes.length === 1 && activeTypes[0] === "SURFSENSE_DOCS";
|
||||
|
||||
// Build query parameters for fetching documents (excluding SURFSENSE_DOCS type)
|
||||
// Build query parameters for fetching documents
|
||||
const queryParams = useMemo(
|
||||
() => ({
|
||||
search_space_id: searchSpaceId,
|
||||
page: pageIndex,
|
||||
page_size: pageSize,
|
||||
...(regularDocumentTypes.length > 0 && { document_types: regularDocumentTypes }),
|
||||
...(activeTypes.length > 0 && { document_types: activeTypes }),
|
||||
}),
|
||||
[searchSpaceId, pageIndex, pageSize, regularDocumentTypes]
|
||||
[searchSpaceId, pageIndex, pageSize, activeTypes]
|
||||
);
|
||||
|
||||
// Build search query parameters (excluding SURFSENSE_DOCS type)
|
||||
// Build search query parameters
|
||||
const searchQueryParams = useMemo(
|
||||
() => ({
|
||||
search_space_id: searchSpaceId,
|
||||
page: pageIndex,
|
||||
page_size: pageSize,
|
||||
title: debouncedSearch.trim(),
|
||||
...(regularDocumentTypes.length > 0 && { document_types: regularDocumentTypes }),
|
||||
...(activeTypes.length > 0 && { document_types: activeTypes }),
|
||||
}),
|
||||
[searchSpaceId, pageIndex, pageSize, regularDocumentTypes, debouncedSearch]
|
||||
[searchSpaceId, pageIndex, pageSize, activeTypes, debouncedSearch]
|
||||
);
|
||||
|
||||
// Use query for fetching documents (disabled when only SURFSENSE_DOCS is selected)
|
||||
// Use query for fetching documents
|
||||
const {
|
||||
data: documentsResponse,
|
||||
isLoading: isDocumentsLoading,
|
||||
|
|
@ -100,10 +91,10 @@ export default function DocumentsTable() {
|
|||
queryKey: cacheKeys.documents.globalQueryParams(queryParams),
|
||||
queryFn: () => documentsApiService.getDocuments({ queryParams }),
|
||||
staleTime: 3 * 60 * 1000, // 3 minutes
|
||||
enabled: !!searchSpaceId && !debouncedSearch.trim() && !onlySurfsenseDocsSelected,
|
||||
enabled: !!searchSpaceId && !debouncedSearch.trim(),
|
||||
});
|
||||
|
||||
// Use query for searching documents (disabled when only SURFSENSE_DOCS is selected)
|
||||
// Use query for searching documents
|
||||
const {
|
||||
data: searchResponse,
|
||||
isLoading: isSearchLoading,
|
||||
|
|
@ -113,7 +104,7 @@ export default function DocumentsTable() {
|
|||
queryKey: cacheKeys.documents.globalQueryParams(searchQueryParams),
|
||||
queryFn: () => documentsApiService.searchDocuments({ queryParams: searchQueryParams }),
|
||||
staleTime: 3 * 60 * 1000, // 3 minutes
|
||||
enabled: !!searchSpaceId && !!debouncedSearch.trim() && !onlySurfsenseDocsSelected,
|
||||
enabled: !!searchSpaceId && !!debouncedSearch.trim(),
|
||||
});
|
||||
|
||||
// Determine if we should show SurfSense docs (when no type filter or SURFSENSE_DOCS is selected)
|
||||
|
|
@ -163,64 +154,16 @@ export default function DocumentsTable() {
|
|||
}, [rawTypeCounts, surfsenseDocsResponse?.total]);
|
||||
|
||||
// Extract documents and total based on search state
|
||||
const regularDocuments = debouncedSearch.trim()
|
||||
const documents = debouncedSearch.trim()
|
||||
? searchResponse?.items || []
|
||||
: documentsResponse?.items || [];
|
||||
const regularTotal = debouncedSearch.trim()
|
||||
? searchResponse?.total || 0
|
||||
: documentsResponse?.total || 0;
|
||||
const total = debouncedSearch.trim() ? searchResponse?.total || 0 : documentsResponse?.total || 0;
|
||||
|
||||
// Merge regular documents with SurfSense docs
|
||||
const documents = useMemo(() => {
|
||||
// If filtering by type and not including SURFSENSE_DOCS, only show regular docs
|
||||
if (activeTypes.length > 0 && !activeTypes.includes("SURFSENSE_DOCS" as DocumentTypeEnum)) {
|
||||
return regularDocuments;
|
||||
}
|
||||
// If filtering only by SURFSENSE_DOCS, only show surfsense docs
|
||||
if (activeTypes.length === 1 && activeTypes[0] === "SURFSENSE_DOCS") {
|
||||
return surfsenseDocsAsDocuments;
|
||||
}
|
||||
// Otherwise, merge both (surfsense docs first)
|
||||
return [...surfsenseDocsAsDocuments, ...regularDocuments];
|
||||
}, [regularDocuments, surfsenseDocsAsDocuments, activeTypes]);
|
||||
const loading = debouncedSearch.trim() ? isSearchLoading : isDocumentsLoading;
|
||||
const error = debouncedSearch.trim() ? searchError : documentsError;
|
||||
|
||||
const total = useMemo(() => {
|
||||
if (activeTypes.length > 0 && !activeTypes.includes("SURFSENSE_DOCS" as DocumentTypeEnum)) {
|
||||
return regularTotal;
|
||||
}
|
||||
if (activeTypes.length === 1 && activeTypes[0] === "SURFSENSE_DOCS") {
|
||||
return surfsenseDocsResponse?.total || 0;
|
||||
}
|
||||
return regularTotal + (surfsenseDocsResponse?.total || 0);
|
||||
}, [regularTotal, surfsenseDocsResponse?.total, activeTypes]);
|
||||
|
||||
const loading = useMemo(() => {
|
||||
// If only SURFSENSE_DOCS selected, only check surfsense loading
|
||||
if (onlySurfsenseDocsSelected) {
|
||||
return isSurfsenseDocsLoading;
|
||||
}
|
||||
// Otherwise check both regular docs and surfsense docs loading
|
||||
const regularLoading = debouncedSearch.trim() ? isSearchLoading : isDocumentsLoading;
|
||||
return regularLoading || (showSurfsenseDocs && isSurfsenseDocsLoading);
|
||||
}, [
|
||||
onlySurfsenseDocsSelected,
|
||||
isSurfsenseDocsLoading,
|
||||
debouncedSearch,
|
||||
isSearchLoading,
|
||||
isDocumentsLoading,
|
||||
showSurfsenseDocs,
|
||||
]);
|
||||
|
||||
const error = useMemo(() => {
|
||||
// If only SURFSENSE_DOCS selected, no regular docs errors
|
||||
if (onlySurfsenseDocsSelected) {
|
||||
return null;
|
||||
}
|
||||
return debouncedSearch.trim() ? searchError : documentsError;
|
||||
}, [onlySurfsenseDocsSelected, debouncedSearch, searchError, documentsError]);
|
||||
|
||||
// Display server-filtered results directly
|
||||
const displayDocs = documents || [];
|
||||
// Display results directly
|
||||
const displayDocs = documents;
|
||||
const displayTotal = total;
|
||||
const pageStart = pageIndex * pageSize;
|
||||
const pageEnd = Math.min(pageStart + pageSize, displayTotal);
|
||||
|
|
@ -240,33 +183,16 @@ export default function DocumentsTable() {
|
|||
if (isRefreshing) return;
|
||||
setIsRefreshing(true);
|
||||
try {
|
||||
const refetchPromises: Promise<unknown>[] = [];
|
||||
// Only refetch regular documents if not in "only surfsense docs" mode
|
||||
if (!onlySurfsenseDocsSelected) {
|
||||
if (debouncedSearch.trim()) {
|
||||
refetchPromises.push(refetchSearch());
|
||||
} else {
|
||||
refetchPromises.push(refetchDocuments());
|
||||
}
|
||||
if (debouncedSearch.trim()) {
|
||||
await refetchSearch();
|
||||
} else {
|
||||
await refetchDocuments();
|
||||
}
|
||||
if (showSurfsenseDocs) {
|
||||
refetchPromises.push(refetchSurfsenseDocs());
|
||||
}
|
||||
await Promise.all(refetchPromises);
|
||||
toast.success(t("refresh_success") || "Documents refreshed");
|
||||
} finally {
|
||||
setIsRefreshing(false);
|
||||
}
|
||||
}, [
|
||||
debouncedSearch,
|
||||
refetchSearch,
|
||||
refetchDocuments,
|
||||
refetchSurfsenseDocs,
|
||||
showSurfsenseDocs,
|
||||
onlySurfsenseDocsSelected,
|
||||
t,
|
||||
isRefreshing,
|
||||
]);
|
||||
}, [debouncedSearch, refetchSearch, refetchDocuments, t, isRefreshing]);
|
||||
|
||||
// Create a delete function for single document deletion
|
||||
const deleteDocument = useCallback(
|
||||
|
|
@ -357,7 +283,7 @@ export default function DocumentsTable() {
|
|||
</motion.div>
|
||||
|
||||
<DocumentsFilters
|
||||
typeCounts={typeCounts ?? {}}
|
||||
typeCounts={rawTypeCounts ?? {}}
|
||||
selectedIds={selectedIds}
|
||||
onSearch={setSearch}
|
||||
searchValue={search}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ import {
|
|||
// extractWriteTodosFromContent,
|
||||
hydratePlanStateAtom,
|
||||
} from "@/atoms/chat/plan-state.atom";
|
||||
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
|
||||
import { Thread } from "@/components/assistant-ui/thread";
|
||||
import { ChatHeader } from "@/components/new-chat/chat-header";
|
||||
import type { ThinkingStep } from "@/components/tool-ui/deepagent-thinking";
|
||||
|
|
@ -185,12 +186,25 @@ function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike {
|
|||
}
|
||||
}
|
||||
|
||||
// Build metadata.custom for author display in shared chats
|
||||
const metadata = msg.author_id
|
||||
? {
|
||||
custom: {
|
||||
author: {
|
||||
displayName: msg.author_display_name ?? null,
|
||||
avatarUrl: msg.author_avatar_url ?? null,
|
||||
},
|
||||
},
|
||||
}
|
||||
: undefined;
|
||||
|
||||
return {
|
||||
id: `msg-${msg.id}`,
|
||||
role: msg.role,
|
||||
content,
|
||||
createdAt: new Date(msg.created_at),
|
||||
attachments,
|
||||
metadata,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -238,6 +252,9 @@ export default function NewChatPage() {
|
|||
const setMessageDocumentsMap = useSetAtom(messageDocumentsMapAtom);
|
||||
const hydratePlanState = useSetAtom(hydratePlanStateAtom);
|
||||
|
||||
// Get current user for author info in shared chats
|
||||
const { data: currentUser } = useAtomValue(currentUserAtom);
|
||||
|
||||
// Create the attachment adapter for file processing
|
||||
const attachmentAdapter = useMemo(() => createAttachmentAdapter(), []);
|
||||
|
||||
|
|
@ -306,12 +323,6 @@ export default function NewChatPage() {
|
|||
if (steps.length > 0) {
|
||||
restoredThinkingSteps.set(`msg-${msg.id}`, steps);
|
||||
}
|
||||
// Hydrate write_todos plan state from persisted tool calls
|
||||
// Disabled for now
|
||||
// const writeTodosCalls = extractWriteTodosFromContent(msg.content);
|
||||
// for (const todoData of writeTodosCalls) {
|
||||
// hydratePlanState(todoData);
|
||||
// }
|
||||
}
|
||||
if (msg.role === "user") {
|
||||
const docs = extractMentionedDocuments(msg.content);
|
||||
|
|
@ -448,13 +459,27 @@ export default function NewChatPage() {
|
|||
|
||||
// Add user message to state
|
||||
const userMsgId = `msg-user-${Date.now()}`;
|
||||
|
||||
// Include author metadata for shared chats
|
||||
const authorMetadata =
|
||||
currentThread?.visibility === "SEARCH_SPACE" && currentUser
|
||||
? {
|
||||
custom: {
|
||||
author: {
|
||||
displayName: currentUser.display_name ?? null,
|
||||
avatarUrl: currentUser.avatar_url ?? null,
|
||||
},
|
||||
},
|
||||
}
|
||||
: undefined;
|
||||
|
||||
const userMessage: ThreadMessageLike = {
|
||||
id: userMsgId,
|
||||
role: "user",
|
||||
content: message.content,
|
||||
createdAt: new Date(),
|
||||
// Include attachments so they can be displayed
|
||||
attachments: message.attachments || [],
|
||||
metadata: authorMetadata,
|
||||
};
|
||||
setMessages((prev) => [...prev, userMessage]);
|
||||
|
||||
|
|
@ -884,6 +909,8 @@ export default function NewChatPage() {
|
|||
setMentionedDocuments,
|
||||
setMessageDocumentsMap,
|
||||
queryClient,
|
||||
currentThread,
|
||||
currentUser,
|
||||
]
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,123 @@
|
|||
"use client";
|
||||
|
||||
import { Check, Copy, Key, Menu, Shield } from "lucide-react";
|
||||
import { AnimatePresence, motion } from "motion/react";
|
||||
import { useTranslations } from "next-intl";
|
||||
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip";
|
||||
import { useApiKey } from "@/hooks/use-api-key";
|
||||
|
||||
interface ApiKeyContentProps {
|
||||
onMenuClick: () => void;
|
||||
}
|
||||
|
||||
export function ApiKeyContent({ onMenuClick }: ApiKeyContentProps) {
|
||||
const t = useTranslations("userSettings");
|
||||
const { apiKey, isLoading, copied, copyToClipboard } = useApiKey();
|
||||
|
||||
return (
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
transition={{ delay: 0.2, duration: 0.4 }}
|
||||
className="h-full min-w-0 flex-1 overflow-hidden bg-background"
|
||||
>
|
||||
<div className="h-full overflow-y-auto">
|
||||
<div className="mx-auto max-w-4xl p-4 md:p-6 lg:p-10">
|
||||
<AnimatePresence mode="wait">
|
||||
<motion.div
|
||||
key="api-key-header"
|
||||
initial={{ opacity: 0, y: 10 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: -10 }}
|
||||
transition={{ duration: 0.3 }}
|
||||
className="mb-6 md:mb-8"
|
||||
>
|
||||
<div className="flex items-center gap-3 md:gap-4">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
onClick={onMenuClick}
|
||||
className="h-10 w-10 shrink-0 md:hidden"
|
||||
>
|
||||
<Menu className="h-5 w-5" />
|
||||
</Button>
|
||||
<motion.div
|
||||
initial={{ scale: 0.8, opacity: 0 }}
|
||||
animate={{ scale: 1, opacity: 1 }}
|
||||
transition={{ delay: 0.1, duration: 0.3 }}
|
||||
className="flex h-10 w-10 shrink-0 items-center justify-center rounded-lg border border-primary/10 bg-gradient-to-br from-primary/20 to-primary/5 shadow-sm md:h-14 md:w-14 md:rounded-2xl"
|
||||
>
|
||||
<Key className="h-5 w-5 text-primary md:h-7 md:w-7" />
|
||||
</motion.div>
|
||||
<div className="min-w-0">
|
||||
<h1 className="truncate text-lg font-bold tracking-tight md:text-2xl">
|
||||
{t("api_key_title")}
|
||||
</h1>
|
||||
<p className="text-sm text-muted-foreground">{t("api_key_description")}</p>
|
||||
</div>
|
||||
</div>
|
||||
</motion.div>
|
||||
</AnimatePresence>
|
||||
|
||||
<AnimatePresence mode="wait">
|
||||
<motion.div
|
||||
key="api-key-content"
|
||||
initial={{ opacity: 0, y: 20 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: -20 }}
|
||||
transition={{ duration: 0.35, ease: [0.4, 0, 0.2, 1] }}
|
||||
className="space-y-6"
|
||||
>
|
||||
<Alert>
|
||||
<Shield className="h-4 w-4" />
|
||||
<AlertTitle>{t("api_key_warning_title")}</AlertTitle>
|
||||
<AlertDescription>{t("api_key_warning_description")}</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
<div className="rounded-lg border bg-card p-6">
|
||||
<h3 className="mb-4 font-medium">{t("your_api_key")}</h3>
|
||||
{isLoading ? (
|
||||
<div className="h-12 w-full animate-pulse rounded-md bg-muted" />
|
||||
) : apiKey ? (
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="flex-1 overflow-x-auto rounded-md bg-muted p-3 font-mono text-sm">
|
||||
{apiKey}
|
||||
</div>
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
onClick={copyToClipboard}
|
||||
className="shrink-0"
|
||||
>
|
||||
{copied ? <Check className="h-4 w-4" /> : <Copy className="h-4 w-4" />}
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>{copied ? t("copied") : t("copy")}</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
) : (
|
||||
<p className="text-center text-muted-foreground">{t("no_api_key")}</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="rounded-lg border bg-card p-6">
|
||||
<h3 className="mb-2 font-medium">{t("usage_title")}</h3>
|
||||
<p className="mb-4 text-sm text-muted-foreground">{t("usage_description")}</p>
|
||||
<pre className="overflow-x-auto rounded-md bg-muted p-3 text-sm">
|
||||
<code>Authorization: Bearer {apiKey || "YOUR_API_KEY"}</code>
|
||||
</pre>
|
||||
</div>
|
||||
</motion.div>
|
||||
</AnimatePresence>
|
||||
</div>
|
||||
</div>
|
||||
</motion.div>
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,181 @@
|
|||
"use client";
|
||||
|
||||
import { useAtomValue } from "jotai";
|
||||
import { Loader2, Menu, User } from "lucide-react";
|
||||
import { AnimatePresence, motion } from "motion/react";
|
||||
import { useTranslations } from "next-intl";
|
||||
import { useEffect, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
|
||||
import { updateUserMutationAtom } from "@/atoms/user/user-mutation.atoms";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
|
||||
interface ProfileContentProps {
|
||||
onMenuClick: () => void;
|
||||
}
|
||||
|
||||
function AvatarDisplay({ url, fallback }: { url?: string; fallback: string }) {
|
||||
const [hasError, setHasError] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
setHasError(false);
|
||||
}, [url]);
|
||||
|
||||
if (url && !hasError) {
|
||||
return (
|
||||
<img
|
||||
src={url}
|
||||
alt="Avatar"
|
||||
className="h-16 w-16 rounded-xl object-cover"
|
||||
onError={() => setHasError(true)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex h-16 w-16 items-center justify-center rounded-xl bg-muted text-xl font-semibold text-muted-foreground">
|
||||
{fallback}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function ProfileContent({ onMenuClick }: ProfileContentProps) {
|
||||
const t = useTranslations("userSettings");
|
||||
const { data: user, isLoading: isUserLoading } = useAtomValue(currentUserAtom);
|
||||
const { mutateAsync: updateUser, isPending } = useAtomValue(updateUserMutationAtom);
|
||||
|
||||
const [displayName, setDisplayName] = useState("");
|
||||
|
||||
useEffect(() => {
|
||||
if (user) {
|
||||
setDisplayName(user.display_name || "");
|
||||
}
|
||||
}, [user]);
|
||||
|
||||
const getInitials = (email: string) => {
|
||||
const name = email.split("@")[0];
|
||||
return name.slice(0, 2).toUpperCase();
|
||||
};
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
|
||||
try {
|
||||
await updateUser({
|
||||
display_name: displayName || null,
|
||||
});
|
||||
toast.success(t("profile_saved"));
|
||||
} catch {
|
||||
toast.error(t("profile_save_error"));
|
||||
}
|
||||
};
|
||||
|
||||
const hasChanges = displayName !== (user?.display_name || "");
|
||||
|
||||
return (
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
transition={{ delay: 0.2, duration: 0.4 }}
|
||||
className="h-full min-w-0 flex-1 overflow-hidden bg-background"
|
||||
>
|
||||
<div className="h-full overflow-y-auto">
|
||||
<div className="mx-auto max-w-4xl p-4 md:p-6 lg:p-10">
|
||||
<AnimatePresence mode="wait">
|
||||
<motion.div
|
||||
key="profile-header"
|
||||
initial={{ opacity: 0, y: 10 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: -10 }}
|
||||
transition={{ duration: 0.3 }}
|
||||
className="mb-6 md:mb-8"
|
||||
>
|
||||
<div className="flex items-center gap-3 md:gap-4">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
onClick={onMenuClick}
|
||||
className="h-10 w-10 shrink-0 md:hidden"
|
||||
>
|
||||
<Menu className="h-5 w-5" />
|
||||
</Button>
|
||||
<motion.div
|
||||
initial={{ scale: 0.8, opacity: 0 }}
|
||||
animate={{ scale: 1, opacity: 1 }}
|
||||
transition={{ delay: 0.1, duration: 0.3 }}
|
||||
className="flex h-10 w-10 shrink-0 items-center justify-center rounded-lg border border-primary/10 bg-gradient-to-br from-primary/20 to-primary/5 shadow-sm md:h-14 md:w-14 md:rounded-2xl"
|
||||
>
|
||||
<User className="h-5 w-5 text-primary md:h-7 md:w-7" />
|
||||
</motion.div>
|
||||
<div className="min-w-0">
|
||||
<h1 className="truncate text-lg font-bold tracking-tight md:text-2xl">
|
||||
{t("profile_title")}
|
||||
</h1>
|
||||
<p className="text-sm text-muted-foreground">{t("profile_description")}</p>
|
||||
</div>
|
||||
</div>
|
||||
</motion.div>
|
||||
</AnimatePresence>
|
||||
|
||||
<AnimatePresence mode="wait">
|
||||
<motion.div
|
||||
key="profile-content"
|
||||
initial={{ opacity: 0, y: 20 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: -20 }}
|
||||
transition={{ duration: 0.35, ease: [0.4, 0, 0.2, 1] }}
|
||||
>
|
||||
{isUserLoading ? (
|
||||
<div className="flex items-center justify-center py-12">
|
||||
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
|
||||
</div>
|
||||
) : (
|
||||
<form onSubmit={handleSubmit} className="space-y-6">
|
||||
<div className="rounded-lg border bg-card p-6">
|
||||
<div className="flex flex-col gap-6">
|
||||
<div className="space-y-2">
|
||||
<Label>{t("profile_avatar")}</Label>
|
||||
<AvatarDisplay
|
||||
url={user?.avatar_url || undefined}
|
||||
fallback={getInitials(user?.email || "")}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="display-name">{t("profile_display_name")}</Label>
|
||||
<Input
|
||||
id="display-name"
|
||||
type="text"
|
||||
placeholder={user?.email?.split("@")[0]}
|
||||
value={displayName}
|
||||
onChange={(e) => setDisplayName(e.target.value)}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{t("profile_display_name_hint")}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label>{t("profile_email")}</Label>
|
||||
<Input type="email" value={user?.email || ""} disabled />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex justify-end">
|
||||
<Button type="submit" disabled={isPending || !hasChanges}>
|
||||
{isPending && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
|
||||
{t("profile_save")}
|
||||
</Button>
|
||||
</div>
|
||||
</form>
|
||||
)}
|
||||
</motion.div>
|
||||
</AnimatePresence>
|
||||
</div>
|
||||
</div>
|
||||
</motion.div>
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,155 @@
|
|||
"use client";
|
||||
|
||||
import { ArrowLeft, ChevronRight, X } from "lucide-react";
|
||||
import type { LucideIcon } from "lucide-react";
|
||||
import { AnimatePresence, motion } from "motion/react";
|
||||
import { useTranslations } from "next-intl";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export interface SettingsNavItem {
|
||||
id: string;
|
||||
label: string;
|
||||
description: string;
|
||||
icon: LucideIcon;
|
||||
}
|
||||
|
||||
interface UserSettingsSidebarProps {
|
||||
activeSection: string;
|
||||
onSectionChange: (section: string) => void;
|
||||
onBackToApp: () => void;
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
navItems: SettingsNavItem[];
|
||||
}
|
||||
|
||||
export function UserSettingsSidebar({
|
||||
activeSection,
|
||||
onSectionChange,
|
||||
onBackToApp,
|
||||
isOpen,
|
||||
onClose,
|
||||
navItems,
|
||||
}: UserSettingsSidebarProps) {
|
||||
const t = useTranslations("userSettings");
|
||||
|
||||
const handleNavClick = (sectionId: string) => {
|
||||
onSectionChange(sectionId);
|
||||
onClose();
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<AnimatePresence>
|
||||
{isOpen && (
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
transition={{ duration: 0.2 }}
|
||||
className="fixed inset-0 z-40 bg-background/80 backdrop-blur-sm md:hidden"
|
||||
onClick={onClose}
|
||||
/>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
|
||||
<aside
|
||||
className={cn(
|
||||
"fixed left-0 top-0 z-50 md:relative md:z-auto",
|
||||
"flex h-full w-72 shrink-0 flex-col bg-background md:bg-muted/30",
|
||||
"md:border-r",
|
||||
"transition-transform duration-300 ease-out",
|
||||
"md:translate-x-0",
|
||||
isOpen ? "translate-x-0" : "-translate-x-full md:translate-x-0"
|
||||
)}
|
||||
>
|
||||
{/* Header with title */}
|
||||
<div className="space-y-3 p-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<Button
|
||||
variant="ghost"
|
||||
onClick={onBackToApp}
|
||||
className="group h-11 justify-start gap-3 px-3 hover:bg-muted"
|
||||
>
|
||||
<div className="flex h-8 w-8 items-center justify-center rounded-lg bg-primary/10 transition-colors group-hover:bg-primary/20">
|
||||
<ArrowLeft className="h-4 w-4 text-primary" />
|
||||
</div>
|
||||
<span className="font-medium">{t("back_to_app")}</span>
|
||||
</Button>
|
||||
<Button variant="ghost" size="icon" onClick={onClose} className="h-9 w-9 md:hidden">
|
||||
<X className="h-5 w-5" />
|
||||
</Button>
|
||||
</div>
|
||||
{/* Settings Title */}
|
||||
<div className="px-3">
|
||||
<h2 className="text-lg font-semibold text-foreground">{t("title")}</h2>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<nav className="flex-1 space-y-1 overflow-y-auto px-3 py-2">
|
||||
{navItems.map((item, index) => {
|
||||
const isActive = activeSection === item.id;
|
||||
const Icon = item.icon;
|
||||
|
||||
return (
|
||||
<motion.button
|
||||
key={item.id}
|
||||
initial={{ opacity: 0, x: -10 }}
|
||||
animate={{ opacity: 1, x: 0 }}
|
||||
transition={{ delay: 0.1 + index * 0.05, duration: 0.3 }}
|
||||
onClick={() => handleNavClick(item.id)}
|
||||
whileHover={{ scale: 1.01 }}
|
||||
whileTap={{ scale: 0.99 }}
|
||||
className={cn(
|
||||
"relative flex w-full items-center gap-3 rounded-xl px-3 py-3 text-left transition-all duration-200",
|
||||
isActive ? "border border-border bg-muted shadow-sm" : "hover:bg-muted/60"
|
||||
)}
|
||||
>
|
||||
{isActive && (
|
||||
<motion.div
|
||||
layoutId="userSettingsActiveIndicator"
|
||||
className="absolute left-0 top-1/2 h-8 w-1 -translate-y-1/2 rounded-r-full bg-primary"
|
||||
initial={false}
|
||||
transition={{
|
||||
type: "spring",
|
||||
stiffness: 500,
|
||||
damping: 35,
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
<div
|
||||
className={cn(
|
||||
"flex h-9 w-9 items-center justify-center rounded-lg transition-colors",
|
||||
isActive ? "bg-primary/10 text-primary" : "bg-muted text-muted-foreground"
|
||||
)}
|
||||
>
|
||||
<Icon className="h-4 w-4" />
|
||||
</div>
|
||||
<div className="min-w-0 flex-1">
|
||||
<p
|
||||
className={cn(
|
||||
"truncate text-sm font-medium transition-colors",
|
||||
isActive ? "text-foreground" : "text-muted-foreground"
|
||||
)}
|
||||
>
|
||||
{item.label}
|
||||
</p>
|
||||
<p className="truncate text-xs text-muted-foreground/70">{item.description}</p>
|
||||
</div>
|
||||
<ChevronRight
|
||||
className={cn(
|
||||
"h-4 w-4 shrink-0 transition-all",
|
||||
isActive
|
||||
? "translate-x-0 text-primary opacity-100"
|
||||
: "-translate-x-1 text-muted-foreground/40 opacity-0"
|
||||
)}
|
||||
/>
|
||||
</motion.button>
|
||||
);
|
||||
})}
|
||||
</nav>
|
||||
</aside>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -1,286 +1,27 @@
|
|||
"use client";
|
||||
|
||||
import {
|
||||
ArrowLeft,
|
||||
Check,
|
||||
ChevronRight,
|
||||
Copy,
|
||||
Key,
|
||||
type LucideIcon,
|
||||
Menu,
|
||||
Shield,
|
||||
X,
|
||||
} from "lucide-react";
|
||||
import { AnimatePresence, motion } from "motion/react";
|
||||
import { Key, User } from "lucide-react";
|
||||
import { motion } from "motion/react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useTranslations } from "next-intl";
|
||||
import { useCallback, useState } from "react";
|
||||
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip";
|
||||
import { useApiKey } from "@/hooks/use-api-key";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface SettingsNavItem {
|
||||
id: string;
|
||||
label: string;
|
||||
description: string;
|
||||
icon: LucideIcon;
|
||||
}
|
||||
|
||||
function UserSettingsSidebar({
|
||||
activeSection,
|
||||
onSectionChange,
|
||||
onBackToApp,
|
||||
isOpen,
|
||||
onClose,
|
||||
navItems,
|
||||
}: {
|
||||
activeSection: string;
|
||||
onSectionChange: (section: string) => void;
|
||||
onBackToApp: () => void;
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
navItems: SettingsNavItem[];
|
||||
}) {
|
||||
const t = useTranslations("userSettings");
|
||||
|
||||
const handleNavClick = (sectionId: string) => {
|
||||
onSectionChange(sectionId);
|
||||
onClose();
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<AnimatePresence>
|
||||
{isOpen && (
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
transition={{ duration: 0.2 }}
|
||||
className="fixed inset-0 z-40 bg-background/80 backdrop-blur-sm md:hidden"
|
||||
onClick={onClose}
|
||||
/>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
|
||||
<aside
|
||||
className={cn(
|
||||
"fixed left-0 top-0 z-50 md:relative md:z-auto",
|
||||
"flex h-full w-72 shrink-0 flex-col bg-background md:bg-muted/30",
|
||||
"md:border-r",
|
||||
"transition-transform duration-300 ease-out",
|
||||
"md:translate-x-0",
|
||||
isOpen ? "translate-x-0" : "-translate-x-full md:translate-x-0"
|
||||
)}
|
||||
>
|
||||
{/* Header with title */}
|
||||
<div className="space-y-3 p-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<Button
|
||||
variant="ghost"
|
||||
onClick={onBackToApp}
|
||||
className="group h-11 justify-start gap-3 px-3 hover:bg-muted"
|
||||
>
|
||||
<div className="flex h-8 w-8 items-center justify-center rounded-lg bg-primary/10 transition-colors group-hover:bg-primary/20">
|
||||
<ArrowLeft className="h-4 w-4 text-primary" />
|
||||
</div>
|
||||
<span className="font-medium">{t("back_to_app")}</span>
|
||||
</Button>
|
||||
<Button variant="ghost" size="icon" onClick={onClose} className="h-9 w-9 md:hidden">
|
||||
<X className="h-5 w-5" />
|
||||
</Button>
|
||||
</div>
|
||||
{/* Settings Title */}
|
||||
<div className="px-3">
|
||||
<h2 className="text-lg font-semibold text-foreground">{t("title")}</h2>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<nav className="flex-1 space-y-1 overflow-y-auto px-3 py-2">
|
||||
{navItems.map((item, index) => {
|
||||
const isActive = activeSection === item.id;
|
||||
const Icon = item.icon;
|
||||
|
||||
return (
|
||||
<motion.button
|
||||
key={item.id}
|
||||
initial={{ opacity: 0, x: -10 }}
|
||||
animate={{ opacity: 1, x: 0 }}
|
||||
transition={{ delay: 0.1 + index * 0.05, duration: 0.3 }}
|
||||
onClick={() => handleNavClick(item.id)}
|
||||
whileHover={{ scale: 1.01 }}
|
||||
whileTap={{ scale: 0.99 }}
|
||||
className={cn(
|
||||
"relative flex w-full items-center gap-3 rounded-xl px-3 py-3 text-left transition-all duration-200",
|
||||
isActive ? "border border-border bg-muted shadow-sm" : "hover:bg-muted/60"
|
||||
)}
|
||||
>
|
||||
{isActive && (
|
||||
<motion.div
|
||||
layoutId="userSettingsActiveIndicator"
|
||||
className="absolute left-0 top-1/2 h-8 w-1 -translate-y-1/2 rounded-r-full bg-primary"
|
||||
initial={false}
|
||||
transition={{
|
||||
type: "spring",
|
||||
stiffness: 500,
|
||||
damping: 35,
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
<div
|
||||
className={cn(
|
||||
"flex h-9 w-9 items-center justify-center rounded-lg transition-colors",
|
||||
isActive ? "bg-primary/10 text-primary" : "bg-muted text-muted-foreground"
|
||||
)}
|
||||
>
|
||||
<Icon className="h-4 w-4" />
|
||||
</div>
|
||||
<div className="min-w-0 flex-1">
|
||||
<p
|
||||
className={cn(
|
||||
"truncate text-sm font-medium transition-colors",
|
||||
isActive ? "text-foreground" : "text-muted-foreground"
|
||||
)}
|
||||
>
|
||||
{item.label}
|
||||
</p>
|
||||
<p className="truncate text-xs text-muted-foreground/70">{item.description}</p>
|
||||
</div>
|
||||
<ChevronRight
|
||||
className={cn(
|
||||
"h-4 w-4 shrink-0 transition-all",
|
||||
isActive
|
||||
? "translate-x-0 text-primary opacity-100"
|
||||
: "-translate-x-1 text-muted-foreground/40 opacity-0"
|
||||
)}
|
||||
/>
|
||||
</motion.button>
|
||||
);
|
||||
})}
|
||||
</nav>
|
||||
</aside>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
function ApiKeyContent({ onMenuClick }: { onMenuClick: () => void }) {
|
||||
const t = useTranslations("userSettings");
|
||||
const { apiKey, isLoading, copied, copyToClipboard } = useApiKey();
|
||||
|
||||
return (
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
transition={{ delay: 0.2, duration: 0.4 }}
|
||||
className="h-full min-w-0 flex-1 overflow-hidden bg-background"
|
||||
>
|
||||
<div className="h-full overflow-y-auto">
|
||||
<div className="mx-auto max-w-4xl p-4 md:p-6 lg:p-10">
|
||||
<AnimatePresence mode="wait">
|
||||
<motion.div
|
||||
key="api-key-header"
|
||||
initial={{ opacity: 0, y: 10 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: -10 }}
|
||||
transition={{ duration: 0.3 }}
|
||||
className="mb-6 md:mb-8"
|
||||
>
|
||||
<div className="flex items-center gap-3 md:gap-4">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
onClick={onMenuClick}
|
||||
className="h-10 w-10 shrink-0 md:hidden"
|
||||
>
|
||||
<Menu className="h-5 w-5" />
|
||||
</Button>
|
||||
<motion.div
|
||||
initial={{ scale: 0.8, opacity: 0 }}
|
||||
animate={{ scale: 1, opacity: 1 }}
|
||||
transition={{ delay: 0.1, duration: 0.3 }}
|
||||
className="flex h-10 w-10 shrink-0 items-center justify-center rounded-lg border border-primary/10 bg-gradient-to-br from-primary/20 to-primary/5 shadow-sm md:h-14 md:w-14 md:rounded-2xl"
|
||||
>
|
||||
<Key className="h-5 w-5 text-primary md:h-7 md:w-7" />
|
||||
</motion.div>
|
||||
<div className="min-w-0">
|
||||
<h1 className="truncate text-lg font-bold tracking-tight md:text-2xl">
|
||||
{t("api_key_title")}
|
||||
</h1>
|
||||
<p className="text-sm text-muted-foreground">{t("api_key_description")}</p>
|
||||
</div>
|
||||
</div>
|
||||
</motion.div>
|
||||
</AnimatePresence>
|
||||
|
||||
<AnimatePresence mode="wait">
|
||||
<motion.div
|
||||
key="api-key-content"
|
||||
initial={{ opacity: 0, y: 20 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: -20 }}
|
||||
transition={{ duration: 0.35, ease: [0.4, 0, 0.2, 1] }}
|
||||
className="space-y-6"
|
||||
>
|
||||
<Alert>
|
||||
<Shield className="h-4 w-4" />
|
||||
<AlertTitle>{t("api_key_warning_title")}</AlertTitle>
|
||||
<AlertDescription>{t("api_key_warning_description")}</AlertDescription>
|
||||
</Alert>
|
||||
|
||||
<div className="rounded-lg border bg-card p-6">
|
||||
<h3 className="mb-4 font-medium">{t("your_api_key")}</h3>
|
||||
{isLoading ? (
|
||||
<div className="h-12 w-full animate-pulse rounded-md bg-muted" />
|
||||
) : apiKey ? (
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="flex-1 overflow-x-auto rounded-md bg-muted p-3 font-mono text-sm">
|
||||
{apiKey}
|
||||
</div>
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
onClick={copyToClipboard}
|
||||
className="shrink-0"
|
||||
>
|
||||
{copied ? <Check className="h-4 w-4" /> : <Copy className="h-4 w-4" />}
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>{copied ? t("copied") : t("copy")}</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
) : (
|
||||
<p className="text-center text-muted-foreground">{t("no_api_key")}</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="rounded-lg border bg-card p-6">
|
||||
<h3 className="mb-2 font-medium">{t("usage_title")}</h3>
|
||||
<p className="mb-4 text-sm text-muted-foreground">{t("usage_description")}</p>
|
||||
<pre className="overflow-x-auto rounded-md bg-muted p-3 text-sm">
|
||||
<code>Authorization: Bearer {apiKey || "YOUR_API_KEY"}</code>
|
||||
</pre>
|
||||
</div>
|
||||
</motion.div>
|
||||
</AnimatePresence>
|
||||
</div>
|
||||
</div>
|
||||
</motion.div>
|
||||
);
|
||||
}
|
||||
import { ApiKeyContent } from "./components/ApiKeyContent";
|
||||
import { ProfileContent } from "./components/ProfileContent";
|
||||
import { UserSettingsSidebar, type SettingsNavItem } from "./components/UserSettingsSidebar";
|
||||
|
||||
export default function UserSettingsPage() {
|
||||
const t = useTranslations("userSettings");
|
||||
const router = useRouter();
|
||||
const [activeSection, setActiveSection] = useState("api-key");
|
||||
const [activeSection, setActiveSection] = useState("profile");
|
||||
const [isSidebarOpen, setIsSidebarOpen] = useState(false);
|
||||
|
||||
const navItems: SettingsNavItem[] = [
|
||||
{
|
||||
id: "profile",
|
||||
label: t("profile_nav_label"),
|
||||
description: t("profile_nav_description"),
|
||||
icon: User,
|
||||
},
|
||||
{
|
||||
id: "api-key",
|
||||
label: t("api_key_nav_label"),
|
||||
|
|
@ -310,6 +51,9 @@ export default function UserSettingsPage() {
|
|||
onClose={() => setIsSidebarOpen(false)}
|
||||
navItems={navItems}
|
||||
/>
|
||||
{activeSection === "profile" && (
|
||||
<ProfileContent onMenuClick={() => setIsSidebarOpen(true)} />
|
||||
)}
|
||||
{activeSection === "api-key" && (
|
||||
<ApiKeyContent onMenuClick={() => setIsSidebarOpen(true)} />
|
||||
)}
|
||||
|
|
|
|||
19
surfsense_web/atoms/user/user-mutation.atoms.ts
Normal file
19
surfsense_web/atoms/user/user-mutation.atoms.ts
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
import { atomWithMutation, queryClientAtom } from "jotai-tanstack-query";
|
||||
import type { UpdateUserRequest } from "@/contracts/types/user.types";
|
||||
import { userApiService } from "@/lib/apis/user-api.service";
|
||||
import { cacheKeys } from "@/lib/query-client/cache-keys";
|
||||
|
||||
export const updateUserMutationAtom = atomWithMutation((get) => {
|
||||
const queryClient = get(queryClientAtom);
|
||||
|
||||
return {
|
||||
mutationKey: cacheKeys.user.current(),
|
||||
mutationFn: async (request: UpdateUserRequest) => {
|
||||
return userApiService.updateMe(request);
|
||||
},
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: cacheKeys.user.current() });
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
|
|
@ -19,9 +19,7 @@ import {
|
|||
ChevronRightIcon,
|
||||
CopyIcon,
|
||||
DownloadIcon,
|
||||
FileText,
|
||||
Loader2,
|
||||
PencilIcon,
|
||||
RefreshCwIcon,
|
||||
SquareIcon,
|
||||
} from "lucide-react";
|
||||
|
|
@ -31,7 +29,6 @@ import { createPortal } from "react-dom";
|
|||
import {
|
||||
mentionedDocumentIdsAtom,
|
||||
mentionedDocumentsAtom,
|
||||
messageDocumentsMapAtom,
|
||||
} from "@/atoms/chat/mentioned-documents.atom";
|
||||
import {
|
||||
globalNewLLMConfigsAtom,
|
||||
|
|
@ -42,8 +39,8 @@ import { currentUserAtom } from "@/atoms/user/user-query.atoms";
|
|||
import {
|
||||
ComposerAddAttachment,
|
||||
ComposerAttachments,
|
||||
UserMessageAttachments,
|
||||
} from "@/components/assistant-ui/attachment";
|
||||
import { UserMessage } from "@/components/assistant-ui/user-message";
|
||||
import { ConnectorIndicator } from "@/components/assistant-ui/connector-popup";
|
||||
import {
|
||||
InlineMentionEditor,
|
||||
|
|
@ -639,69 +636,6 @@ const AssistantActionBar: FC = () => {
|
|||
);
|
||||
};
|
||||
|
||||
const UserMessage: FC = () => {
|
||||
const messageId = useAssistantState(({ message }) => message?.id);
|
||||
const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom);
|
||||
const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined;
|
||||
const hasAttachments = useAssistantState(
|
||||
({ message }) => message?.attachments && message.attachments.length > 0
|
||||
);
|
||||
|
||||
return (
|
||||
<MessagePrimitive.Root
|
||||
className="aui-user-message-root fade-in slide-in-from-bottom-1 mx-auto grid w-full max-w-(--thread-max-width) animate-in auto-rows-auto grid-cols-[minmax(72px,1fr)_auto] content-start gap-y-2 px-2 py-3 duration-150 [&:where(>*)]:col-start-2"
|
||||
data-role="user"
|
||||
>
|
||||
<div className="aui-user-message-content-wrapper col-start-2 min-w-0">
|
||||
{/* Display attachments and mentioned documents */}
|
||||
{(hasAttachments || (mentionedDocs && mentionedDocs.length > 0)) && (
|
||||
<div className="flex flex-wrap items-end gap-2 mb-2 justify-end">
|
||||
{/* Attachments (images show as thumbnails, documents as chips) */}
|
||||
<UserMessageAttachments />
|
||||
{/* Mentioned documents as chips */}
|
||||
{mentionedDocs?.map((doc) => (
|
||||
<span
|
||||
key={`${doc.document_type}:${doc.id}`}
|
||||
className="inline-flex items-center gap-1 px-2 py-0.5 rounded-full bg-primary/10 text-xs font-medium text-primary border border-primary/20"
|
||||
title={doc.title}
|
||||
>
|
||||
<FileText className="size-3" />
|
||||
<span className="max-w-[150px] truncate">{doc.title}</span>
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
{/* Message bubble with action bar positioned relative to it */}
|
||||
<div className="relative">
|
||||
<div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground">
|
||||
<MessagePrimitive.Parts />
|
||||
</div>
|
||||
<div className="aui-user-action-bar-wrapper absolute top-1/2 right-full -translate-y-1/2 pr-1">
|
||||
<UserActionBar />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<BranchPicker className="aui-user-branch-picker -mr-1 col-span-full col-start-1 row-start-3 justify-end" />
|
||||
</MessagePrimitive.Root>
|
||||
);
|
||||
};
|
||||
|
||||
const UserActionBar: FC = () => {
|
||||
return (
|
||||
<ActionBarPrimitive.Root
|
||||
hideWhenRunning
|
||||
autohide="not-last"
|
||||
className="aui-user-action-bar-root flex flex-col items-end"
|
||||
>
|
||||
<ActionBarPrimitive.Edit asChild>
|
||||
<TooltipIconButton tooltip="Edit" className="aui-user-action-edit p-4">
|
||||
<PencilIcon />
|
||||
</TooltipIconButton>
|
||||
</ActionBarPrimitive.Edit>
|
||||
</ActionBarPrimitive.Root>
|
||||
);
|
||||
};
|
||||
|
||||
const EditComposer: FC = () => {
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -1,16 +1,54 @@
|
|||
import { ActionBarPrimitive, MessagePrimitive, useAssistantState } from "@assistant-ui/react";
|
||||
import { useAtomValue } from "jotai";
|
||||
import { FileText, PencilIcon } from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import { type FC, useState } from "react";
|
||||
import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom";
|
||||
import { UserMessageAttachments } from "@/components/assistant-ui/attachment";
|
||||
import { BranchPicker } from "@/components/assistant-ui/branch-picker";
|
||||
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
|
||||
|
||||
interface AuthorMetadata {
|
||||
displayName: string | null;
|
||||
avatarUrl: string | null;
|
||||
}
|
||||
|
||||
const UserAvatar: FC<AuthorMetadata> = ({ displayName, avatarUrl }) => {
|
||||
const [hasError, setHasError] = useState(false);
|
||||
|
||||
const initials = displayName
|
||||
? displayName
|
||||
.split(" ")
|
||||
.map((n) => n[0])
|
||||
.join("")
|
||||
.toUpperCase()
|
||||
.slice(0, 2)
|
||||
: "U";
|
||||
|
||||
if (avatarUrl && !hasError) {
|
||||
return (
|
||||
<img
|
||||
src={avatarUrl}
|
||||
alt={displayName || "User"}
|
||||
className="size-8 rounded-full object-cover"
|
||||
referrerPolicy="no-referrer"
|
||||
onError={() => setHasError(true)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex size-8 items-center justify-center rounded-full bg-primary/10 text-xs font-medium text-primary">
|
||||
{initials}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export const UserMessage: FC = () => {
|
||||
const messageId = useAssistantState(({ message }) => message?.id);
|
||||
const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom);
|
||||
const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined;
|
||||
const metadata = useAssistantState(({ message }) => message?.metadata);
|
||||
const author = metadata?.custom?.author as AuthorMetadata | undefined;
|
||||
const hasAttachments = useAssistantState(
|
||||
({ message }) => message?.attachments && message.attachments.length > 0
|
||||
);
|
||||
|
|
@ -20,34 +58,42 @@ export const UserMessage: FC = () => {
|
|||
className="aui-user-message-root fade-in slide-in-from-bottom-1 mx-auto grid w-full max-w-(--thread-max-width) animate-in auto-rows-auto grid-cols-[minmax(72px,1fr)_auto] content-start gap-y-2 px-2 py-3 duration-150 [&:where(>*)]:col-start-2"
|
||||
data-role="user"
|
||||
>
|
||||
<div className="aui-user-message-content-wrapper col-start-2 min-w-0">
|
||||
{/* Display attachments and mentioned documents */}
|
||||
{(hasAttachments || (mentionedDocs && mentionedDocs.length > 0)) && (
|
||||
<div className="flex flex-wrap items-end gap-2 mb-2 justify-end">
|
||||
{/* Attachments (images show as thumbnails, documents as chips) */}
|
||||
<UserMessageAttachments />
|
||||
{/* Mentioned documents as chips */}
|
||||
{mentionedDocs?.map((doc) => (
|
||||
<span
|
||||
key={`${doc.document_type}:${doc.id}`}
|
||||
className="inline-flex items-center gap-1 px-2 py-0.5 rounded-full bg-primary/10 text-xs font-medium text-primary border border-primary/20"
|
||||
title={doc.title}
|
||||
>
|
||||
<FileText className="size-3" />
|
||||
<span className="max-w-[150px] truncate">{doc.title}</span>
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
{/* Message bubble with action bar positioned relative to it */}
|
||||
<div className="relative">
|
||||
<div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground">
|
||||
<MessagePrimitive.Parts />
|
||||
</div>
|
||||
<div className="aui-user-action-bar-wrapper absolute top-1/2 right-full -translate-y-1/2 pr-1">
|
||||
<UserActionBar />
|
||||
<div className="aui-user-message-content-wrapper col-start-2 min-w-0 flex items-end gap-2">
|
||||
<div className="flex-1 min-w-0">
|
||||
{/* Display attachments and mentioned documents */}
|
||||
{(hasAttachments || (mentionedDocs && mentionedDocs.length > 0)) && (
|
||||
<div className="flex flex-wrap items-end gap-2 mb-2 justify-end">
|
||||
{/* Attachments (images show as thumbnails, documents as chips) */}
|
||||
<UserMessageAttachments />
|
||||
{/* Mentioned documents as chips */}
|
||||
{mentionedDocs?.map((doc) => (
|
||||
<span
|
||||
key={`${doc.document_type}:${doc.id}`}
|
||||
className="inline-flex items-center gap-1 px-2 py-0.5 rounded-full bg-primary/10 text-xs font-medium text-primary border border-primary/20"
|
||||
title={doc.title}
|
||||
>
|
||||
<FileText className="size-3" />
|
||||
<span className="max-w-[150px] truncate">{doc.title}</span>
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
{/* Message bubble with action bar positioned relative to it */}
|
||||
<div className="relative">
|
||||
<div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground">
|
||||
<MessagePrimitive.Parts />
|
||||
</div>
|
||||
<div className="aui-user-action-bar-wrapper absolute top-1/2 right-full -translate-y-1/2 pr-1">
|
||||
<UserActionBar />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/* User avatar - only shown in shared chats */}
|
||||
{author && (
|
||||
<div className="shrink-0">
|
||||
<UserAvatar displayName={author.displayName} avatarUrl={author.avatarUrl} />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<BranchPicker className="aui-user-branch-picker -mr-1 col-span-full col-start-1 row-start-3 justify-end" />
|
||||
|
|
|
|||
|
|
@ -354,7 +354,11 @@ export function LayoutDataProvider({
|
|||
onChatDelete={handleChatDelete}
|
||||
onViewAllSharedChats={handleViewAllSharedChats}
|
||||
onViewAllPrivateChats={handleViewAllPrivateChats}
|
||||
user={{ email: user?.email || "", name: user?.email?.split("@")[0] }}
|
||||
user={{
|
||||
email: user?.email || "",
|
||||
name: user?.display_name || user?.email?.split("@")[0],
|
||||
avatarUrl: user?.avatar_url || undefined,
|
||||
}}
|
||||
onSettings={handleSettings}
|
||||
onManageMembers={handleManageMembers}
|
||||
onUserSettings={handleUserSettings}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ export interface SearchSpace {
|
|||
export interface User {
|
||||
email: string;
|
||||
name?: string;
|
||||
avatarUrl?: string;
|
||||
}
|
||||
|
||||
export interface NavItem {
|
||||
|
|
|
|||
|
|
@ -61,6 +61,39 @@ function getInitials(email: string): string {
|
|||
return name.slice(0, 2).toUpperCase();
|
||||
}
|
||||
|
||||
/**
|
||||
* User avatar component - shows image if available, otherwise falls back to initials
|
||||
*/
|
||||
function UserAvatar({
|
||||
avatarUrl,
|
||||
initials,
|
||||
bgColor,
|
||||
}: {
|
||||
avatarUrl?: string;
|
||||
initials: string;
|
||||
bgColor: string;
|
||||
}) {
|
||||
if (avatarUrl) {
|
||||
return (
|
||||
<img
|
||||
src={avatarUrl}
|
||||
alt="User avatar"
|
||||
className="h-8 w-8 shrink-0 rounded-lg object-cover"
|
||||
referrerPolicy="no-referrer"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg text-xs font-semibold text-white"
|
||||
style={{ backgroundColor: bgColor }}
|
||||
>
|
||||
{initials}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function SidebarUserProfile({
|
||||
user,
|
||||
onUserSettings,
|
||||
|
|
@ -88,12 +121,7 @@ export function SidebarUserProfile({
|
|||
"focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring"
|
||||
)}
|
||||
>
|
||||
<div
|
||||
className="flex h-8 w-8 items-center justify-center rounded-lg text-xs font-semibold text-white"
|
||||
style={{ backgroundColor: bgColor }}
|
||||
>
|
||||
{initials}
|
||||
</div>
|
||||
<UserAvatar avatarUrl={user.avatarUrl} initials={initials} bgColor={bgColor} />
|
||||
<span className="sr-only">{displayName}</span>
|
||||
</button>
|
||||
</DropdownMenuTrigger>
|
||||
|
|
@ -104,12 +132,7 @@ export function SidebarUserProfile({
|
|||
<DropdownMenuContent className="w-56" side="right" align="end" sideOffset={8}>
|
||||
<DropdownMenuLabel className="font-normal">
|
||||
<div className="flex items-center gap-2">
|
||||
<div
|
||||
className="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg text-xs font-semibold text-white"
|
||||
style={{ backgroundColor: bgColor }}
|
||||
>
|
||||
{initials}
|
||||
</div>
|
||||
<UserAvatar avatarUrl={user.avatarUrl} initials={initials} bgColor={bgColor} />
|
||||
<div className="flex-1 min-w-0">
|
||||
<p className="truncate text-sm font-medium">{displayName}</p>
|
||||
<p className="truncate text-xs text-muted-foreground">{user.email}</p>
|
||||
|
|
@ -149,13 +172,7 @@ export function SidebarUserProfile({
|
|||
"focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring"
|
||||
)}
|
||||
>
|
||||
{/* Avatar */}
|
||||
<div
|
||||
className="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg text-xs font-semibold text-white"
|
||||
style={{ backgroundColor: bgColor }}
|
||||
>
|
||||
{initials}
|
||||
</div>
|
||||
<UserAvatar avatarUrl={user.avatarUrl} initials={initials} bgColor={bgColor} />
|
||||
|
||||
{/* Name and email */}
|
||||
<div className="flex-1 min-w-0">
|
||||
|
|
@ -171,12 +188,7 @@ export function SidebarUserProfile({
|
|||
<DropdownMenuContent className="w-56" side="top" align="start" sideOffset={4}>
|
||||
<DropdownMenuLabel className="font-normal">
|
||||
<div className="flex items-center gap-2">
|
||||
<div
|
||||
className="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg text-xs font-semibold text-white"
|
||||
style={{ backgroundColor: bgColor }}
|
||||
>
|
||||
{initials}
|
||||
</div>
|
||||
<UserAvatar avatarUrl={user.avatarUrl} initials={initials} bgColor={bgColor} />
|
||||
<div className="flex-1 min-w-0">
|
||||
<p className="truncate text-sm font-medium">{displayName}</p>
|
||||
<p className="truncate text-xs text-muted-foreground">{user.email}</p>
|
||||
|
|
|
|||
|
|
@ -215,6 +215,16 @@ export const DocumentMentionPicker = forwardRef<
|
|||
isSurfsenseDocsLoading) &&
|
||||
currentPage === 0;
|
||||
|
||||
// Split documents into SurfSense docs and user docs for grouped rendering
|
||||
const surfsenseDocsList = useMemo(
|
||||
() => actualDocuments.filter((doc) => doc.document_type === "SURFSENSE_DOCS"),
|
||||
[actualDocuments]
|
||||
);
|
||||
const userDocsList = useMemo(
|
||||
() => actualDocuments.filter((doc) => doc.document_type !== "SURFSENSE_DOCS"),
|
||||
[actualDocuments]
|
||||
);
|
||||
|
||||
// Track already selected documents using unique key (document_type:id) to avoid ID collisions
|
||||
const selectedKeys = useMemo(
|
||||
() => new Set(initialSelectedDocuments.map((d) => `${d.document_type}:${d.id}`)),
|
||||
|
|
@ -324,47 +334,102 @@ export const DocumentMentionPicker = forwardRef<
|
|||
</div>
|
||||
) : (
|
||||
<div className="py-1">
|
||||
{actualDocuments.map((doc) => {
|
||||
const docKey = `${doc.document_type}:${doc.id}`;
|
||||
const isAlreadySelected = selectedKeys.has(docKey);
|
||||
const selectableIndex = selectableDocuments.findIndex(
|
||||
(d) => d.document_type === doc.document_type && d.id === doc.id
|
||||
);
|
||||
const isHighlighted = !isAlreadySelected && selectableIndex === highlightedIndex;
|
||||
{/* SurfSense Documentation Section */}
|
||||
{surfsenseDocsList.length > 0 && (
|
||||
<>
|
||||
<div className="sticky top-0 z-10 px-3 py-2 text-xs font-bold uppercase tracking-wider bg-muted text-foreground/80 border-b border-border">
|
||||
SurfSense Docs
|
||||
</div>
|
||||
{surfsenseDocsList.map((doc) => {
|
||||
const docKey = `${doc.document_type}:${doc.id}`;
|
||||
const isAlreadySelected = selectedKeys.has(docKey);
|
||||
const selectableIndex = selectableDocuments.findIndex(
|
||||
(d) => d.document_type === doc.document_type && d.id === doc.id
|
||||
);
|
||||
const isHighlighted = !isAlreadySelected && selectableIndex === highlightedIndex;
|
||||
|
||||
return (
|
||||
<button
|
||||
key={docKey}
|
||||
ref={(el) => {
|
||||
if (el && selectableIndex >= 0) {
|
||||
itemRefs.current.set(selectableIndex, el);
|
||||
}
|
||||
}}
|
||||
type="button"
|
||||
onClick={() => !isAlreadySelected && handleSelectDocument(doc)}
|
||||
onMouseEnter={() => {
|
||||
if (!isAlreadySelected && selectableIndex >= 0) {
|
||||
setHighlightedIndex(selectableIndex);
|
||||
}
|
||||
}}
|
||||
disabled={isAlreadySelected}
|
||||
className={cn(
|
||||
"w-full flex items-center gap-2 px-3 py-2 text-left transition-colors",
|
||||
isAlreadySelected ? "opacity-50 cursor-not-allowed" : "cursor-pointer",
|
||||
isHighlighted && "bg-accent"
|
||||
)}
|
||||
>
|
||||
<span className="shrink-0 text-muted-foreground text-sm">
|
||||
{getConnectorIcon(doc.document_type)}
|
||||
</span>
|
||||
<span className="flex-1 text-sm truncate" title={doc.title}>
|
||||
{doc.title}
|
||||
</span>
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* User Documents Section */}
|
||||
{userDocsList.length > 0 && (
|
||||
<>
|
||||
<div className="sticky top-0 z-10 px-3 py-2 text-xs font-bold uppercase tracking-wider bg-muted text-foreground/80 border-b border-border">
|
||||
Your Documents
|
||||
</div>
|
||||
{userDocsList.map((doc) => {
|
||||
const docKey = `${doc.document_type}:${doc.id}`;
|
||||
const isAlreadySelected = selectedKeys.has(docKey);
|
||||
const selectableIndex = selectableDocuments.findIndex(
|
||||
(d) => d.document_type === doc.document_type && d.id === doc.id
|
||||
);
|
||||
const isHighlighted = !isAlreadySelected && selectableIndex === highlightedIndex;
|
||||
|
||||
return (
|
||||
<button
|
||||
key={docKey}
|
||||
ref={(el) => {
|
||||
if (el && selectableIndex >= 0) {
|
||||
itemRefs.current.set(selectableIndex, el);
|
||||
}
|
||||
}}
|
||||
type="button"
|
||||
onClick={() => !isAlreadySelected && handleSelectDocument(doc)}
|
||||
onMouseEnter={() => {
|
||||
if (!isAlreadySelected && selectableIndex >= 0) {
|
||||
setHighlightedIndex(selectableIndex);
|
||||
}
|
||||
}}
|
||||
disabled={isAlreadySelected}
|
||||
className={cn(
|
||||
"w-full flex items-center gap-2 px-3 py-2 text-left transition-colors",
|
||||
isAlreadySelected ? "opacity-50 cursor-not-allowed" : "cursor-pointer",
|
||||
isHighlighted && "bg-accent"
|
||||
)}
|
||||
>
|
||||
<span className="shrink-0 text-muted-foreground text-sm">
|
||||
{getConnectorIcon(doc.document_type)}
|
||||
</span>
|
||||
<span className="flex-1 text-sm truncate" title={doc.title}>
|
||||
{doc.title}
|
||||
</span>
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
|
||||
return (
|
||||
<button
|
||||
key={docKey}
|
||||
ref={(el) => {
|
||||
if (el && selectableIndex >= 0) {
|
||||
itemRefs.current.set(selectableIndex, el);
|
||||
}
|
||||
}}
|
||||
type="button"
|
||||
onClick={() => !isAlreadySelected && handleSelectDocument(doc)}
|
||||
onMouseEnter={() => {
|
||||
if (!isAlreadySelected && selectableIndex >= 0) {
|
||||
setHighlightedIndex(selectableIndex);
|
||||
}
|
||||
}}
|
||||
disabled={isAlreadySelected}
|
||||
className={cn(
|
||||
"w-full flex items-center gap-2 px-3 py-2 text-left transition-colors",
|
||||
isAlreadySelected ? "opacity-50 cursor-not-allowed" : "cursor-pointer",
|
||||
isHighlighted && "bg-accent"
|
||||
)}
|
||||
>
|
||||
{/* Type icon */}
|
||||
<span className="flex-shrink-0 text-muted-foreground text-sm">
|
||||
{getConnectorIcon(doc.document_type)}
|
||||
</span>
|
||||
{/* Title */}
|
||||
<span className="flex-1 text-sm truncate" title={doc.title}>
|
||||
{doc.title}
|
||||
</span>
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
{/* Loading indicator for additional pages */}
|
||||
{isLoadingMore && (
|
||||
<div className="flex items-center justify-center py-2">
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ export const user = z.object({
|
|||
is_verified: z.boolean(),
|
||||
pages_limit: z.number(),
|
||||
pages_used: z.number(),
|
||||
display_name: z.string().nullish(),
|
||||
avatar_url: z.string().nullish(),
|
||||
});
|
||||
|
||||
/**
|
||||
|
|
@ -15,5 +17,20 @@ export const user = z.object({
|
|||
*/
|
||||
export const getMeResponse = user;
|
||||
|
||||
/**
|
||||
* Update current user request
|
||||
*/
|
||||
export const updateUserRequest = z.object({
|
||||
display_name: z.string().nullish(),
|
||||
avatar_url: z.string().nullish(),
|
||||
});
|
||||
|
||||
/**
|
||||
* Update current user response
|
||||
*/
|
||||
export const updateUserResponse = user;
|
||||
|
||||
export type User = z.infer<typeof user>;
|
||||
export type GetMeResponse = z.infer<typeof getMeResponse>;
|
||||
export type UpdateUserRequest = z.infer<typeof updateUserRequest>;
|
||||
export type UpdateUserResponse = z.infer<typeof updateUserResponse>;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
import { getMeResponse } from "@/contracts/types/user.types";
|
||||
import {
|
||||
getMeResponse,
|
||||
updateUserResponse,
|
||||
type UpdateUserRequest,
|
||||
} from "@/contracts/types/user.types";
|
||||
import { baseApiService } from "./base-api.service";
|
||||
|
||||
class UserApiService {
|
||||
|
|
@ -8,6 +12,15 @@ class UserApiService {
|
|||
getMe = async () => {
|
||||
return baseApiService.get(`/users/me`, getMeResponse);
|
||||
};
|
||||
|
||||
/**
|
||||
* Update current authenticated user
|
||||
*/
|
||||
updateMe = async (request: UpdateUserRequest) => {
|
||||
return baseApiService.patch(`/users/me`, updateUserResponse, {
|
||||
body: request,
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
export const userApiService = new UserApiService();
|
||||
|
|
|
|||
|
|
@ -31,6 +31,9 @@ export interface MessageRecord {
|
|||
role: "user" | "assistant" | "system";
|
||||
content: unknown;
|
||||
created_at: string;
|
||||
author_id?: string | null;
|
||||
author_display_name?: string | null;
|
||||
author_avatar_url?: string | null;
|
||||
}
|
||||
|
||||
export interface ThreadListResponse {
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
/**
|
||||
* Environment configuration for the frontend.
|
||||
*
|
||||
*
|
||||
* This file centralizes access to NEXT_PUBLIC_* environment variables.
|
||||
* For Docker deployments, these placeholders are replaced at container startup
|
||||
* via sed in the entrypoint script.
|
||||
*
|
||||
*
|
||||
* IMPORTANT: Do not use template literals or complex expressions with these values
|
||||
* as it may prevent the sed replacement from working correctly.
|
||||
*/
|
||||
|
|
@ -24,5 +24,5 @@ export const ETL_SERVICE = process.env.NEXT_PUBLIC_ETL_SERVICE || "DOCLING";
|
|||
// Helper to check if local auth is enabled
|
||||
export const isLocalAuth = () => AUTH_TYPE === "LOCAL";
|
||||
|
||||
// Helper to check if Google auth is enabled
|
||||
// Helper to check if Google auth is enabled
|
||||
export const isGoogleAuth = () => AUTH_TYPE === "GOOGLE";
|
||||
|
|
|
|||
|
|
@ -109,6 +109,17 @@
|
|||
"title": "User Settings",
|
||||
"description": "Manage your account settings and API access",
|
||||
"back_to_app": "Back to app",
|
||||
"profile_nav_label": "Profile",
|
||||
"profile_nav_description": "Manage your display name and avatar",
|
||||
"profile_title": "Profile",
|
||||
"profile_description": "Update your personal information",
|
||||
"profile_avatar": "Profile Picture",
|
||||
"profile_display_name": "Display Name",
|
||||
"profile_display_name_hint": "This is how your name appears across the app",
|
||||
"profile_email": "Email",
|
||||
"profile_save": "Save Changes",
|
||||
"profile_saved": "Profile updated successfully",
|
||||
"profile_save_error": "Failed to update profile",
|
||||
"api_key_nav_label": "API Key",
|
||||
"api_key_nav_description": "Manage your API access token",
|
||||
"api_key_title": "API Key",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue