Merge remote-tracking branch 'upstream/dev' into feat/replace-logs

This commit is contained in:
Anish Sarkar 2026-01-15 03:07:20 +05:30
commit 2e0f742000
47 changed files with 2365 additions and 700 deletions

View file

@ -0,0 +1,54 @@
"""Initial schema setup
Revision ID: 0
Revises: None
Creates all tables from SQLAlchemy models. Idempotent - safe to run on existing databases.
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
revision: str = "0"
down_revision: str | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
from app.db import Base
connection = op.get_bind()
# Create tables
op.execute(sa.text("CREATE EXTENSION IF NOT EXISTS vector"))
Base.metadata.create_all(bind=connection)
# Set up indexes
op.execute(
sa.text(
"CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)"
)
)
op.execute(
sa.text(
"CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector('english', content))"
)
)
op.execute(
sa.text(
"CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)"
)
)
op.execute(
sa.text(
"CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector('english', content))"
)
)
def downgrade() -> None:
pass

View file

@ -6,6 +6,8 @@ Revises: 9
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
@ -18,9 +20,37 @@ depends_on: str | Sequence[str] | None = None
CHAT_TYPE_ENUM = "chattype"
def enum_exists(enum_name: str) -> bool:
"""Check if an enum type exists in the database."""
conn = op.get_bind()
result = conn.execute(
sa.text(
"SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)"
),
{"enum_name": enum_name},
)
return result.scalar()
def table_exists(table_name: str) -> bool:
"""Check if a table exists in the database."""
conn = op.get_bind()
result = conn.execute(
sa.text(
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)"
),
{"table_name": table_name},
)
return result.scalar()
def upgrade() -> None:
"""Upgrade schema - replace ChatType enum values with new QNA/REPORT structure."""
# Skip if chats table or chattype enum doesn't exist (fresh database)
if not table_exists("chats") or not enum_exists(CHAT_TYPE_ENUM):
return
# Old enum name for temporary storage
old_enum_name = f"{CHAT_TYPE_ENUM}_old"
@ -72,6 +102,10 @@ def upgrade() -> None:
def downgrade() -> None:
"""Downgrade schema - revert ChatType enum to old GENERAL/DEEP/DEEPER/DEEPEST structure."""
# Skip if chats table or chattype enum doesn't exist
if not table_exists("chats") or not enum_exists(CHAT_TYPE_ENUM):
return
# Old enum name for temporary storage
old_enum_name = f"{CHAT_TYPE_ENUM}_old"

View file

@ -7,22 +7,36 @@ Revises:
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# Import pgvector if needed for other types, though not for this ENUM change
# import pgvector
# revision identifiers, used by Alembic.
revision: str = "1"
down_revision: str | None = None
down_revision: str | None = "0"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def enum_exists(enum_name: str) -> bool:
"""Check if an enum type exists in the database."""
conn = op.get_bind()
result = conn.execute(
sa.text(
"SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)"
),
{"enum_name": enum_name},
)
return result.scalar()
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Skip if the enum doesn't exist (fresh DB after downgrade - create_db_and_tables will handle it)
if not enum_exists("searchsourceconnectortype"):
return
# Manually add the command to add the enum value
# Note: It's generally better to let autogenerate handle this, but we're bypassing it
op.execute(
@ -51,6 +65,10 @@ END$$;
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Skip if the enum doesn't exist
if not enum_exists("searchsourceconnectortype"):
return
# Downgrading removal of an enum value is complex and potentially dangerous
# if the value is in use. Often omitted or requires manual SQL based on context.
# For now, we'll just pass. If you needed to reverse this, you'd likely

View file

@ -7,6 +7,8 @@ Revises: 23
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
@ -16,11 +18,27 @@ branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def table_exists(table_name: str) -> bool:
"""Check if a table exists in the database."""
conn = op.get_bind()
result = conn.execute(
sa.text(
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)"
),
{"table_name": table_name},
)
return result.scalar()
def upgrade() -> None:
"""
Fix any chats with NULL type values by setting them to QNA.
This handles edge cases from previous migrations where type values were not properly migrated.
"""
# Skip if chats table doesn't exist (fresh database)
if not table_exists("chats"):
return
# Update any NULL type values to QNA (the default chat type)
op.execute(
"""

View file

@ -10,6 +10,8 @@ Revises: 33
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers
@ -19,42 +21,59 @@ branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def table_exists(table_name: str) -> bool:
"""Check if a table exists in the database."""
conn = op.get_bind()
result = conn.execute(
sa.text(
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)"
),
{"table_name": table_name},
)
return result.scalar()
def upgrade() -> None:
"""Add columns only if they don't already exist (safe for re-runs)."""
# Add 'state_version' column to chats table (default 1)
op.execute("""
ALTER TABLE chats
ADD COLUMN IF NOT EXISTS state_version BIGINT DEFAULT 1 NOT NULL
""")
# Skip if chats table doesn't exist (fresh database)
if table_exists("chats"):
op.execute("""
ALTER TABLE chats
ADD COLUMN IF NOT EXISTS state_version BIGINT DEFAULT 1 NOT NULL
""")
# Add 'chat_state_version' column to podcasts table
op.execute("""
ALTER TABLE podcasts
ADD COLUMN IF NOT EXISTS chat_state_version BIGINT
""")
if table_exists("podcasts"):
op.execute("""
ALTER TABLE podcasts
ADD COLUMN IF NOT EXISTS chat_state_version BIGINT
""")
# Add 'chat_id' column to podcasts table
op.execute("""
ALTER TABLE podcasts
ADD COLUMN IF NOT EXISTS chat_id INTEGER
""")
# Add 'chat_id' column to podcasts table
op.execute("""
ALTER TABLE podcasts
ADD COLUMN IF NOT EXISTS chat_id INTEGER
""")
def downgrade() -> None:
"""Remove columns only if they exist."""
op.execute("""
ALTER TABLE podcasts
DROP COLUMN IF EXISTS chat_state_version
""")
if table_exists("podcasts"):
op.execute("""
ALTER TABLE podcasts
DROP COLUMN IF EXISTS chat_state_version
""")
op.execute("""
ALTER TABLE podcasts
DROP COLUMN IF EXISTS chat_id
""")
op.execute("""
ALTER TABLE podcasts
DROP COLUMN IF EXISTS chat_id
""")
op.execute("""
ALTER TABLE chats
DROP COLUMN IF EXISTS state_version
""")
if table_exists("chats"):
op.execute("""
ALTER TABLE chats
DROP COLUMN IF EXISTS state_version
""")

View file

@ -62,8 +62,25 @@ def parse_timestamp(ts, fallback):
return fallback
def table_exists(table_name: str) -> bool:
"""Check if a table exists in the database."""
conn = op.get_bind()
result = conn.execute(
sa.text(
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)"
),
{"table_name": table_name},
)
return result.scalar()
def upgrade() -> None:
"""Migrate old chats to new_chat_threads and remove old tables."""
# Skip if chats table doesn't exist (fresh database)
if not table_exists("chats"):
print("[Migration 49] Chats table does not exist, skipping migration")
return
connection = op.get_bind()
# Get all old chats
@ -176,36 +193,49 @@ def upgrade() -> None:
print("[Migration 49] Migration complete!")
def enum_exists(enum_name: str) -> bool:
"""Check if an enum type exists in the database."""
conn = op.get_bind()
result = conn.execute(
sa.text(
"SELECT EXISTS (SELECT 1 FROM pg_type WHERE typname = :enum_name)"
),
{"enum_name": enum_name},
)
return result.scalar()
def downgrade() -> None:
"""Recreate old chats table (data cannot be restored)."""
# Recreate chattype enum
# Skip if chats table already exists
if table_exists("chats"):
print("[Migration 49 Downgrade] Chats table already exists, skipping")
return
# Recreate chattype enum if it doesn't exist
if not enum_exists("chattype"):
op.execute(
sa.text("""
CREATE TYPE chattype AS ENUM ('QNA')
""")
)
# Recreate chats table using raw SQL to avoid SQLAlchemy trying to create the enum
op.execute(
sa.text("""
CREATE TYPE chattype AS ENUM ('QNA')
CREATE TABLE chats (
id SERIAL PRIMARY KEY,
type chattype NOT NULL,
title VARCHAR NOT NULL,
initial_connectors VARCHAR[],
messages JSON NOT NULL,
state_version BIGINT NOT NULL DEFAULT 1,
search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
)
""")
)
# Recreate chats table
op.create_table(
"chats",
sa.Column("id", sa.Integer(), primary_key=True, index=True),
sa.Column("type", sa.Enum("QNA", name="chattype"), nullable=False),
sa.Column("title", sa.String(), nullable=False, index=True),
sa.Column("initial_connectors", sa.ARRAY(sa.String()), nullable=True),
sa.Column("messages", sa.JSON(), nullable=False),
sa.Column("state_version", sa.BigInteger(), nullable=False, default=1),
sa.Column(
"search_space_id",
sa.Integer(),
sa.ForeignKey("searchspaces.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
op.execute(sa.text("CREATE INDEX ix_chats_id ON chats (id)"))
op.execute(sa.text("CREATE INDEX ix_chats_title ON chats (title)"))
print("[Migration 49 Downgrade] Chats table recreated (data not restored)")

View file

@ -0,0 +1,50 @@
"""allow_multiple_connectors_with_unique_names
Revision ID: 5263aa4e7f94
Revises: a1b2c3d4e5f6
Create Date: 2026-01-13 12:23:31.481643
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '5263aa4e7f94'
down_revision: str | None = 'a1b2c3d4e5f6'
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema."""
# Drop the old unique constraint
op.drop_constraint(
'uq_searchspace_user_connector_type',
'search_source_connectors',
type_='unique'
)
# Create new unique constraint that includes name
op.create_unique_constraint(
'uq_searchspace_user_connector_type_name',
'search_source_connectors',
['search_space_id', 'user_id', 'connector_type', 'name']
)
def downgrade() -> None:
"""Downgrade schema."""
# Drop the new constraint
op.drop_constraint(
'uq_searchspace_user_connector_type_name',
'search_source_connectors',
type_='unique'
)
# Restore the old constraint
op.create_unique_constraint(
'uq_searchspace_user_connector_type',
'search_source_connectors',
['search_space_id', 'user_id', 'connector_type']
)

View file

@ -39,7 +39,7 @@ def upgrade():
"""
)
# Rename columns (only if they exist with old names)
# Rename columns (only if source exists and target doesn't already exist)
op.execute(
"""
DO $$
@ -47,6 +47,9 @@ def upgrade():
IF EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'fast_llm_id'
) AND NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'agent_llm_id'
) THEN
ALTER TABLE searchspaces RENAME COLUMN fast_llm_id TO agent_llm_id;
END IF;
@ -61,6 +64,9 @@ def upgrade():
IF EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'long_context_llm_id'
) AND NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'document_summary_llm_id'
) THEN
ALTER TABLE searchspaces RENAME COLUMN long_context_llm_id TO document_summary_llm_id;
END IF;
@ -100,7 +106,7 @@ def downgrade():
"""
)
# Rename columns back
# Rename columns back (only if source exists and target doesn't already exist)
op.execute(
"""
DO $$
@ -108,6 +114,9 @@ def downgrade():
IF EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'agent_llm_id'
) AND NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'fast_llm_id'
) THEN
ALTER TABLE searchspaces RENAME COLUMN agent_llm_id TO fast_llm_id;
END IF;
@ -122,6 +131,9 @@ def downgrade():
IF EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'document_summary_llm_id'
) AND NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'long_context_llm_id'
) THEN
ALTER TABLE searchspaces RENAME COLUMN document_summary_llm_id TO long_context_llm_id;
END IF;

View file

@ -60,14 +60,28 @@ def downgrade() -> None:
connection = op.get_bind()
connection.execute(
# Only update if the target enum value exists (it won't on fresh databases)
result = connection.execute(
text(
"""
UPDATE documents
SET document_type = 'GOOGLE_DRIVE_CONNECTOR'
WHERE document_type = 'GOOGLE_DRIVE_FILE';
SELECT EXISTS (
SELECT 1 FROM pg_type t
JOIN pg_enum e ON t.oid = e.enumtypid
WHERE t.typname = 'documenttype' AND e.enumlabel = 'GOOGLE_DRIVE_CONNECTOR'
);
"""
)
)
enum_exists = result.scalar()
connection.commit()
if enum_exists:
connection.execute(
text(
"""
UPDATE documents
SET document_type = 'GOOGLE_DRIVE_CONNECTOR'
WHERE document_type = 'GOOGLE_DRIVE_FILE';
"""
)
)
connection.commit()

View file

@ -18,59 +18,77 @@ branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# Alter Chat table
op.alter_column(
"chats",
"title",
existing_type=sa.String(200),
type_=sa.String(),
existing_nullable=False,
def table_exists(table_name: str) -> bool:
"""Check if a table exists in the database."""
conn = op.get_bind()
result = conn.execute(
sa.text(
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = :table_name)"
),
{"table_name": table_name},
)
return result.scalar()
def upgrade() -> None:
# Alter Chat table (may not exist on fresh databases, removed in migration 49)
if table_exists("chats"):
op.alter_column(
"chats",
"title",
existing_type=sa.String(200),
type_=sa.String(),
existing_nullable=False,
)
# Alter Document table
op.alter_column(
"documents",
"title",
existing_type=sa.String(200),
type_=sa.String(),
existing_nullable=False,
)
if table_exists("documents"):
op.alter_column(
"documents",
"title",
existing_type=sa.String(200),
type_=sa.String(),
existing_nullable=False,
)
# Alter Podcast table
op.alter_column(
"podcasts",
"title",
existing_type=sa.String(200),
type_=sa.String(),
existing_nullable=False,
)
if table_exists("podcasts"):
op.alter_column(
"podcasts",
"title",
existing_type=sa.String(200),
type_=sa.String(),
existing_nullable=False,
)
def downgrade() -> None:
# Revert Chat table
op.alter_column(
"chats",
"title",
existing_type=sa.String(),
type_=sa.String(200),
existing_nullable=False,
)
if table_exists("chats"):
op.alter_column(
"chats",
"title",
existing_type=sa.String(),
type_=sa.String(200),
existing_nullable=False,
)
# Revert Document table
op.alter_column(
"documents",
"title",
existing_type=sa.String(),
type_=sa.String(200),
existing_nullable=False,
)
if table_exists("documents"):
op.alter_column(
"documents",
"title",
existing_type=sa.String(),
type_=sa.String(200),
existing_nullable=False,
)
# Revert Podcast table
op.alter_column(
"podcasts",
"title",
existing_type=sa.String(),
type_=sa.String(200),
existing_nullable=False,
)
if table_exists("podcasts"):
op.alter_column(
"podcasts",
"title",
existing_type=sa.String(),
type_=sa.String(200),
existing_nullable=False,
)

View file

@ -0,0 +1,72 @@
"""Add display_name and avatar_url columns to user table
This migration adds:
- display_name column for user's full name from OAuth
- avatar_url column for user's profile picture URL from OAuth
Revision ID: 62
Revises: 61
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "62"
down_revision: str | None = "61"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add display_name and avatar_url columns to user table."""
# Add display_name column (nullable for existing users)
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'user' AND column_name = 'display_name'
) THEN
ALTER TABLE "user"
ADD COLUMN display_name VARCHAR;
END IF;
END$$;
"""
)
# Add avatar_url column (nullable for existing users)
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'user' AND column_name = 'avatar_url'
) THEN
ALTER TABLE "user"
ADD COLUMN avatar_url VARCHAR;
END IF;
END$$;
"""
)
def downgrade() -> None:
"""Remove display_name and avatar_url columns from user table."""
op.execute(
"""
ALTER TABLE "user"
DROP COLUMN IF EXISTS avatar_url;
"""
)
op.execute(
"""
ALTER TABLE "user"
DROP COLUMN IF EXISTS display_name;
"""
)

View file

@ -0,0 +1,47 @@
"""Add author_id column to new_chat_messages table
Revision ID: 63
Revises: 62
"""
from collections.abc import Sequence
from alembic import op
revision: str = "63"
down_revision: str | None = "62"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add author_id column to new_chat_messages table."""
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'new_chat_messages' AND column_name = 'author_id'
) THEN
ALTER TABLE new_chat_messages
ADD COLUMN author_id UUID REFERENCES "user"(id) ON DELETE SET NULL;
CREATE INDEX IF NOT EXISTS ix_new_chat_messages_author_id
ON new_chat_messages(author_id);
END IF;
END$$;
"""
)
def downgrade() -> None:
"""Remove author_id column from new_chat_messages table."""
op.execute(
"""
DROP INDEX IF EXISTS ix_new_chat_messages_author_id;
ALTER TABLE new_chat_messages
DROP COLUMN IF EXISTS author_id;
"""
)

View file

@ -0,0 +1,37 @@
"""Add MCP connector type
Revision ID: a1b2c3d4e5f6
Revises: 61
Create Date: 2026-01-09 15:19:51.827647
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'a1b2c3d4e5f6'
down_revision: str | None = '61'
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add MCP_CONNECTOR to SearchSourceConnectorType enum."""
# Add new enum value using raw SQL
op.execute(
"""
ALTER TYPE searchsourceconnectortype ADD VALUE IF NOT EXISTS 'MCP_CONNECTOR';
"""
)
def downgrade() -> None:
"""Remove MCP_CONNECTOR from SearchSourceConnectorType enum."""
# Note: PostgreSQL does not support removing enum values directly.
# To downgrade, you would need to:
# 1. Create a new enum without MCP_CONNECTOR
# 2. Alter the column to use the new enum
# 3. Drop the old enum
# This is left as a manual operation if needed.
pass

View file

@ -20,7 +20,7 @@ from app.agents.new_chat.system_prompt import (
build_configurable_system_prompt,
build_surfsense_system_prompt,
)
from app.agents.new_chat.tools import build_tools
from app.agents.new_chat.tools.registry import build_tools_async
from app.services.connector_service import ConnectorService
# =============================================================================
@ -28,7 +28,7 @@ from app.services.connector_service import ConnectorService
# =============================================================================
def create_surfsense_deep_agent(
async def create_surfsense_deep_agent(
llm: ChatLiteLLM,
search_space_id: int,
db_session: AsyncSession,
@ -120,8 +120,8 @@ def create_surfsense_deep_agent(
"firecrawl_api_key": firecrawl_api_key,
}
# Build tools using the registry
tools = build_tools(
# Build tools using the async registry (includes MCP tools)
tools = await build_tools_async(
dependencies=dependencies,
enabled_tools=enabled_tools,
disabled_tools=disabled_tools,

View file

@ -0,0 +1,188 @@
"""MCP Client Wrapper.
This module provides a client for communicating with MCP servers via stdio transport.
It handles server lifecycle management, tool discovery, and tool execution.
"""
import logging
import os
from contextlib import asynccontextmanager
from typing import Any
from mcp import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client
logger = logging.getLogger(__name__)
class MCPClient:
"""Client for communicating with an MCP server."""
def __init__(self, command: str, args: list[str], env: dict[str, str] | None = None):
"""Initialize MCP client.
Args:
command: Command to spawn the MCP server (e.g., "uvx", "node")
args: Arguments for the command (e.g., ["mcp-server-git"])
env: Optional environment variables for the server process
"""
self.command = command
self.args = args
self.env = env or {}
self.session: ClientSession | None = None
@asynccontextmanager
async def connect(self):
"""Connect to the MCP server and manage its lifecycle.
Yields:
ClientSession: Active MCP session for making requests
"""
try:
# Merge env vars with current environment
server_env = os.environ.copy()
server_env.update(self.env)
# Create server parameters with env
server_params = StdioServerParameters(
command=self.command,
args=self.args,
env=server_env
)
# Spawn server process and create session
# Note: Cannot combine these context managers because ClientSession
# needs the read/write streams from stdio_client
async with stdio_client(server=server_params) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the connection
await session.initialize()
self.session = session
logger.info(
"Connected to MCP server: %s %s",
self.command,
" ".join(self.args),
)
yield session
except Exception as e:
logger.error("Failed to connect to MCP server: %s", e, exc_info=True)
raise
finally:
self.session = None
logger.info("Disconnected from MCP server: %s", self.command)
async def list_tools(self) -> list[dict[str, Any]]:
"""List all tools available from the MCP server.
Returns:
List of tool definitions with name, description, and input schema
Raises:
RuntimeError: If not connected to server
"""
if not self.session:
raise RuntimeError("Not connected to MCP server. Use 'async with client.connect():'")
try:
# Call tools/list RPC method
response = await self.session.list_tools()
tools = []
for tool in response.tools:
tools.append({
"name": tool.name,
"description": tool.description or "",
"input_schema": tool.inputSchema if hasattr(tool, "inputSchema") else {},
})
logger.info("Listed %d tools from MCP server", len(tools))
return tools
except Exception as e:
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
raise
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""Call a tool on the MCP server.
Args:
tool_name: Name of the tool to call
arguments: Arguments to pass to the tool
Returns:
Tool execution result
Raises:
RuntimeError: If not connected to server
"""
if not self.session:
raise RuntimeError("Not connected to MCP server. Use 'async with client.connect():'")
try:
logger.info("Calling MCP tool '%s' with arguments: %s", tool_name, arguments)
# Call tools/call RPC method
response = await self.session.call_tool(tool_name, arguments=arguments)
# Extract content from response
result = []
for content in response.content:
if hasattr(content, "text"):
result.append(content.text)
elif hasattr(content, "data"):
result.append(str(content.data))
else:
result.append(str(content))
result_str = "\n".join(result) if result else ""
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
return result_str
except RuntimeError as e:
# Handle validation errors from MCP server responses
# Some MCP servers (like server-memory) return extra fields not in their schema
if "Invalid structured content" in str(e):
logger.warning("MCP server returned data not matching its schema, but continuing: %s", e)
# Try to extract result from error message or return a success message
return "Operation completed (server returned unexpected format)"
raise
except (ValueError, TypeError, AttributeError, KeyError) as e:
logger.error("Failed to call MCP tool '%s': %s", tool_name, e, exc_info=True)
return f"Error calling tool: {e!s}"
async def test_mcp_connection(
command: str, args: list[str], env: dict[str, str] | None = None
) -> dict[str, Any]:
"""Test connection to an MCP server and fetch available tools.
Args:
command: Command to spawn the MCP server
args: Arguments for the command
env: Optional environment variables
Returns:
Dict with connection status and available tools
"""
client = MCPClient(command, args, env)
try:
async with client.connect():
tools = await client.list_tools()
return {
"status": "success",
"message": f"Connected successfully. Found {len(tools)} tools.",
"tools": tools,
}
except (RuntimeError, ConnectionError, TimeoutError, OSError) as e:
return {
"status": "error",
"message": f"Failed to connect: {e!s}",
"tools": [],
}

View file

@ -0,0 +1,189 @@
"""MCP Tool Factory.
This module creates LangChain tools from MCP servers using the Model Context Protocol.
Tools are dynamically discovered from MCP servers - no manual configuration needed.
This implements real MCP protocol support similar to Cursor's implementation.
"""
import logging
from typing import Any
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, create_model
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.mcp_client import MCPClient
from app.db import SearchSourceConnector, SearchSourceConnectorType
logger = logging.getLogger(__name__)
def _create_dynamic_input_model_from_schema(
tool_name: str, input_schema: dict[str, Any],
) -> type[BaseModel]:
"""Create a Pydantic model from MCP tool's JSON schema.
Args:
tool_name: Name of the tool (used for model class name)
input_schema: JSON schema from MCP server
Returns:
Pydantic model class for tool input validation
"""
properties = input_schema.get("properties", {})
required_fields = input_schema.get("required", [])
# Build Pydantic field definitions
field_definitions = {}
for param_name, param_schema in properties.items():
param_description = param_schema.get("description", "")
is_required = param_name in required_fields
# Use Any type for complex schemas to preserve structure
# This allows the MCP server to do its own validation
from typing import Any as AnyType
from pydantic import Field
if is_required:
field_definitions[param_name] = (AnyType, Field(..., description=param_description))
else:
field_definitions[param_name] = (
AnyType | None,
Field(None, description=param_description),
)
# Create dynamic model
model_name = f"{tool_name.replace(' ', '').replace('-', '_')}Input"
return create_model(model_name, **field_definitions)
async def _create_mcp_tool_from_definition(
tool_def: dict[str, Any],
mcp_client: MCPClient,
) -> StructuredTool:
"""Create a LangChain tool from an MCP tool definition.
Args:
tool_def: Tool definition from MCP server with name, description, input_schema
mcp_client: MCP client instance for calling the tool
Returns:
LangChain StructuredTool instance
"""
tool_name = tool_def.get("name", "unnamed_tool")
tool_description = tool_def.get("description", "No description provided")
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
# Log the actual schema for debugging
logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}")
# Create dynamic input model from schema
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
async def mcp_tool_call(**kwargs) -> str:
"""Execute the MCP tool call via the client."""
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
try:
# Connect to server and call tool
async with mcp_client.connect():
result = await mcp_client.call_tool(tool_name, kwargs)
return str(result)
except Exception as e:
error_msg = f"MCP tool '{tool_name}' failed: {e!s}"
logger.exception(error_msg)
return f"Error: {error_msg}"
# Create StructuredTool with response_format to preserve exact schema
tool = StructuredTool(
name=tool_name,
description=tool_description,
coroutine=mcp_tool_call,
args_schema=input_model,
# Store the original MCP schema as metadata so we can access it later
metadata={"mcp_input_schema": input_schema},
)
logger.info(f"Created MCP tool: '{tool_name}'")
return tool
async def load_mcp_tools(
session: AsyncSession, search_space_id: int,
) -> list[StructuredTool]:
"""Load all MCP tools from user's active MCP server connectors.
This discovers tools dynamically from MCP servers using the protocol.
Args:
session: Database session
search_space_id: User's search space ID
Returns:
List of LangChain StructuredTool instances
"""
try:
# Fetch all MCP connectors for this search space
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.connector_type
== SearchSourceConnectorType.MCP_CONNECTOR,
SearchSourceConnector.search_space_id == search_space_id,
),
)
tools: list[StructuredTool] = []
for connector in result.scalars():
try:
# Extract server config
config = connector.config or {}
server_config = config.get("server_config", {})
command = server_config.get("command")
args = server_config.get("args", [])
env = server_config.get("env", {})
if not command:
logger.warning(f"MCP connector {connector.id} missing command, skipping")
continue
# Create MCP client
mcp_client = MCPClient(command, args, env)
# Connect and discover tools
async with mcp_client.connect():
tool_definitions = await mcp_client.list_tools()
logger.info(
f"Discovered {len(tool_definitions)} tools from MCP server "
f"'{command}' (connector {connector.id})"
)
# Create LangChain tools from definitions
for tool_def in tool_definitions:
try:
tool = await _create_mcp_tool_from_definition(tool_def, mcp_client)
tools.append(tool)
except Exception as e:
logger.exception(
f"Failed to create tool '{tool_def.get('name')}' "
f"from connector {connector.id}: {e!s}",
)
except Exception as e:
logger.exception(
f"Failed to load tools from MCP connector {connector.id}: {e!s}",
)
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
return tools
except Exception as e:
logger.exception(f"Failed to load MCP tools: {e!s}")
return []

View file

@ -1,5 +1,4 @@
"""
Tools registry for SurfSense deep agent.
"""Tools registry for SurfSense deep agent.
This module provides a registry pattern for managing tools in the SurfSense agent.
It makes it easy for OSS contributors to add new tools by:
@ -37,6 +36,7 @@ Example of adding a new tool:
),
"""
import logging
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any
@ -46,6 +46,7 @@ from langchain_core.tools import BaseTool
from .display_image import create_display_image_tool
from .knowledge_base import create_search_knowledge_base_tool
from .link_preview import create_link_preview_tool
from .mcp_tool import load_mcp_tools
from .podcast import create_generate_podcast_tool
from .scrape_webpage import create_scrape_webpage_tool
from .search_surfsense_docs import create_search_surfsense_docs_tool
@ -57,8 +58,7 @@ from .search_surfsense_docs import create_search_surfsense_docs_tool
@dataclass
class ToolDefinition:
"""
Definition of a tool that can be added to the agent.
"""Definition of a tool that can be added to the agent.
Attributes:
name: Unique identifier for the tool
@ -66,6 +66,7 @@ class ToolDefinition:
factory: Callable that creates the tool. Receives a dict of dependencies.
requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session")
enabled_by_default: Whether the tool is enabled when no explicit config is provided
"""
name: str
@ -178,8 +179,7 @@ def build_tools(
disabled_tools: list[str] | None = None,
additional_tools: list[BaseTool] | None = None,
) -> list[BaseTool]:
"""
Build the list of tools for the agent.
"""Build the list of tools for the agent.
Args:
dependencies: Dict containing all possible dependencies:
@ -206,6 +206,7 @@ def build_tools(
# Add custom tools
tools = build_tools(deps, additional_tools=[my_custom_tool])
"""
# Determine which tools to enable
if enabled_tools is not None:
@ -226,8 +227,9 @@ def build_tools(
# Check that all required dependencies are provided
missing_deps = [dep for dep in tool_def.requires if dep not in dependencies]
if missing_deps:
msg = f"Tool '{tool_def.name}' requires dependencies: {missing_deps}"
raise ValueError(
f"Tool '{tool_def.name}' requires dependencies: {missing_deps}"
msg,
)
# Create the tool
@ -239,3 +241,61 @@ def build_tools(
tools.extend(additional_tools)
return tools
async def build_tools_async(
dependencies: dict[str, Any],
enabled_tools: list[str] | None = None,
disabled_tools: list[str] | None = None,
additional_tools: list[BaseTool] | None = None,
include_mcp_tools: bool = True,
) -> list[BaseTool]:
"""Async version of build_tools that also loads MCP tools from database.
Design Note:
This function exists because MCP tools require database queries to load user configs,
while built-in tools are created synchronously from static code.
Alternative: We could make build_tools() itself async and always query the database,
but that would force async everywhere even when only using built-in tools. The current
design keeps the simple case (static tools only) synchronous while supporting dynamic
database-loaded tools through this async wrapper.
Args:
dependencies: Dict containing all possible dependencies
enabled_tools: Explicit list of tool names to enable. If None, uses defaults.
disabled_tools: List of tool names to disable (applied after enabled_tools).
additional_tools: Extra tools to add (e.g., custom tools not in registry).
include_mcp_tools: Whether to load user's MCP tools from database.
Returns:
List of configured tool instances ready for the agent, including MCP tools.
"""
# Build standard tools
tools = build_tools(dependencies, enabled_tools, disabled_tools, additional_tools)
# Load MCP tools if requested and dependencies are available
if (
include_mcp_tools
and "db_session" in dependencies
and "search_space_id" in dependencies
):
try:
mcp_tools = await load_mcp_tools(
dependencies["db_session"], dependencies["search_space_id"],
)
tools.extend(mcp_tools)
logging.info(
f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}",
)
except Exception as e:
# Log error but don't fail - just continue without MCP tools
logging.exception(f"Failed to load MCP tools: {e!s}")
# Log all tools being returned to agent
logging.info(
f"Total tools for agent: {len(tools)} - {[t.name for t in tools]}",
)
return tools

View file

@ -80,6 +80,7 @@ class SearchSourceConnectorType(str, Enum):
WEBCRAWLER_CONNECTOR = "WEBCRAWLER_CONNECTOR"
BOOKSTACK_CONNECTOR = "BOOKSTACK_CONNECTOR"
CIRCLEBACK_CONNECTOR = "CIRCLEBACK_CONNECTOR"
MCP_CONNECTOR = "MCP_CONNECTOR" # Model Context Protocol - User-defined API tools
class LiteLLMProvider(str, Enum):
@ -412,8 +413,17 @@ class NewChatMessage(BaseModel, TimestampMixin):
index=True,
)
# Relationship
# Track who sent this message (for shared chats)
author_id = Column(
UUID(as_uuid=True),
ForeignKey("user.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
# Relationships
thread = relationship("NewChatThread", back_populates="messages")
author = relationship("User")
class Document(BaseModel, TimestampMixin):
@ -611,7 +621,8 @@ class SearchSourceConnector(BaseModel, TimestampMixin):
"search_space_id",
"user_id",
"connector_type",
name="uq_searchspace_user_connector_type",
"name",
name="uq_searchspace_user_connector_type_name",
),
)
@ -919,6 +930,10 @@ if config.AUTH_TYPE == "GOOGLE":
)
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
# User profile from OAuth
display_name = Column(String, nullable=True)
avatar_url = Column(String, nullable=True)
else:
class User(SQLAlchemyBaseUserTableUUID, Base):
@ -958,6 +973,10 @@ else:
)
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
# User profile (can be set manually for non-OAuth users)
display_name = Column(String, nullable=True)
avatar_url = Column(String, nullable=True)
engine = create_async_engine(DATABASE_URL)
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)

View file

@ -411,11 +411,9 @@ async def get_thread_messages(
Requires CHATS_READ permission.
"""
try:
# Get thread with messages
# Get thread first
result = await session.execute(
select(NewChatThread)
.options(selectinload(NewChatThread.messages))
.filter(NewChatThread.id == thread_id)
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
@ -434,6 +432,15 @@ async def get_thread_messages(
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
# Get messages with their authors loaded
messages_result = await session.execute(
select(NewChatMessage)
.options(selectinload(NewChatMessage.author))
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at)
)
db_messages = messages_result.scalars().all()
# Return messages in the format expected by assistant-ui
messages = [
NewChatMessageRead(
@ -442,8 +449,11 @@ async def get_thread_messages(
role=msg.role,
content=msg.content,
created_at=msg.created_at,
author_id=msg.author_id,
author_display_name=msg.author.display_name if msg.author else None,
author_avatar_url=msg.author.avatar_url if msg.author else None,
)
for msg in thread.messages
for msg in db_messages
]
return ThreadHistoryLoadResponse(messages=messages)
@ -782,6 +792,7 @@ async def append_message(
thread_id=thread_id,
role=message_role,
content=message.content,
author_id=user.id,
)
session.add(db_message)

View file

@ -7,6 +7,13 @@ PUT /search-source-connectors/{connector_id} - Update a specific connector
DELETE /search-source-connectors/{connector_id} - Delete a specific connector
POST /search-source-connectors/{connector_id}/index - Index content from a connector to a search space
MCP (Model Context Protocol) Connector routes:
POST /connectors/mcp - Create a new MCP connector with custom API tools
GET /connectors/mcp - List all MCP connectors for the current user's search space
GET /connectors/mcp/{connector_id} - Get a specific MCP connector with tools config
PUT /connectors/mcp/{connector_id} - Update an MCP connector's tools config
DELETE /connectors/mcp/{connector_id} - Delete an MCP connector
Note: OAuth connectors (Gmail, Drive, Slack, etc.) support multiple accounts per search space.
Non-OAuth connectors (BookStack, GitHub, etc.) are limited to one per search space.
"""
@ -32,6 +39,9 @@ from app.db import (
)
from app.schemas import (
GoogleDriveIndexRequest,
MCPConnectorCreate,
MCPConnectorRead,
MCPConnectorUpdate,
SearchSourceConnectorBase,
SearchSourceConnectorCreate,
SearchSourceConnectorRead,
@ -128,18 +138,20 @@ async def create_search_source_connector(
# Check if a connector with the same type already exists for this search space
# (for non-OAuth connectors that don't support multiple accounts)
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.connector_type == connector.connector_type,
)
)
existing_connector = result.scalars().first()
if existing_connector:
raise HTTPException(
status_code=409,
detail=f"A connector with type {connector.connector_type} already exists in this search space.",
# Exception: MCP_CONNECTOR can have multiple instances with different names
if connector.connector_type != SearchSourceConnectorType.MCP_CONNECTOR:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.connector_type == connector.connector_type,
)
)
existing_connector = result.scalars().first()
if existing_connector:
raise HTTPException(
status_code=409,
detail=f"A connector with type {connector.connector_type} already exists in this search space.",
)
# Prepare connector data
connector_data = connector.model_dump()
@ -2054,3 +2066,348 @@ async def run_bookstack_indexing(
indexing_function=index_bookstack_pages,
update_timestamp_func=_update_connector_timestamp_by_id,
)
# =============================================================================
# MCP Connector Routes
# =============================================================================
@router.post("/connectors/mcp", response_model=MCPConnectorRead, status_code=201)
async def create_mcp_connector(
connector_data: MCPConnectorCreate,
search_space_id: int = Query(..., description="Search space ID"),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Create a new MCP (Model Context Protocol) connector.
MCP connectors allow users to connect to MCP servers (like in Cursor).
Tools are auto-discovered from the server - no manual configuration needed.
Args:
connector_data: MCP server configuration (command, args, env)
search_space_id: ID of the search space to attach the connector to
session: Database session
user: Current authenticated user
Returns:
Created MCP connector with server configuration
Raises:
HTTPException: If search space not found or permission denied
"""
try:
# Check user has permission to create connectors
await check_permission(
session,
user,
search_space_id,
Permission.CONNECTORS_CREATE.value,
"You don't have permission to create connectors in this search space",
)
# Create the connector with server config
db_connector = SearchSourceConnector(
name=connector_data.name,
connector_type=SearchSourceConnectorType.MCP_CONNECTOR,
is_indexable=False, # MCP connectors are not indexable
config={"server_config": connector_data.server_config.model_dump()},
periodic_indexing_enabled=False,
indexing_frequency_minutes=None,
search_space_id=search_space_id,
user_id=user.id,
)
session.add(db_connector)
await session.commit()
await session.refresh(db_connector)
logger.info(
f"Created MCP connector {db_connector.id} for server '{connector_data.server_config.command}' "
f"for user {user.id} in search space {search_space_id}"
)
# Convert to read schema
connector_read = SearchSourceConnectorRead.model_validate(db_connector)
return MCPConnectorRead.from_connector(connector_read)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to create MCP connector: {e!s}", exc_info=True)
await session.rollback()
raise HTTPException(
status_code=500, detail=f"Failed to create MCP connector: {e!s}"
) from e
@router.get("/connectors/mcp", response_model=list[MCPConnectorRead])
async def list_mcp_connectors(
search_space_id: int = Query(..., description="Search space ID"),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
List all MCP connectors for a search space.
Args:
search_space_id: ID of the search space
session: Database session
user: Current authenticated user
Returns:
List of MCP connectors with their tool configurations
"""
try:
# Check user has permission to read connectors
await check_permission(
session,
user,
search_space_id,
Permission.CONNECTORS_READ.value,
"You don't have permission to view connectors in this search space",
)
# Fetch MCP connectors
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.connector_type
== SearchSourceConnectorType.MCP_CONNECTOR,
SearchSourceConnector.search_space_id == search_space_id,
)
)
connectors = result.scalars().all()
return [
MCPConnectorRead.from_connector(SearchSourceConnectorRead.model_validate(c))
for c in connectors
]
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to list MCP connectors: {e!s}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"Failed to list MCP connectors: {e!s}"
) from e
@router.get("/connectors/mcp/{connector_id}", response_model=MCPConnectorRead)
async def get_mcp_connector(
connector_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get a specific MCP connector by ID.
Args:
connector_id: ID of the connector
session: Database session
user: Current authenticated user
Returns:
MCP connector with tool configurations
"""
try:
# Fetch connector
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.MCP_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
raise HTTPException(status_code=404, detail="MCP connector not found")
# Check user has permission to read connectors
await check_permission(
session,
user,
connector.search_space_id,
Permission.CONNECTORS_READ.value,
"You don't have permission to view this connector",
)
connector_read = SearchSourceConnectorRead.model_validate(connector)
return MCPConnectorRead.from_connector(connector_read)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get MCP connector: {e!s}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"Failed to get MCP connector: {e!s}"
) from e
@router.put("/connectors/mcp/{connector_id}", response_model=MCPConnectorRead)
async def update_mcp_connector(
connector_id: int,
connector_update: MCPConnectorUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Update an MCP connector.
Args:
connector_id: ID of the connector to update
connector_update: Updated connector data
session: Database session
user: Current authenticated user
Returns:
Updated MCP connector
"""
try:
# Fetch connector
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.MCP_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
raise HTTPException(status_code=404, detail="MCP connector not found")
# Check user has permission to update connectors
await check_permission(
session,
user,
connector.search_space_id,
Permission.CONNECTORS_UPDATE.value,
"You don't have permission to update this connector",
)
# Update fields
if connector_update.name is not None:
connector.name = connector_update.name
if connector_update.server_config is not None:
connector.config = {
"server_config": connector_update.server_config.model_dump()
}
connector.updated_at = datetime.now(UTC)
await session.commit()
await session.refresh(connector)
logger.info(f"Updated MCP connector {connector_id}")
connector_read = SearchSourceConnectorRead.model_validate(connector)
return MCPConnectorRead.from_connector(connector_read)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to update MCP connector: {e!s}", exc_info=True)
await session.rollback()
raise HTTPException(
status_code=500, detail=f"Failed to update MCP connector: {e!s}"
) from e
@router.delete("/connectors/mcp/{connector_id}", status_code=204)
async def delete_mcp_connector(
connector_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Delete an MCP connector.
Args:
connector_id: ID of the connector to delete
session: Database session
user: Current authenticated user
"""
try:
# Fetch connector
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.MCP_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
raise HTTPException(status_code=404, detail="MCP connector not found")
# Check user has permission to delete connectors
await check_permission(
session,
user,
connector.search_space_id,
Permission.CONNECTORS_DELETE.value,
"You don't have permission to delete this connector",
)
await session.delete(connector)
await session.commit()
logger.info(f"Deleted MCP connector {connector_id}")
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to delete MCP connector: {e!s}", exc_info=True)
await session.rollback()
raise HTTPException(
status_code=500, detail=f"Failed to delete MCP connector: {e!s}"
) from e
@router.post("/connectors/mcp/test")
async def test_mcp_server_connection(
server_config: dict = Body(...),
user: User = Depends(current_active_user),
):
"""
Test connection to an MCP server and fetch available tools.
This endpoint allows users to test their MCP server configuration
before saving it, similar to Cursor's flow.
Args:
server_config: Server configuration with command, args, env
user: Current authenticated user
Returns:
Connection status and list of available tools
"""
try:
from app.agents.new_chat.tools.mcp_client import test_mcp_connection
command = server_config.get("command")
args = server_config.get("args", [])
env = server_config.get("env", {})
if not command:
raise HTTPException(status_code=400, detail="Server command is required")
# Test the connection
result = await test_mcp_connection(command, args, env)
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to test MCP connection: {e!s}", exc_info=True)
return {
"status": "error",
"message": f"Failed to test connection: {e!s}",
"tools": [],
}

View file

@ -55,6 +55,10 @@ from .rbac_schemas import (
UserSearchSpaceAccess,
)
from .search_source_connector import (
MCPConnectorCreate,
MCPConnectorRead,
MCPConnectorUpdate,
MCPServerConfig,
SearchSourceConnectorBase,
SearchSourceConnectorCreate,
SearchSourceConnectorRead,
@ -108,6 +112,11 @@ __all__ = [
"LogFilter",
"LogRead",
"LogUpdate",
# Search source connector schemas
"MCPConnectorCreate",
"MCPConnectorRead",
"MCPConnectorUpdate",
"MCPServerConfig",
"MembershipRead",
"MembershipReadWithUser",
"MembershipUpdate",
@ -135,7 +144,6 @@ __all__ = [
"RoleCreate",
"RoleRead",
"RoleUpdate",
# Search source connector schemas
"SearchSourceConnectorBase",
"SearchSourceConnectorCreate",
"SearchSourceConnectorRead",

View file

@ -38,6 +38,9 @@ class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
"""Schema for reading a message."""
thread_id: int
author_id: UUID | None = None
author_display_name: str | None = None
author_avatar_url: str | None = None
model_config = ConfigDict(from_attributes=True)

View file

@ -23,7 +23,7 @@ class SearchSourceConnectorBase(BaseModel):
@field_validator("config")
@classmethod
def validate_config_for_connector_type(
cls, config: dict[str, Any], values: dict[str, Any]
cls, config: dict[str, Any], values: dict[str, Any],
) -> dict[str, Any]:
connector_type = values.data.get("connector_type")
return validate_connector_config(connector_type, config)
@ -38,15 +38,18 @@ class SearchSourceConnectorBase(BaseModel):
"""
if self.periodic_indexing_enabled:
if not self.is_indexable:
msg = "periodic_indexing_enabled can only be True for indexable connectors"
raise ValueError(
"periodic_indexing_enabled can only be True for indexable connectors"
msg,
)
if self.indexing_frequency_minutes is None:
msg = "indexing_frequency_minutes is required when periodic_indexing_enabled is True"
raise ValueError(
"indexing_frequency_minutes is required when periodic_indexing_enabled is True"
msg,
)
if self.indexing_frequency_minutes <= 0:
raise ValueError("indexing_frequency_minutes must be greater than 0")
msg = "indexing_frequency_minutes must be greater than 0"
raise ValueError(msg)
return self
@ -70,3 +73,63 @@ class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampMod
user_id: uuid.UUID
model_config = ConfigDict(from_attributes=True)
# =============================================================================
# MCP-specific schemas
# =============================================================================
class MCPServerConfig(BaseModel):
"""Configuration for an MCP server connection (similar to Cursor's config)."""
command: str # e.g., "uvx", "node", "python"
args: list[str] = [] # e.g., ["mcp-server-git", "--repository", "/path"]
env: dict[str, str] = {} # Environment variables for the server process
transport: str = "stdio" # "stdio" | "sse" | "http" (stdio is most common)
class MCPConnectorCreate(BaseModel):
"""Schema for creating an MCP connector."""
name: str
server_config: MCPServerConfig
class MCPConnectorUpdate(BaseModel):
"""Schema for updating an MCP connector."""
name: str | None = None
server_config: MCPServerConfig | None = None
class MCPConnectorRead(BaseModel):
"""Schema for reading an MCP connector with server config."""
id: int
name: str
connector_type: SearchSourceConnectorType
server_config: MCPServerConfig
search_space_id: int
user_id: uuid.UUID
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
@classmethod
def from_connector(cls, connector: SearchSourceConnectorRead) -> "MCPConnectorRead":
"""Convert from base SearchSourceConnectorRead."""
config = connector.config or {}
server_config = MCPServerConfig(**config.get("server_config", {}))
return cls(
id=connector.id,
name=connector.name,
connector_type=connector.connector_type,
server_config=server_config,
search_space_id=connector.search_space_id,
user_id=connector.user_id,
created_at=connector.created_at,
updated_at=connector.updated_at,
)

View file

@ -6,6 +6,8 @@ from fastapi_users import schemas
class UserRead(schemas.BaseUser[uuid.UUID]):
pages_limit: int
pages_used: int
display_name: str | None = None
avatar_url: str | None = None
class UserCreate(schemas.BaseUserCreate):
@ -13,4 +15,5 @@ class UserCreate(schemas.BaseUserCreate):
class UserUpdate(schemas.BaseUserUpdate):
pass
display_name: str | None = None
avatar_url: str | None = None

View file

@ -237,7 +237,7 @@ async def stream_new_chat(
checkpointer = await get_checkpointer()
# Create the deep agent with checkpointer and configurable prompts
agent = create_surfsense_deep_agent(
agent = await create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
db_session=session,

View file

@ -1,6 +1,7 @@
import logging
import uuid
import httpx
from fastapi import Depends, Request, Response
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
@ -46,6 +47,71 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = SECRET
verification_token_secret = SECRET
async def oauth_callback(
self,
oauth_name: str,
access_token: str,
account_id: str,
account_email: str,
expires_at: int | None = None,
refresh_token: str | None = None,
request: Request | None = None,
*,
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> User:
"""
Override OAuth callback to capture Google profile data (name, avatar).
"""
# Call parent implementation to create/get user
user = await super().oauth_callback(
oauth_name,
access_token,
account_id,
account_email,
expires_at,
refresh_token,
request,
associate_by_email=associate_by_email,
is_verified_by_default=is_verified_by_default,
)
# Fetch and store Google profile data if not already set
if oauth_name == "google" and (not user.display_name or not user.avatar_url):
try:
async with httpx.AsyncClient() as client:
response = await client.get(
"https://people.googleapis.com/v1/people/me",
params={"personFields": "names,photos"},
headers={"Authorization": f"Bearer {access_token}"},
)
response.raise_for_status()
profile = response.json()
update_dict = {}
# Extract name from names array
names = profile.get("names", [])
if not user.display_name and names:
display_name = names[0].get("displayName")
if display_name:
update_dict["display_name"] = display_name
# Extract photo URL from photos array
photos = profile.get("photos", [])
if not user.avatar_url and photos:
photo_url = photos[0].get("url")
if photo_url:
update_dict["avatar_url"] = photo_url
if update_dict:
user = await self.user_db.update(user, update_dict)
except Exception as e:
logger.warning(f"Failed to fetch Google profile: {e}")
return user
async def on_after_register(self, user: User, request: Request | None = None):
"""
Called after a user registers. Creates a default search space for the user

View file

@ -8,9 +8,9 @@ from typing import Any
from urllib.parse import urlparse
from uuid import UUID
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.sql import func
from app.db import SearchSourceConnector, SearchSourceConnectorType
@ -27,6 +27,7 @@ BASE_NAME_FOR_TYPE = {
SearchSourceConnectorType.DISCORD_CONNECTOR: "Discord",
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "Confluence",
SearchSourceConnectorType.AIRTABLE_CONNECTOR: "Airtable",
SearchSourceConnectorType.MCP_CONNECTOR: "Model Context Protocol (MCP)",
}
@ -75,7 +76,7 @@ def extract_identifier_from_credentials(
if ".atlassian.net" in hostname:
return hostname.replace(".atlassian.net", "")
return hostname
except Exception:
except (ValueError, TypeError, AttributeError):
pass
return None

View file

@ -57,6 +57,9 @@ dependencies = [
"chonkie[all]>=1.5.0",
"langgraph-checkpoint-postgres>=3.0.2",
"psycopg[binary,pool]>=3.3.2",
"mcp>=1.25.0",
"starlette>=0.40.0,<0.51.0",
"sse-starlette>=3.1.1,<3.1.2",
]
[dependency-groups]

View file

@ -79,25 +79,17 @@ export function DocumentsTableShell({
[documents, sortKey, sortDesc]
);
// Filter out SURFSENSE_DOCS for selection purposes
const selectableDocs = React.useMemo(
() => sorted.filter((d) => d.document_type !== "SURFSENSE_DOCS"),
[sorted]
);
const allSelectedOnPage =
selectableDocs.length > 0 && selectableDocs.every((d) => selectedIds.has(d.id));
const someSelectedOnPage =
selectableDocs.some((d) => selectedIds.has(d.id)) && !allSelectedOnPage;
const allSelectedOnPage = sorted.length > 0 && sorted.every((d) => selectedIds.has(d.id));
const someSelectedOnPage = sorted.some((d) => selectedIds.has(d.id)) && !allSelectedOnPage;
const toggleAll = (checked: boolean) => {
const next = new Set(selectedIds);
if (checked)
selectableDocs.forEach((d) => {
sorted.forEach((d) => {
next.add(d.id);
});
else
selectableDocs.forEach((d) => {
sorted.forEach((d) => {
next.delete(d.id);
});
setSelectedIds(next);
@ -238,10 +230,9 @@ export function DocumentsTableShell({
const icon = getDocumentTypeIcon(doc.document_type);
const title = doc.title;
const truncatedTitle = title.length > 30 ? `${title.slice(0, 30)}...` : title;
const isSurfsenseDoc = doc.document_type === "SURFSENSE_DOCS";
return (
<motion.tr
key={`${doc.document_type}-${doc.id}`}
key={doc.id}
initial={{ opacity: 0, y: 10 }}
animate={{
opacity: 1,
@ -258,9 +249,8 @@ export function DocumentsTableShell({
>
<TableCell className="px-4 py-3">
<Checkbox
checked={selectedIds.has(doc.id) && !isSurfsenseDoc}
onCheckedChange={(v) => !isSurfsenseDoc && toggleOne(doc.id, !!v)}
disabled={isSurfsenseDoc}
checked={selectedIds.has(doc.id)}
onCheckedChange={(v) => toggleOne(doc.id, !!v)}
aria-label="Select row"
/>
</TableCell>

View file

@ -18,7 +18,7 @@ import { cacheKeys } from "@/lib/query-client/cache-keys";
import { DocumentsFilters } from "./components/DocumentsFilters";
import { DocumentsTableShell, type SortKey } from "./components/DocumentsTableShell";
import { PaginationControls } from "./components/PaginationControls";
import type { ColumnVisibility, Document } from "./components/types";
import type { ColumnVisibility } from "./components/types";
function useDebounced<T>(value: T, delay = 250) {
const [debounced, setDebounced] = useState(value);
@ -58,39 +58,30 @@ export default function DocumentsTable() {
const { data: rawTypeCounts } = useAtomValue(documentTypeCountsAtom);
const { mutateAsync: deleteDocumentMutation } = useAtomValue(deleteDocumentMutationAtom);
// Filter out SURFSENSE_DOCS from active types for regular documents API
const regularDocumentTypes = useMemo(
() => activeTypes.filter((t) => t !== "SURFSENSE_DOCS"),
[activeTypes]
);
// Check if only SURFSENSE_DOCS is selected (skip regular docs query)
const onlySurfsenseDocsSelected = activeTypes.length === 1 && activeTypes[0] === "SURFSENSE_DOCS";
// Build query parameters for fetching documents (excluding SURFSENSE_DOCS type)
// Build query parameters for fetching documents
const queryParams = useMemo(
() => ({
search_space_id: searchSpaceId,
page: pageIndex,
page_size: pageSize,
...(regularDocumentTypes.length > 0 && { document_types: regularDocumentTypes }),
...(activeTypes.length > 0 && { document_types: activeTypes }),
}),
[searchSpaceId, pageIndex, pageSize, regularDocumentTypes]
[searchSpaceId, pageIndex, pageSize, activeTypes]
);
// Build search query parameters (excluding SURFSENSE_DOCS type)
// Build search query parameters
const searchQueryParams = useMemo(
() => ({
search_space_id: searchSpaceId,
page: pageIndex,
page_size: pageSize,
title: debouncedSearch.trim(),
...(regularDocumentTypes.length > 0 && { document_types: regularDocumentTypes }),
...(activeTypes.length > 0 && { document_types: activeTypes }),
}),
[searchSpaceId, pageIndex, pageSize, regularDocumentTypes, debouncedSearch]
[searchSpaceId, pageIndex, pageSize, activeTypes, debouncedSearch]
);
// Use query for fetching documents (disabled when only SURFSENSE_DOCS is selected)
// Use query for fetching documents
const {
data: documentsResponse,
isLoading: isDocumentsLoading,
@ -100,10 +91,10 @@ export default function DocumentsTable() {
queryKey: cacheKeys.documents.globalQueryParams(queryParams),
queryFn: () => documentsApiService.getDocuments({ queryParams }),
staleTime: 3 * 60 * 1000, // 3 minutes
enabled: !!searchSpaceId && !debouncedSearch.trim() && !onlySurfsenseDocsSelected,
enabled: !!searchSpaceId && !debouncedSearch.trim(),
});
// Use query for searching documents (disabled when only SURFSENSE_DOCS is selected)
// Use query for searching documents
const {
data: searchResponse,
isLoading: isSearchLoading,
@ -113,7 +104,7 @@ export default function DocumentsTable() {
queryKey: cacheKeys.documents.globalQueryParams(searchQueryParams),
queryFn: () => documentsApiService.searchDocuments({ queryParams: searchQueryParams }),
staleTime: 3 * 60 * 1000, // 3 minutes
enabled: !!searchSpaceId && !!debouncedSearch.trim() && !onlySurfsenseDocsSelected,
enabled: !!searchSpaceId && !!debouncedSearch.trim(),
});
// Determine if we should show SurfSense docs (when no type filter or SURFSENSE_DOCS is selected)
@ -163,64 +154,16 @@ export default function DocumentsTable() {
}, [rawTypeCounts, surfsenseDocsResponse?.total]);
// Extract documents and total based on search state
const regularDocuments = debouncedSearch.trim()
const documents = debouncedSearch.trim()
? searchResponse?.items || []
: documentsResponse?.items || [];
const regularTotal = debouncedSearch.trim()
? searchResponse?.total || 0
: documentsResponse?.total || 0;
const total = debouncedSearch.trim() ? searchResponse?.total || 0 : documentsResponse?.total || 0;
// Merge regular documents with SurfSense docs
const documents = useMemo(() => {
// If filtering by type and not including SURFSENSE_DOCS, only show regular docs
if (activeTypes.length > 0 && !activeTypes.includes("SURFSENSE_DOCS" as DocumentTypeEnum)) {
return regularDocuments;
}
// If filtering only by SURFSENSE_DOCS, only show surfsense docs
if (activeTypes.length === 1 && activeTypes[0] === "SURFSENSE_DOCS") {
return surfsenseDocsAsDocuments;
}
// Otherwise, merge both (surfsense docs first)
return [...surfsenseDocsAsDocuments, ...regularDocuments];
}, [regularDocuments, surfsenseDocsAsDocuments, activeTypes]);
const loading = debouncedSearch.trim() ? isSearchLoading : isDocumentsLoading;
const error = debouncedSearch.trim() ? searchError : documentsError;
const total = useMemo(() => {
if (activeTypes.length > 0 && !activeTypes.includes("SURFSENSE_DOCS" as DocumentTypeEnum)) {
return regularTotal;
}
if (activeTypes.length === 1 && activeTypes[0] === "SURFSENSE_DOCS") {
return surfsenseDocsResponse?.total || 0;
}
return regularTotal + (surfsenseDocsResponse?.total || 0);
}, [regularTotal, surfsenseDocsResponse?.total, activeTypes]);
const loading = useMemo(() => {
// If only SURFSENSE_DOCS selected, only check surfsense loading
if (onlySurfsenseDocsSelected) {
return isSurfsenseDocsLoading;
}
// Otherwise check both regular docs and surfsense docs loading
const regularLoading = debouncedSearch.trim() ? isSearchLoading : isDocumentsLoading;
return regularLoading || (showSurfsenseDocs && isSurfsenseDocsLoading);
}, [
onlySurfsenseDocsSelected,
isSurfsenseDocsLoading,
debouncedSearch,
isSearchLoading,
isDocumentsLoading,
showSurfsenseDocs,
]);
const error = useMemo(() => {
// If only SURFSENSE_DOCS selected, no regular docs errors
if (onlySurfsenseDocsSelected) {
return null;
}
return debouncedSearch.trim() ? searchError : documentsError;
}, [onlySurfsenseDocsSelected, debouncedSearch, searchError, documentsError]);
// Display server-filtered results directly
const displayDocs = documents || [];
// Display results directly
const displayDocs = documents;
const displayTotal = total;
const pageStart = pageIndex * pageSize;
const pageEnd = Math.min(pageStart + pageSize, displayTotal);
@ -240,33 +183,16 @@ export default function DocumentsTable() {
if (isRefreshing) return;
setIsRefreshing(true);
try {
const refetchPromises: Promise<unknown>[] = [];
// Only refetch regular documents if not in "only surfsense docs" mode
if (!onlySurfsenseDocsSelected) {
if (debouncedSearch.trim()) {
refetchPromises.push(refetchSearch());
} else {
refetchPromises.push(refetchDocuments());
}
if (debouncedSearch.trim()) {
await refetchSearch();
} else {
await refetchDocuments();
}
if (showSurfsenseDocs) {
refetchPromises.push(refetchSurfsenseDocs());
}
await Promise.all(refetchPromises);
toast.success(t("refresh_success") || "Documents refreshed");
} finally {
setIsRefreshing(false);
}
}, [
debouncedSearch,
refetchSearch,
refetchDocuments,
refetchSurfsenseDocs,
showSurfsenseDocs,
onlySurfsenseDocsSelected,
t,
isRefreshing,
]);
}, [debouncedSearch, refetchSearch, refetchDocuments, t, isRefreshing]);
// Create a delete function for single document deletion
const deleteDocument = useCallback(
@ -357,7 +283,7 @@ export default function DocumentsTable() {
</motion.div>
<DocumentsFilters
typeCounts={typeCounts ?? {}}
typeCounts={rawTypeCounts ?? {}}
selectedIds={selectedIds}
onSearch={setSearch}
searchValue={search}

View file

@ -23,6 +23,7 @@ import {
// extractWriteTodosFromContent,
hydratePlanStateAtom,
} from "@/atoms/chat/plan-state.atom";
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
import { Thread } from "@/components/assistant-ui/thread";
import { ChatHeader } from "@/components/new-chat/chat-header";
import type { ThinkingStep } from "@/components/tool-ui/deepagent-thinking";
@ -185,12 +186,25 @@ function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike {
}
}
// Build metadata.custom for author display in shared chats
const metadata = msg.author_id
? {
custom: {
author: {
displayName: msg.author_display_name ?? null,
avatarUrl: msg.author_avatar_url ?? null,
},
},
}
: undefined;
return {
id: `msg-${msg.id}`,
role: msg.role,
content,
createdAt: new Date(msg.created_at),
attachments,
metadata,
};
}
@ -238,6 +252,9 @@ export default function NewChatPage() {
const setMessageDocumentsMap = useSetAtom(messageDocumentsMapAtom);
const hydratePlanState = useSetAtom(hydratePlanStateAtom);
// Get current user for author info in shared chats
const { data: currentUser } = useAtomValue(currentUserAtom);
// Create the attachment adapter for file processing
const attachmentAdapter = useMemo(() => createAttachmentAdapter(), []);
@ -306,12 +323,6 @@ export default function NewChatPage() {
if (steps.length > 0) {
restoredThinkingSteps.set(`msg-${msg.id}`, steps);
}
// Hydrate write_todos plan state from persisted tool calls
// Disabled for now
// const writeTodosCalls = extractWriteTodosFromContent(msg.content);
// for (const todoData of writeTodosCalls) {
// hydratePlanState(todoData);
// }
}
if (msg.role === "user") {
const docs = extractMentionedDocuments(msg.content);
@ -448,13 +459,27 @@ export default function NewChatPage() {
// Add user message to state
const userMsgId = `msg-user-${Date.now()}`;
// Include author metadata for shared chats
const authorMetadata =
currentThread?.visibility === "SEARCH_SPACE" && currentUser
? {
custom: {
author: {
displayName: currentUser.display_name ?? null,
avatarUrl: currentUser.avatar_url ?? null,
},
},
}
: undefined;
const userMessage: ThreadMessageLike = {
id: userMsgId,
role: "user",
content: message.content,
createdAt: new Date(),
// Include attachments so they can be displayed
attachments: message.attachments || [],
metadata: authorMetadata,
};
setMessages((prev) => [...prev, userMessage]);
@ -884,6 +909,8 @@ export default function NewChatPage() {
setMentionedDocuments,
setMessageDocumentsMap,
queryClient,
currentThread,
currentUser,
]
);

View file

@ -0,0 +1,123 @@
"use client";
import { Check, Copy, Key, Menu, Shield } from "lucide-react";
import { AnimatePresence, motion } from "motion/react";
import { useTranslations } from "next-intl";
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
import { Button } from "@/components/ui/button";
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip";
import { useApiKey } from "@/hooks/use-api-key";
interface ApiKeyContentProps {
onMenuClick: () => void;
}
export function ApiKeyContent({ onMenuClick }: ApiKeyContentProps) {
const t = useTranslations("userSettings");
const { apiKey, isLoading, copied, copyToClipboard } = useApiKey();
return (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ delay: 0.2, duration: 0.4 }}
className="h-full min-w-0 flex-1 overflow-hidden bg-background"
>
<div className="h-full overflow-y-auto">
<div className="mx-auto max-w-4xl p-4 md:p-6 lg:p-10">
<AnimatePresence mode="wait">
<motion.div
key="api-key-header"
initial={{ opacity: 0, y: 10 }}
animate={{ opacity: 1, y: 0 }}
exit={{ opacity: 0, y: -10 }}
transition={{ duration: 0.3 }}
className="mb-6 md:mb-8"
>
<div className="flex items-center gap-3 md:gap-4">
<Button
variant="outline"
size="icon"
onClick={onMenuClick}
className="h-10 w-10 shrink-0 md:hidden"
>
<Menu className="h-5 w-5" />
</Button>
<motion.div
initial={{ scale: 0.8, opacity: 0 }}
animate={{ scale: 1, opacity: 1 }}
transition={{ delay: 0.1, duration: 0.3 }}
className="flex h-10 w-10 shrink-0 items-center justify-center rounded-lg border border-primary/10 bg-gradient-to-br from-primary/20 to-primary/5 shadow-sm md:h-14 md:w-14 md:rounded-2xl"
>
<Key className="h-5 w-5 text-primary md:h-7 md:w-7" />
</motion.div>
<div className="min-w-0">
<h1 className="truncate text-lg font-bold tracking-tight md:text-2xl">
{t("api_key_title")}
</h1>
<p className="text-sm text-muted-foreground">{t("api_key_description")}</p>
</div>
</div>
</motion.div>
</AnimatePresence>
<AnimatePresence mode="wait">
<motion.div
key="api-key-content"
initial={{ opacity: 0, y: 20 }}
animate={{ opacity: 1, y: 0 }}
exit={{ opacity: 0, y: -20 }}
transition={{ duration: 0.35, ease: [0.4, 0, 0.2, 1] }}
className="space-y-6"
>
<Alert>
<Shield className="h-4 w-4" />
<AlertTitle>{t("api_key_warning_title")}</AlertTitle>
<AlertDescription>{t("api_key_warning_description")}</AlertDescription>
</Alert>
<div className="rounded-lg border bg-card p-6">
<h3 className="mb-4 font-medium">{t("your_api_key")}</h3>
{isLoading ? (
<div className="h-12 w-full animate-pulse rounded-md bg-muted" />
) : apiKey ? (
<div className="flex items-center gap-2">
<div className="flex-1 overflow-x-auto rounded-md bg-muted p-3 font-mono text-sm">
{apiKey}
</div>
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="outline"
size="icon"
onClick={copyToClipboard}
className="shrink-0"
>
{copied ? <Check className="h-4 w-4" /> : <Copy className="h-4 w-4" />}
</Button>
</TooltipTrigger>
<TooltipContent>{copied ? t("copied") : t("copy")}</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
) : (
<p className="text-center text-muted-foreground">{t("no_api_key")}</p>
)}
</div>
<div className="rounded-lg border bg-card p-6">
<h3 className="mb-2 font-medium">{t("usage_title")}</h3>
<p className="mb-4 text-sm text-muted-foreground">{t("usage_description")}</p>
<pre className="overflow-x-auto rounded-md bg-muted p-3 text-sm">
<code>Authorization: Bearer {apiKey || "YOUR_API_KEY"}</code>
</pre>
</div>
</motion.div>
</AnimatePresence>
</div>
</div>
</motion.div>
);
}

View file

@ -0,0 +1,181 @@
"use client";
import { useAtomValue } from "jotai";
import { Loader2, Menu, User } from "lucide-react";
import { AnimatePresence, motion } from "motion/react";
import { useTranslations } from "next-intl";
import { useEffect, useState } from "react";
import { toast } from "sonner";
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
import { updateUserMutationAtom } from "@/atoms/user/user-mutation.atoms";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
interface ProfileContentProps {
onMenuClick: () => void;
}
function AvatarDisplay({ url, fallback }: { url?: string; fallback: string }) {
const [hasError, setHasError] = useState(false);
useEffect(() => {
setHasError(false);
}, [url]);
if (url && !hasError) {
return (
<img
src={url}
alt="Avatar"
className="h-16 w-16 rounded-xl object-cover"
onError={() => setHasError(true)}
/>
);
}
return (
<div className="flex h-16 w-16 items-center justify-center rounded-xl bg-muted text-xl font-semibold text-muted-foreground">
{fallback}
</div>
);
}
export function ProfileContent({ onMenuClick }: ProfileContentProps) {
const t = useTranslations("userSettings");
const { data: user, isLoading: isUserLoading } = useAtomValue(currentUserAtom);
const { mutateAsync: updateUser, isPending } = useAtomValue(updateUserMutationAtom);
const [displayName, setDisplayName] = useState("");
useEffect(() => {
if (user) {
setDisplayName(user.display_name || "");
}
}, [user]);
const getInitials = (email: string) => {
const name = email.split("@")[0];
return name.slice(0, 2).toUpperCase();
};
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
try {
await updateUser({
display_name: displayName || null,
});
toast.success(t("profile_saved"));
} catch {
toast.error(t("profile_save_error"));
}
};
const hasChanges = displayName !== (user?.display_name || "");
return (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ delay: 0.2, duration: 0.4 }}
className="h-full min-w-0 flex-1 overflow-hidden bg-background"
>
<div className="h-full overflow-y-auto">
<div className="mx-auto max-w-4xl p-4 md:p-6 lg:p-10">
<AnimatePresence mode="wait">
<motion.div
key="profile-header"
initial={{ opacity: 0, y: 10 }}
animate={{ opacity: 1, y: 0 }}
exit={{ opacity: 0, y: -10 }}
transition={{ duration: 0.3 }}
className="mb-6 md:mb-8"
>
<div className="flex items-center gap-3 md:gap-4">
<Button
variant="outline"
size="icon"
onClick={onMenuClick}
className="h-10 w-10 shrink-0 md:hidden"
>
<Menu className="h-5 w-5" />
</Button>
<motion.div
initial={{ scale: 0.8, opacity: 0 }}
animate={{ scale: 1, opacity: 1 }}
transition={{ delay: 0.1, duration: 0.3 }}
className="flex h-10 w-10 shrink-0 items-center justify-center rounded-lg border border-primary/10 bg-gradient-to-br from-primary/20 to-primary/5 shadow-sm md:h-14 md:w-14 md:rounded-2xl"
>
<User className="h-5 w-5 text-primary md:h-7 md:w-7" />
</motion.div>
<div className="min-w-0">
<h1 className="truncate text-lg font-bold tracking-tight md:text-2xl">
{t("profile_title")}
</h1>
<p className="text-sm text-muted-foreground">{t("profile_description")}</p>
</div>
</div>
</motion.div>
</AnimatePresence>
<AnimatePresence mode="wait">
<motion.div
key="profile-content"
initial={{ opacity: 0, y: 20 }}
animate={{ opacity: 1, y: 0 }}
exit={{ opacity: 0, y: -20 }}
transition={{ duration: 0.35, ease: [0.4, 0, 0.2, 1] }}
>
{isUserLoading ? (
<div className="flex items-center justify-center py-12">
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
</div>
) : (
<form onSubmit={handleSubmit} className="space-y-6">
<div className="rounded-lg border bg-card p-6">
<div className="flex flex-col gap-6">
<div className="space-y-2">
<Label>{t("profile_avatar")}</Label>
<AvatarDisplay
url={user?.avatar_url || undefined}
fallback={getInitials(user?.email || "")}
/>
</div>
<div className="space-y-2">
<Label htmlFor="display-name">{t("profile_display_name")}</Label>
<Input
id="display-name"
type="text"
placeholder={user?.email?.split("@")[0]}
value={displayName}
onChange={(e) => setDisplayName(e.target.value)}
/>
<p className="text-xs text-muted-foreground">
{t("profile_display_name_hint")}
</p>
</div>
<div className="space-y-2">
<Label>{t("profile_email")}</Label>
<Input type="email" value={user?.email || ""} disabled />
</div>
</div>
</div>
<div className="flex justify-end">
<Button type="submit" disabled={isPending || !hasChanges}>
{isPending && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
{t("profile_save")}
</Button>
</div>
</form>
)}
</motion.div>
</AnimatePresence>
</div>
</div>
</motion.div>
);
}

View file

@ -0,0 +1,155 @@
"use client";
import { ArrowLeft, ChevronRight, X } from "lucide-react";
import type { LucideIcon } from "lucide-react";
import { AnimatePresence, motion } from "motion/react";
import { useTranslations } from "next-intl";
import { Button } from "@/components/ui/button";
import { cn } from "@/lib/utils";
export interface SettingsNavItem {
id: string;
label: string;
description: string;
icon: LucideIcon;
}
interface UserSettingsSidebarProps {
activeSection: string;
onSectionChange: (section: string) => void;
onBackToApp: () => void;
isOpen: boolean;
onClose: () => void;
navItems: SettingsNavItem[];
}
export function UserSettingsSidebar({
activeSection,
onSectionChange,
onBackToApp,
isOpen,
onClose,
navItems,
}: UserSettingsSidebarProps) {
const t = useTranslations("userSettings");
const handleNavClick = (sectionId: string) => {
onSectionChange(sectionId);
onClose();
};
return (
<>
<AnimatePresence>
{isOpen && (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
transition={{ duration: 0.2 }}
className="fixed inset-0 z-40 bg-background/80 backdrop-blur-sm md:hidden"
onClick={onClose}
/>
)}
</AnimatePresence>
<aside
className={cn(
"fixed left-0 top-0 z-50 md:relative md:z-auto",
"flex h-full w-72 shrink-0 flex-col bg-background md:bg-muted/30",
"md:border-r",
"transition-transform duration-300 ease-out",
"md:translate-x-0",
isOpen ? "translate-x-0" : "-translate-x-full md:translate-x-0"
)}
>
{/* Header with title */}
<div className="space-y-3 p-4">
<div className="flex items-center justify-between">
<Button
variant="ghost"
onClick={onBackToApp}
className="group h-11 justify-start gap-3 px-3 hover:bg-muted"
>
<div className="flex h-8 w-8 items-center justify-center rounded-lg bg-primary/10 transition-colors group-hover:bg-primary/20">
<ArrowLeft className="h-4 w-4 text-primary" />
</div>
<span className="font-medium">{t("back_to_app")}</span>
</Button>
<Button variant="ghost" size="icon" onClick={onClose} className="h-9 w-9 md:hidden">
<X className="h-5 w-5" />
</Button>
</div>
{/* Settings Title */}
<div className="px-3">
<h2 className="text-lg font-semibold text-foreground">{t("title")}</h2>
</div>
</div>
<nav className="flex-1 space-y-1 overflow-y-auto px-3 py-2">
{navItems.map((item, index) => {
const isActive = activeSection === item.id;
const Icon = item.icon;
return (
<motion.button
key={item.id}
initial={{ opacity: 0, x: -10 }}
animate={{ opacity: 1, x: 0 }}
transition={{ delay: 0.1 + index * 0.05, duration: 0.3 }}
onClick={() => handleNavClick(item.id)}
whileHover={{ scale: 1.01 }}
whileTap={{ scale: 0.99 }}
className={cn(
"relative flex w-full items-center gap-3 rounded-xl px-3 py-3 text-left transition-all duration-200",
isActive ? "border border-border bg-muted shadow-sm" : "hover:bg-muted/60"
)}
>
{isActive && (
<motion.div
layoutId="userSettingsActiveIndicator"
className="absolute left-0 top-1/2 h-8 w-1 -translate-y-1/2 rounded-r-full bg-primary"
initial={false}
transition={{
type: "spring",
stiffness: 500,
damping: 35,
}}
/>
)}
<div
className={cn(
"flex h-9 w-9 items-center justify-center rounded-lg transition-colors",
isActive ? "bg-primary/10 text-primary" : "bg-muted text-muted-foreground"
)}
>
<Icon className="h-4 w-4" />
</div>
<div className="min-w-0 flex-1">
<p
className={cn(
"truncate text-sm font-medium transition-colors",
isActive ? "text-foreground" : "text-muted-foreground"
)}
>
{item.label}
</p>
<p className="truncate text-xs text-muted-foreground/70">{item.description}</p>
</div>
<ChevronRight
className={cn(
"h-4 w-4 shrink-0 transition-all",
isActive
? "translate-x-0 text-primary opacity-100"
: "-translate-x-1 text-muted-foreground/40 opacity-0"
)}
/>
</motion.button>
);
})}
</nav>
</aside>
</>
);
}

View file

@ -1,286 +1,27 @@
"use client";
import {
ArrowLeft,
Check,
ChevronRight,
Copy,
Key,
type LucideIcon,
Menu,
Shield,
X,
} from "lucide-react";
import { AnimatePresence, motion } from "motion/react";
import { Key, User } from "lucide-react";
import { motion } from "motion/react";
import { useRouter } from "next/navigation";
import { useTranslations } from "next-intl";
import { useCallback, useState } from "react";
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
import { Button } from "@/components/ui/button";
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip";
import { useApiKey } from "@/hooks/use-api-key";
import { cn } from "@/lib/utils";
interface SettingsNavItem {
id: string;
label: string;
description: string;
icon: LucideIcon;
}
function UserSettingsSidebar({
activeSection,
onSectionChange,
onBackToApp,
isOpen,
onClose,
navItems,
}: {
activeSection: string;
onSectionChange: (section: string) => void;
onBackToApp: () => void;
isOpen: boolean;
onClose: () => void;
navItems: SettingsNavItem[];
}) {
const t = useTranslations("userSettings");
const handleNavClick = (sectionId: string) => {
onSectionChange(sectionId);
onClose();
};
return (
<>
<AnimatePresence>
{isOpen && (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
transition={{ duration: 0.2 }}
className="fixed inset-0 z-40 bg-background/80 backdrop-blur-sm md:hidden"
onClick={onClose}
/>
)}
</AnimatePresence>
<aside
className={cn(
"fixed left-0 top-0 z-50 md:relative md:z-auto",
"flex h-full w-72 shrink-0 flex-col bg-background md:bg-muted/30",
"md:border-r",
"transition-transform duration-300 ease-out",
"md:translate-x-0",
isOpen ? "translate-x-0" : "-translate-x-full md:translate-x-0"
)}
>
{/* Header with title */}
<div className="space-y-3 p-4">
<div className="flex items-center justify-between">
<Button
variant="ghost"
onClick={onBackToApp}
className="group h-11 justify-start gap-3 px-3 hover:bg-muted"
>
<div className="flex h-8 w-8 items-center justify-center rounded-lg bg-primary/10 transition-colors group-hover:bg-primary/20">
<ArrowLeft className="h-4 w-4 text-primary" />
</div>
<span className="font-medium">{t("back_to_app")}</span>
</Button>
<Button variant="ghost" size="icon" onClick={onClose} className="h-9 w-9 md:hidden">
<X className="h-5 w-5" />
</Button>
</div>
{/* Settings Title */}
<div className="px-3">
<h2 className="text-lg font-semibold text-foreground">{t("title")}</h2>
</div>
</div>
<nav className="flex-1 space-y-1 overflow-y-auto px-3 py-2">
{navItems.map((item, index) => {
const isActive = activeSection === item.id;
const Icon = item.icon;
return (
<motion.button
key={item.id}
initial={{ opacity: 0, x: -10 }}
animate={{ opacity: 1, x: 0 }}
transition={{ delay: 0.1 + index * 0.05, duration: 0.3 }}
onClick={() => handleNavClick(item.id)}
whileHover={{ scale: 1.01 }}
whileTap={{ scale: 0.99 }}
className={cn(
"relative flex w-full items-center gap-3 rounded-xl px-3 py-3 text-left transition-all duration-200",
isActive ? "border border-border bg-muted shadow-sm" : "hover:bg-muted/60"
)}
>
{isActive && (
<motion.div
layoutId="userSettingsActiveIndicator"
className="absolute left-0 top-1/2 h-8 w-1 -translate-y-1/2 rounded-r-full bg-primary"
initial={false}
transition={{
type: "spring",
stiffness: 500,
damping: 35,
}}
/>
)}
<div
className={cn(
"flex h-9 w-9 items-center justify-center rounded-lg transition-colors",
isActive ? "bg-primary/10 text-primary" : "bg-muted text-muted-foreground"
)}
>
<Icon className="h-4 w-4" />
</div>
<div className="min-w-0 flex-1">
<p
className={cn(
"truncate text-sm font-medium transition-colors",
isActive ? "text-foreground" : "text-muted-foreground"
)}
>
{item.label}
</p>
<p className="truncate text-xs text-muted-foreground/70">{item.description}</p>
</div>
<ChevronRight
className={cn(
"h-4 w-4 shrink-0 transition-all",
isActive
? "translate-x-0 text-primary opacity-100"
: "-translate-x-1 text-muted-foreground/40 opacity-0"
)}
/>
</motion.button>
);
})}
</nav>
</aside>
</>
);
}
function ApiKeyContent({ onMenuClick }: { onMenuClick: () => void }) {
const t = useTranslations("userSettings");
const { apiKey, isLoading, copied, copyToClipboard } = useApiKey();
return (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ delay: 0.2, duration: 0.4 }}
className="h-full min-w-0 flex-1 overflow-hidden bg-background"
>
<div className="h-full overflow-y-auto">
<div className="mx-auto max-w-4xl p-4 md:p-6 lg:p-10">
<AnimatePresence mode="wait">
<motion.div
key="api-key-header"
initial={{ opacity: 0, y: 10 }}
animate={{ opacity: 1, y: 0 }}
exit={{ opacity: 0, y: -10 }}
transition={{ duration: 0.3 }}
className="mb-6 md:mb-8"
>
<div className="flex items-center gap-3 md:gap-4">
<Button
variant="outline"
size="icon"
onClick={onMenuClick}
className="h-10 w-10 shrink-0 md:hidden"
>
<Menu className="h-5 w-5" />
</Button>
<motion.div
initial={{ scale: 0.8, opacity: 0 }}
animate={{ scale: 1, opacity: 1 }}
transition={{ delay: 0.1, duration: 0.3 }}
className="flex h-10 w-10 shrink-0 items-center justify-center rounded-lg border border-primary/10 bg-gradient-to-br from-primary/20 to-primary/5 shadow-sm md:h-14 md:w-14 md:rounded-2xl"
>
<Key className="h-5 w-5 text-primary md:h-7 md:w-7" />
</motion.div>
<div className="min-w-0">
<h1 className="truncate text-lg font-bold tracking-tight md:text-2xl">
{t("api_key_title")}
</h1>
<p className="text-sm text-muted-foreground">{t("api_key_description")}</p>
</div>
</div>
</motion.div>
</AnimatePresence>
<AnimatePresence mode="wait">
<motion.div
key="api-key-content"
initial={{ opacity: 0, y: 20 }}
animate={{ opacity: 1, y: 0 }}
exit={{ opacity: 0, y: -20 }}
transition={{ duration: 0.35, ease: [0.4, 0, 0.2, 1] }}
className="space-y-6"
>
<Alert>
<Shield className="h-4 w-4" />
<AlertTitle>{t("api_key_warning_title")}</AlertTitle>
<AlertDescription>{t("api_key_warning_description")}</AlertDescription>
</Alert>
<div className="rounded-lg border bg-card p-6">
<h3 className="mb-4 font-medium">{t("your_api_key")}</h3>
{isLoading ? (
<div className="h-12 w-full animate-pulse rounded-md bg-muted" />
) : apiKey ? (
<div className="flex items-center gap-2">
<div className="flex-1 overflow-x-auto rounded-md bg-muted p-3 font-mono text-sm">
{apiKey}
</div>
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="outline"
size="icon"
onClick={copyToClipboard}
className="shrink-0"
>
{copied ? <Check className="h-4 w-4" /> : <Copy className="h-4 w-4" />}
</Button>
</TooltipTrigger>
<TooltipContent>{copied ? t("copied") : t("copy")}</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
) : (
<p className="text-center text-muted-foreground">{t("no_api_key")}</p>
)}
</div>
<div className="rounded-lg border bg-card p-6">
<h3 className="mb-2 font-medium">{t("usage_title")}</h3>
<p className="mb-4 text-sm text-muted-foreground">{t("usage_description")}</p>
<pre className="overflow-x-auto rounded-md bg-muted p-3 text-sm">
<code>Authorization: Bearer {apiKey || "YOUR_API_KEY"}</code>
</pre>
</div>
</motion.div>
</AnimatePresence>
</div>
</div>
</motion.div>
);
}
import { ApiKeyContent } from "./components/ApiKeyContent";
import { ProfileContent } from "./components/ProfileContent";
import { UserSettingsSidebar, type SettingsNavItem } from "./components/UserSettingsSidebar";
export default function UserSettingsPage() {
const t = useTranslations("userSettings");
const router = useRouter();
const [activeSection, setActiveSection] = useState("api-key");
const [activeSection, setActiveSection] = useState("profile");
const [isSidebarOpen, setIsSidebarOpen] = useState(false);
const navItems: SettingsNavItem[] = [
{
id: "profile",
label: t("profile_nav_label"),
description: t("profile_nav_description"),
icon: User,
},
{
id: "api-key",
label: t("api_key_nav_label"),
@ -310,6 +51,9 @@ export default function UserSettingsPage() {
onClose={() => setIsSidebarOpen(false)}
navItems={navItems}
/>
{activeSection === "profile" && (
<ProfileContent onMenuClick={() => setIsSidebarOpen(true)} />
)}
{activeSection === "api-key" && (
<ApiKeyContent onMenuClick={() => setIsSidebarOpen(true)} />
)}

View file

@ -0,0 +1,19 @@
import { atomWithMutation, queryClientAtom } from "jotai-tanstack-query";
import type { UpdateUserRequest } from "@/contracts/types/user.types";
import { userApiService } from "@/lib/apis/user-api.service";
import { cacheKeys } from "@/lib/query-client/cache-keys";
export const updateUserMutationAtom = atomWithMutation((get) => {
const queryClient = get(queryClientAtom);
return {
mutationKey: cacheKeys.user.current(),
mutationFn: async (request: UpdateUserRequest) => {
return userApiService.updateMe(request);
},
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: cacheKeys.user.current() });
},
};
});

View file

@ -19,9 +19,7 @@ import {
ChevronRightIcon,
CopyIcon,
DownloadIcon,
FileText,
Loader2,
PencilIcon,
RefreshCwIcon,
SquareIcon,
} from "lucide-react";
@ -31,7 +29,6 @@ import { createPortal } from "react-dom";
import {
mentionedDocumentIdsAtom,
mentionedDocumentsAtom,
messageDocumentsMapAtom,
} from "@/atoms/chat/mentioned-documents.atom";
import {
globalNewLLMConfigsAtom,
@ -42,8 +39,8 @@ import { currentUserAtom } from "@/atoms/user/user-query.atoms";
import {
ComposerAddAttachment,
ComposerAttachments,
UserMessageAttachments,
} from "@/components/assistant-ui/attachment";
import { UserMessage } from "@/components/assistant-ui/user-message";
import { ConnectorIndicator } from "@/components/assistant-ui/connector-popup";
import {
InlineMentionEditor,
@ -639,69 +636,6 @@ const AssistantActionBar: FC = () => {
);
};
const UserMessage: FC = () => {
const messageId = useAssistantState(({ message }) => message?.id);
const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom);
const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined;
const hasAttachments = useAssistantState(
({ message }) => message?.attachments && message.attachments.length > 0
);
return (
<MessagePrimitive.Root
className="aui-user-message-root fade-in slide-in-from-bottom-1 mx-auto grid w-full max-w-(--thread-max-width) animate-in auto-rows-auto grid-cols-[minmax(72px,1fr)_auto] content-start gap-y-2 px-2 py-3 duration-150 [&:where(>*)]:col-start-2"
data-role="user"
>
<div className="aui-user-message-content-wrapper col-start-2 min-w-0">
{/* Display attachments and mentioned documents */}
{(hasAttachments || (mentionedDocs && mentionedDocs.length > 0)) && (
<div className="flex flex-wrap items-end gap-2 mb-2 justify-end">
{/* Attachments (images show as thumbnails, documents as chips) */}
<UserMessageAttachments />
{/* Mentioned documents as chips */}
{mentionedDocs?.map((doc) => (
<span
key={`${doc.document_type}:${doc.id}`}
className="inline-flex items-center gap-1 px-2 py-0.5 rounded-full bg-primary/10 text-xs font-medium text-primary border border-primary/20"
title={doc.title}
>
<FileText className="size-3" />
<span className="max-w-[150px] truncate">{doc.title}</span>
</span>
))}
</div>
)}
{/* Message bubble with action bar positioned relative to it */}
<div className="relative">
<div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground">
<MessagePrimitive.Parts />
</div>
<div className="aui-user-action-bar-wrapper absolute top-1/2 right-full -translate-y-1/2 pr-1">
<UserActionBar />
</div>
</div>
</div>
<BranchPicker className="aui-user-branch-picker -mr-1 col-span-full col-start-1 row-start-3 justify-end" />
</MessagePrimitive.Root>
);
};
const UserActionBar: FC = () => {
return (
<ActionBarPrimitive.Root
hideWhenRunning
autohide="not-last"
className="aui-user-action-bar-root flex flex-col items-end"
>
<ActionBarPrimitive.Edit asChild>
<TooltipIconButton tooltip="Edit" className="aui-user-action-edit p-4">
<PencilIcon />
</TooltipIconButton>
</ActionBarPrimitive.Edit>
</ActionBarPrimitive.Root>
);
};
const EditComposer: FC = () => {
return (

View file

@ -1,16 +1,54 @@
import { ActionBarPrimitive, MessagePrimitive, useAssistantState } from "@assistant-ui/react";
import { useAtomValue } from "jotai";
import { FileText, PencilIcon } from "lucide-react";
import type { FC } from "react";
import { type FC, useState } from "react";
import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom";
import { UserMessageAttachments } from "@/components/assistant-ui/attachment";
import { BranchPicker } from "@/components/assistant-ui/branch-picker";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
interface AuthorMetadata {
displayName: string | null;
avatarUrl: string | null;
}
const UserAvatar: FC<AuthorMetadata> = ({ displayName, avatarUrl }) => {
const [hasError, setHasError] = useState(false);
const initials = displayName
? displayName
.split(" ")
.map((n) => n[0])
.join("")
.toUpperCase()
.slice(0, 2)
: "U";
if (avatarUrl && !hasError) {
return (
<img
src={avatarUrl}
alt={displayName || "User"}
className="size-8 rounded-full object-cover"
referrerPolicy="no-referrer"
onError={() => setHasError(true)}
/>
);
}
return (
<div className="flex size-8 items-center justify-center rounded-full bg-primary/10 text-xs font-medium text-primary">
{initials}
</div>
);
};
export const UserMessage: FC = () => {
const messageId = useAssistantState(({ message }) => message?.id);
const messageDocumentsMap = useAtomValue(messageDocumentsMapAtom);
const mentionedDocs = messageId ? messageDocumentsMap[messageId] : undefined;
const metadata = useAssistantState(({ message }) => message?.metadata);
const author = metadata?.custom?.author as AuthorMetadata | undefined;
const hasAttachments = useAssistantState(
({ message }) => message?.attachments && message.attachments.length > 0
);
@ -20,34 +58,42 @@ export const UserMessage: FC = () => {
className="aui-user-message-root fade-in slide-in-from-bottom-1 mx-auto grid w-full max-w-(--thread-max-width) animate-in auto-rows-auto grid-cols-[minmax(72px,1fr)_auto] content-start gap-y-2 px-2 py-3 duration-150 [&:where(>*)]:col-start-2"
data-role="user"
>
<div className="aui-user-message-content-wrapper col-start-2 min-w-0">
{/* Display attachments and mentioned documents */}
{(hasAttachments || (mentionedDocs && mentionedDocs.length > 0)) && (
<div className="flex flex-wrap items-end gap-2 mb-2 justify-end">
{/* Attachments (images show as thumbnails, documents as chips) */}
<UserMessageAttachments />
{/* Mentioned documents as chips */}
{mentionedDocs?.map((doc) => (
<span
key={`${doc.document_type}:${doc.id}`}
className="inline-flex items-center gap-1 px-2 py-0.5 rounded-full bg-primary/10 text-xs font-medium text-primary border border-primary/20"
title={doc.title}
>
<FileText className="size-3" />
<span className="max-w-[150px] truncate">{doc.title}</span>
</span>
))}
</div>
)}
{/* Message bubble with action bar positioned relative to it */}
<div className="relative">
<div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground">
<MessagePrimitive.Parts />
</div>
<div className="aui-user-action-bar-wrapper absolute top-1/2 right-full -translate-y-1/2 pr-1">
<UserActionBar />
<div className="aui-user-message-content-wrapper col-start-2 min-w-0 flex items-end gap-2">
<div className="flex-1 min-w-0">
{/* Display attachments and mentioned documents */}
{(hasAttachments || (mentionedDocs && mentionedDocs.length > 0)) && (
<div className="flex flex-wrap items-end gap-2 mb-2 justify-end">
{/* Attachments (images show as thumbnails, documents as chips) */}
<UserMessageAttachments />
{/* Mentioned documents as chips */}
{mentionedDocs?.map((doc) => (
<span
key={`${doc.document_type}:${doc.id}`}
className="inline-flex items-center gap-1 px-2 py-0.5 rounded-full bg-primary/10 text-xs font-medium text-primary border border-primary/20"
title={doc.title}
>
<FileText className="size-3" />
<span className="max-w-[150px] truncate">{doc.title}</span>
</span>
))}
</div>
)}
{/* Message bubble with action bar positioned relative to it */}
<div className="relative">
<div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground">
<MessagePrimitive.Parts />
</div>
<div className="aui-user-action-bar-wrapper absolute top-1/2 right-full -translate-y-1/2 pr-1">
<UserActionBar />
</div>
</div>
</div>
{/* User avatar - only shown in shared chats */}
{author && (
<div className="shrink-0">
<UserAvatar displayName={author.displayName} avatarUrl={author.avatarUrl} />
</div>
)}
</div>
<BranchPicker className="aui-user-branch-picker -mr-1 col-span-full col-start-1 row-start-3 justify-end" />

View file

@ -354,7 +354,11 @@ export function LayoutDataProvider({
onChatDelete={handleChatDelete}
onViewAllSharedChats={handleViewAllSharedChats}
onViewAllPrivateChats={handleViewAllPrivateChats}
user={{ email: user?.email || "", name: user?.email?.split("@")[0] }}
user={{
email: user?.email || "",
name: user?.display_name || user?.email?.split("@")[0],
avatarUrl: user?.avatar_url || undefined,
}}
onSettings={handleSettings}
onManageMembers={handleManageMembers}
onUserSettings={handleUserSettings}

View file

@ -12,6 +12,7 @@ export interface SearchSpace {
export interface User {
email: string;
name?: string;
avatarUrl?: string;
}
export interface NavItem {

View file

@ -61,6 +61,39 @@ function getInitials(email: string): string {
return name.slice(0, 2).toUpperCase();
}
/**
* User avatar component - shows image if available, otherwise falls back to initials
*/
function UserAvatar({
avatarUrl,
initials,
bgColor,
}: {
avatarUrl?: string;
initials: string;
bgColor: string;
}) {
if (avatarUrl) {
return (
<img
src={avatarUrl}
alt="User avatar"
className="h-8 w-8 shrink-0 rounded-lg object-cover"
referrerPolicy="no-referrer"
/>
);
}
return (
<div
className="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg text-xs font-semibold text-white"
style={{ backgroundColor: bgColor }}
>
{initials}
</div>
);
}
export function SidebarUserProfile({
user,
onUserSettings,
@ -88,12 +121,7 @@ export function SidebarUserProfile({
"focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring"
)}
>
<div
className="flex h-8 w-8 items-center justify-center rounded-lg text-xs font-semibold text-white"
style={{ backgroundColor: bgColor }}
>
{initials}
</div>
<UserAvatar avatarUrl={user.avatarUrl} initials={initials} bgColor={bgColor} />
<span className="sr-only">{displayName}</span>
</button>
</DropdownMenuTrigger>
@ -104,12 +132,7 @@ export function SidebarUserProfile({
<DropdownMenuContent className="w-56" side="right" align="end" sideOffset={8}>
<DropdownMenuLabel className="font-normal">
<div className="flex items-center gap-2">
<div
className="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg text-xs font-semibold text-white"
style={{ backgroundColor: bgColor }}
>
{initials}
</div>
<UserAvatar avatarUrl={user.avatarUrl} initials={initials} bgColor={bgColor} />
<div className="flex-1 min-w-0">
<p className="truncate text-sm font-medium">{displayName}</p>
<p className="truncate text-xs text-muted-foreground">{user.email}</p>
@ -149,13 +172,7 @@ export function SidebarUserProfile({
"focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring"
)}
>
{/* Avatar */}
<div
className="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg text-xs font-semibold text-white"
style={{ backgroundColor: bgColor }}
>
{initials}
</div>
<UserAvatar avatarUrl={user.avatarUrl} initials={initials} bgColor={bgColor} />
{/* Name and email */}
<div className="flex-1 min-w-0">
@ -171,12 +188,7 @@ export function SidebarUserProfile({
<DropdownMenuContent className="w-56" side="top" align="start" sideOffset={4}>
<DropdownMenuLabel className="font-normal">
<div className="flex items-center gap-2">
<div
className="flex h-8 w-8 shrink-0 items-center justify-center rounded-lg text-xs font-semibold text-white"
style={{ backgroundColor: bgColor }}
>
{initials}
</div>
<UserAvatar avatarUrl={user.avatarUrl} initials={initials} bgColor={bgColor} />
<div className="flex-1 min-w-0">
<p className="truncate text-sm font-medium">{displayName}</p>
<p className="truncate text-xs text-muted-foreground">{user.email}</p>

View file

@ -215,6 +215,16 @@ export const DocumentMentionPicker = forwardRef<
isSurfsenseDocsLoading) &&
currentPage === 0;
// Split documents into SurfSense docs and user docs for grouped rendering
const surfsenseDocsList = useMemo(
() => actualDocuments.filter((doc) => doc.document_type === "SURFSENSE_DOCS"),
[actualDocuments]
);
const userDocsList = useMemo(
() => actualDocuments.filter((doc) => doc.document_type !== "SURFSENSE_DOCS"),
[actualDocuments]
);
// Track already selected documents using unique key (document_type:id) to avoid ID collisions
const selectedKeys = useMemo(
() => new Set(initialSelectedDocuments.map((d) => `${d.document_type}:${d.id}`)),
@ -324,47 +334,102 @@ export const DocumentMentionPicker = forwardRef<
</div>
) : (
<div className="py-1">
{actualDocuments.map((doc) => {
const docKey = `${doc.document_type}:${doc.id}`;
const isAlreadySelected = selectedKeys.has(docKey);
const selectableIndex = selectableDocuments.findIndex(
(d) => d.document_type === doc.document_type && d.id === doc.id
);
const isHighlighted = !isAlreadySelected && selectableIndex === highlightedIndex;
{/* SurfSense Documentation Section */}
{surfsenseDocsList.length > 0 && (
<>
<div className="sticky top-0 z-10 px-3 py-2 text-xs font-bold uppercase tracking-wider bg-muted text-foreground/80 border-b border-border">
SurfSense Docs
</div>
{surfsenseDocsList.map((doc) => {
const docKey = `${doc.document_type}:${doc.id}`;
const isAlreadySelected = selectedKeys.has(docKey);
const selectableIndex = selectableDocuments.findIndex(
(d) => d.document_type === doc.document_type && d.id === doc.id
);
const isHighlighted = !isAlreadySelected && selectableIndex === highlightedIndex;
return (
<button
key={docKey}
ref={(el) => {
if (el && selectableIndex >= 0) {
itemRefs.current.set(selectableIndex, el);
}
}}
type="button"
onClick={() => !isAlreadySelected && handleSelectDocument(doc)}
onMouseEnter={() => {
if (!isAlreadySelected && selectableIndex >= 0) {
setHighlightedIndex(selectableIndex);
}
}}
disabled={isAlreadySelected}
className={cn(
"w-full flex items-center gap-2 px-3 py-2 text-left transition-colors",
isAlreadySelected ? "opacity-50 cursor-not-allowed" : "cursor-pointer",
isHighlighted && "bg-accent"
)}
>
<span className="shrink-0 text-muted-foreground text-sm">
{getConnectorIcon(doc.document_type)}
</span>
<span className="flex-1 text-sm truncate" title={doc.title}>
{doc.title}
</span>
</button>
);
})}
</>
)}
{/* User Documents Section */}
{userDocsList.length > 0 && (
<>
<div className="sticky top-0 z-10 px-3 py-2 text-xs font-bold uppercase tracking-wider bg-muted text-foreground/80 border-b border-border">
Your Documents
</div>
{userDocsList.map((doc) => {
const docKey = `${doc.document_type}:${doc.id}`;
const isAlreadySelected = selectedKeys.has(docKey);
const selectableIndex = selectableDocuments.findIndex(
(d) => d.document_type === doc.document_type && d.id === doc.id
);
const isHighlighted = !isAlreadySelected && selectableIndex === highlightedIndex;
return (
<button
key={docKey}
ref={(el) => {
if (el && selectableIndex >= 0) {
itemRefs.current.set(selectableIndex, el);
}
}}
type="button"
onClick={() => !isAlreadySelected && handleSelectDocument(doc)}
onMouseEnter={() => {
if (!isAlreadySelected && selectableIndex >= 0) {
setHighlightedIndex(selectableIndex);
}
}}
disabled={isAlreadySelected}
className={cn(
"w-full flex items-center gap-2 px-3 py-2 text-left transition-colors",
isAlreadySelected ? "opacity-50 cursor-not-allowed" : "cursor-pointer",
isHighlighted && "bg-accent"
)}
>
<span className="shrink-0 text-muted-foreground text-sm">
{getConnectorIcon(doc.document_type)}
</span>
<span className="flex-1 text-sm truncate" title={doc.title}>
{doc.title}
</span>
</button>
);
})}
</>
)}
return (
<button
key={docKey}
ref={(el) => {
if (el && selectableIndex >= 0) {
itemRefs.current.set(selectableIndex, el);
}
}}
type="button"
onClick={() => !isAlreadySelected && handleSelectDocument(doc)}
onMouseEnter={() => {
if (!isAlreadySelected && selectableIndex >= 0) {
setHighlightedIndex(selectableIndex);
}
}}
disabled={isAlreadySelected}
className={cn(
"w-full flex items-center gap-2 px-3 py-2 text-left transition-colors",
isAlreadySelected ? "opacity-50 cursor-not-allowed" : "cursor-pointer",
isHighlighted && "bg-accent"
)}
>
{/* Type icon */}
<span className="flex-shrink-0 text-muted-foreground text-sm">
{getConnectorIcon(doc.document_type)}
</span>
{/* Title */}
<span className="flex-1 text-sm truncate" title={doc.title}>
{doc.title}
</span>
</button>
);
})}
{/* Loading indicator for additional pages */}
{isLoadingMore && (
<div className="flex items-center justify-center py-2">

View file

@ -8,6 +8,8 @@ export const user = z.object({
is_verified: z.boolean(),
pages_limit: z.number(),
pages_used: z.number(),
display_name: z.string().nullish(),
avatar_url: z.string().nullish(),
});
/**
@ -15,5 +17,20 @@ export const user = z.object({
*/
export const getMeResponse = user;
/**
* Update current user request
*/
export const updateUserRequest = z.object({
display_name: z.string().nullish(),
avatar_url: z.string().nullish(),
});
/**
* Update current user response
*/
export const updateUserResponse = user;
export type User = z.infer<typeof user>;
export type GetMeResponse = z.infer<typeof getMeResponse>;
export type UpdateUserRequest = z.infer<typeof updateUserRequest>;
export type UpdateUserResponse = z.infer<typeof updateUserResponse>;

View file

@ -1,4 +1,8 @@
import { getMeResponse } from "@/contracts/types/user.types";
import {
getMeResponse,
updateUserResponse,
type UpdateUserRequest,
} from "@/contracts/types/user.types";
import { baseApiService } from "./base-api.service";
class UserApiService {
@ -8,6 +12,15 @@ class UserApiService {
getMe = async () => {
return baseApiService.get(`/users/me`, getMeResponse);
};
/**
* Update current authenticated user
*/
updateMe = async (request: UpdateUserRequest) => {
return baseApiService.patch(`/users/me`, updateUserResponse, {
body: request,
});
};
}
export const userApiService = new UserApiService();

View file

@ -31,6 +31,9 @@ export interface MessageRecord {
role: "user" | "assistant" | "system";
content: unknown;
created_at: string;
author_id?: string | null;
author_display_name?: string | null;
author_avatar_url?: string | null;
}
export interface ThreadListResponse {

View file

@ -1,10 +1,10 @@
/**
* Environment configuration for the frontend.
*
*
* This file centralizes access to NEXT_PUBLIC_* environment variables.
* For Docker deployments, these placeholders are replaced at container startup
* via sed in the entrypoint script.
*
*
* IMPORTANT: Do not use template literals or complex expressions with these values
* as it may prevent the sed replacement from working correctly.
*/
@ -24,5 +24,5 @@ export const ETL_SERVICE = process.env.NEXT_PUBLIC_ETL_SERVICE || "DOCLING";
// Helper to check if local auth is enabled
export const isLocalAuth = () => AUTH_TYPE === "LOCAL";
// Helper to check if Google auth is enabled
// Helper to check if Google auth is enabled
export const isGoogleAuth = () => AUTH_TYPE === "GOOGLE";

View file

@ -109,6 +109,17 @@
"title": "User Settings",
"description": "Manage your account settings and API access",
"back_to_app": "Back to app",
"profile_nav_label": "Profile",
"profile_nav_description": "Manage your display name and avatar",
"profile_title": "Profile",
"profile_description": "Update your personal information",
"profile_avatar": "Profile Picture",
"profile_display_name": "Display Name",
"profile_display_name_hint": "This is how your name appears across the app",
"profile_email": "Email",
"profile_save": "Save Changes",
"profile_saved": "Profile updated successfully",
"profile_save_error": "Failed to update profile",
"api_key_nav_label": "API Key",
"api_key_nav_description": "Manage your API access token",
"api_key_title": "API Key",