Merge branch 'dev' into bugs_prod

This commit is contained in:
Manoj Aggarwal 2026-02-05 10:53:16 -08:00
commit e6c0fabd0a
125 changed files with 5800 additions and 1115 deletions

View file

@ -239,6 +239,7 @@ ENV POSTGRES_DB=surfsense
ENV DATABASE_URL=postgresql+asyncpg://surfsense:surfsense@localhost:5432/surfsense ENV DATABASE_URL=postgresql+asyncpg://surfsense:surfsense@localhost:5432/surfsense
ENV CELERY_BROKER_URL=redis://localhost:6379/0 ENV CELERY_BROKER_URL=redis://localhost:6379/0
ENV CELERY_RESULT_BACKEND=redis://localhost:6379/0 ENV CELERY_RESULT_BACKEND=redis://localhost:6379/0
ENV CELERY_TASK_DEFAULT_QUEUE=surfsense
ENV PYTHONPATH=/app/backend ENV PYTHONPATH=/app/backend
ENV NEXT_FRONTEND_URL=http://localhost:3000 ENV NEXT_FRONTEND_URL=http://localhost:3000
ENV AUTH_TYPE=LOCAL ENV AUTH_TYPE=LOCAL

View file

@ -53,6 +53,8 @@ services:
- DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-surfsense} - DATABASE_URL=postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-surfsense}
- CELERY_BROKER_URL=redis://redis:${REDIS_PORT:-6379}/0 - CELERY_BROKER_URL=redis://redis:${REDIS_PORT:-6379}/0
- CELERY_RESULT_BACKEND=redis://redis:${REDIS_PORT:-6379}/0 - CELERY_RESULT_BACKEND=redis://redis:${REDIS_PORT:-6379}/0
# Queue name isolation - prevents task collision if Redis is shared with other apps
- CELERY_TASK_DEFAULT_QUEUE=surfsense
- PYTHONPATH=/app - PYTHONPATH=/app
- UVICORN_LOOP=asyncio - UVICORN_LOOP=asyncio
- UNSTRUCTURED_HAS_PATCHED_LOOP=1 - UNSTRUCTURED_HAS_PATCHED_LOOP=1

View file

@ -3,6 +3,12 @@ DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense
#Celery Config #Celery Config
CELERY_BROKER_URL=redis://localhost:6379/0 CELERY_BROKER_URL=redis://localhost:6379/0
CELERY_RESULT_BACKEND=redis://localhost:6379/0 CELERY_RESULT_BACKEND=redis://localhost:6379/0
# Optional: isolate queues when sharing Redis with other apps
CELERY_TASK_DEFAULT_QUEUE=surfsense
# Redis for app-level features (heartbeats, podcast markers)
# Defaults to CELERY_BROKER_URL when not set
REDIS_APP_URL=redis://localhost:6379/0
#Electric(for migrations only) #Electric(for migrations only)
ELECTRIC_DB_USER=electric ELECTRIC_DB_USER=electric
@ -26,6 +32,11 @@ ELECTRIC_DB_PASSWORD=electric_password
SCHEDULE_CHECKER_INTERVAL=5m SCHEDULE_CHECKER_INTERVAL=5m
SECRET_KEY=SECRET SECRET_KEY=SECRET
# JWT Token Lifetimes (optional, defaults shown)
# ACCESS_TOKEN_LIFETIME_SECONDS=86400 # 1 day
# REFRESH_TOKEN_LIFETIME_SECONDS=1209600 # 2 weeks
NEXT_FRONTEND_URL=http://localhost:3000 NEXT_FRONTEND_URL=http://localhost:3000
# Backend URL for OAuth callbacks (optional, set when behind reverse proxy with HTTPS) # Backend URL for OAuth callbacks (optional, set when behind reverse proxy with HTTPS)

View file

@ -0,0 +1,187 @@
"""Add created_by_id column to documents table for document ownership tracking
Revision ID: 86
Revises: 85
Create Date: 2026-02-02
Changes:
1. Add created_by_id column (UUID, nullable, foreign key to user.id)
2. Create index on created_by_id for performance
3. Backfill existing documents with search space owner's user_id (with progress indicator)
"""
import sys
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "86"
down_revision: str | None = "85"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
# Batch size for backfill operation
BATCH_SIZE = 5000
def upgrade() -> None:
"""Add created_by_id column to documents and backfill with search space owner."""
# 1. Add created_by_id column (nullable for backward compatibility)
print("Step 1/4: Adding created_by_id column...")
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'documents' AND column_name = 'created_by_id'
) THEN
ALTER TABLE documents
ADD COLUMN created_by_id UUID;
END IF;
END$$;
"""
)
print(" Done: created_by_id column added.")
# 2. Create index on created_by_id for efficient queries
print("Step 2/4: Creating index on created_by_id...")
op.execute(
"""
CREATE INDEX IF NOT EXISTS ix_documents_created_by_id
ON documents (created_by_id);
"""
)
print(" Done: Index created.")
# 3. Add foreign key constraint with ON DELETE SET NULL
# First check if constraint already exists
print("Step 3/4: Adding foreign key constraint...")
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.table_constraints
WHERE constraint_name = 'fk_documents_created_by_id'
AND table_name = 'documents'
) THEN
ALTER TABLE documents
ADD CONSTRAINT fk_documents_created_by_id
FOREIGN KEY (created_by_id) REFERENCES "user"(id)
ON DELETE SET NULL;
END IF;
END$$;
"""
)
print(" Done: Foreign key constraint added.")
# 4. Backfill existing documents with search space owner's user_id
# Process in batches with progress indicator
print("Step 4/4: Backfilling created_by_id for existing documents...")
connection = op.get_bind()
# Get total count of documents that need backfilling
result = connection.execute(
sa.text("""
SELECT COUNT(*) FROM documents WHERE created_by_id IS NULL
""")
)
total_count = result.scalar()
if total_count == 0:
print(" No documents need backfilling. Skipping.")
return
print(f" Total documents to backfill: {total_count:,}")
processed = 0
batch_num = 0
while processed < total_count:
batch_num += 1
# Update a batch of documents using a subquery to limit the update
# We use ctid (tuple identifier) for efficient batching in PostgreSQL
result = connection.execute(
sa.text("""
UPDATE documents
SET created_by_id = searchspaces.user_id
FROM searchspaces
WHERE documents.search_space_id = searchspaces.id
AND documents.created_by_id IS NULL
AND documents.id IN (
SELECT d.id FROM documents d
WHERE d.created_by_id IS NULL
LIMIT :batch_size
)
"""),
{"batch_size": BATCH_SIZE},
)
rows_updated = result.rowcount
if rows_updated == 0:
# No more rows to update
break
processed += rows_updated
progress_pct = min(100.0, (processed / total_count) * 100)
# Print progress with carriage return for in-place update
sys.stdout.write(
f"\r Progress: {processed:,}/{total_count:,} documents ({progress_pct:.1f}%) - Batch {batch_num}"
)
sys.stdout.flush()
# Final newline after progress
print()
print(f" Done: Backfilled {processed:,} documents.")
def downgrade() -> None:
"""Remove created_by_id column from documents."""
# Drop foreign key constraint
op.execute(
"""
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM information_schema.table_constraints
WHERE constraint_name = 'fk_documents_created_by_id'
AND table_name = 'documents'
) THEN
ALTER TABLE documents
DROP CONSTRAINT fk_documents_created_by_id;
END IF;
END$$;
"""
)
# Drop index
op.execute(
"""
DROP INDEX IF EXISTS ix_documents_created_by_id;
"""
)
# Drop column
op.execute(
"""
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'documents' AND column_name = 'created_by_id'
) THEN
ALTER TABLE documents
DROP COLUMN created_by_id;
END IF;
END$$;
"""
)

View file

@ -0,0 +1,170 @@
"""Add connector_id column to documents table for linking documents to their source connector
Revision ID: 87
Revises: 86
Create Date: 2026-02-02
Changes:
1. Add connector_id column (Integer, nullable, foreign key to search_source_connectors.id)
2. Create index on connector_id for efficient bulk deletion queries
3. SET NULL on delete - allows controlled cleanup in application code
4. Backfill existing documents based on document_type and search_space_id matching
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "87"
down_revision: str | None = "86"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add connector_id column to documents and backfill from existing connectors."""
# 1. Add connector_id column (nullable - for manually uploaded docs without connector)
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'documents' AND column_name = 'connector_id'
) THEN
ALTER TABLE documents
ADD COLUMN connector_id INTEGER;
END IF;
END$$;
"""
)
# 2. Create index on connector_id for efficient cleanup queries
op.execute(
"""
CREATE INDEX IF NOT EXISTS ix_documents_connector_id
ON documents (connector_id);
"""
)
# 3. Add foreign key constraint with ON DELETE SET NULL
# SET NULL allows us to delete documents in controlled batches before deleting connector
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.table_constraints
WHERE constraint_name = 'fk_documents_connector_id'
AND table_name = 'documents'
) THEN
ALTER TABLE documents
ADD CONSTRAINT fk_documents_connector_id
FOREIGN KEY (connector_id) REFERENCES search_source_connectors(id)
ON DELETE SET NULL;
END IF;
END$$;
"""
)
# 4. Backfill existing documents with connector_id based on document_type matching
# This maps document types to their corresponding connector types
# Only backfills for documents in search spaces that have exactly one connector of that type
# Map of document_type -> connector_type for backfilling
document_connector_mappings = [
("NOTION_CONNECTOR", "NOTION_CONNECTOR"),
("SLACK_CONNECTOR", "SLACK_CONNECTOR"),
("TEAMS_CONNECTOR", "TEAMS_CONNECTOR"),
("GITHUB_CONNECTOR", "GITHUB_CONNECTOR"),
("LINEAR_CONNECTOR", "LINEAR_CONNECTOR"),
("DISCORD_CONNECTOR", "DISCORD_CONNECTOR"),
("JIRA_CONNECTOR", "JIRA_CONNECTOR"),
("CONFLUENCE_CONNECTOR", "CONFLUENCE_CONNECTOR"),
("CLICKUP_CONNECTOR", "CLICKUP_CONNECTOR"),
("GOOGLE_CALENDAR_CONNECTOR", "GOOGLE_CALENDAR_CONNECTOR"),
("GOOGLE_GMAIL_CONNECTOR", "GOOGLE_GMAIL_CONNECTOR"),
("GOOGLE_DRIVE_FILE", "GOOGLE_DRIVE_CONNECTOR"),
("AIRTABLE_CONNECTOR", "AIRTABLE_CONNECTOR"),
("LUMA_CONNECTOR", "LUMA_CONNECTOR"),
("ELASTICSEARCH_CONNECTOR", "ELASTICSEARCH_CONNECTOR"),
("BOOKSTACK_CONNECTOR", "BOOKSTACK_CONNECTOR"),
("CIRCLEBACK", "CIRCLEBACK_CONNECTOR"),
("OBSIDIAN_CONNECTOR", "OBSIDIAN_CONNECTOR"),
("COMPOSIO_GOOGLE_DRIVE_CONNECTOR", "COMPOSIO_GOOGLE_DRIVE_CONNECTOR"),
("COMPOSIO_GMAIL_CONNECTOR", "COMPOSIO_GMAIL_CONNECTOR"),
("COMPOSIO_GOOGLE_CALENDAR_CONNECTOR", "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"),
("CRAWLED_URL", "WEBCRAWLER_CONNECTOR"),
]
for doc_type, connector_type in document_connector_mappings:
# Backfill connector_id for documents where:
# 1. Document has this document_type
# 2. Document doesn't already have a connector_id
# 3. There's exactly one connector of this type in the same search space
# This safely handles most cases while avoiding ambiguity
op.execute(
f"""
UPDATE documents d
SET connector_id = (
SELECT ssc.id
FROM search_source_connectors ssc
WHERE ssc.search_space_id = d.search_space_id
AND ssc.connector_type = '{connector_type}'
LIMIT 1
)
WHERE d.document_type = '{doc_type}'
AND d.connector_id IS NULL
AND EXISTS (
SELECT 1 FROM search_source_connectors ssc
WHERE ssc.search_space_id = d.search_space_id
AND ssc.connector_type = '{connector_type}'
);
"""
)
def downgrade() -> None:
"""Remove connector_id column from documents."""
# Drop foreign key constraint
op.execute(
"""
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM information_schema.table_constraints
WHERE constraint_name = 'fk_documents_connector_id'
AND table_name = 'documents'
) THEN
ALTER TABLE documents
DROP CONSTRAINT fk_documents_connector_id;
END IF;
END$$;
"""
)
# Drop index
op.execute(
"""
DROP INDEX IF EXISTS ix_documents_connector_id;
"""
)
# Drop column
op.execute(
"""
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'documents' AND column_name = 'connector_id'
) THEN
ALTER TABLE documents
DROP COLUMN connector_id;
END IF;
END$$;
"""
)

View file

@ -0,0 +1,58 @@
"""Make podcast_transcript nullable
Revision ID: 88
Revises: 87
Create Date: 2026-02-02
The podcast workflow now creates a podcast record with PENDING status first,
then fills in the transcript after generation completes. This requires
podcast_transcript to be nullable.
"""
from collections.abc import Sequence
from alembic import op
revision: str = "88"
down_revision: str | None = "87"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# Make podcast_transcript nullable and remove the server default
op.execute(
"""
ALTER TABLE podcasts
ALTER COLUMN podcast_transcript DROP NOT NULL;
"""
)
op.execute(
"""
ALTER TABLE podcasts
ALTER COLUMN podcast_transcript DROP DEFAULT;
"""
)
def downgrade() -> None:
# Set empty JSON for any NULL values before adding NOT NULL constraint
op.execute(
"""
UPDATE podcasts
SET podcast_transcript = '{}'::jsonb
WHERE podcast_transcript IS NULL;
"""
)
op.execute(
"""
ALTER TABLE podcasts
ALTER COLUMN podcast_transcript SET DEFAULT '{}';
"""
)
op.execute(
"""
ALTER TABLE podcasts
ALTER COLUMN podcast_transcript SET NOT NULL;
"""
)

View file

@ -0,0 +1,46 @@
"""Make podcast file_location nullable
Revision ID: 89
Revises: 88
Create Date: 2026-02-03
The podcast workflow creates a podcast record with PENDING status first,
then fills in the file_location after audio generation completes. This requires
file_location to be nullable.
"""
from collections.abc import Sequence
from alembic import op
revision: str = "89"
down_revision: str | None = "88"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# Make file_location nullable
op.execute(
"""
ALTER TABLE podcasts
ALTER COLUMN file_location DROP NOT NULL;
"""
)
def downgrade() -> None:
# Set empty string for any NULL values before adding NOT NULL constraint
op.execute(
"""
UPDATE podcasts
SET file_location = ''
WHERE file_location IS NULL;
"""
)
op.execute(
"""
ALTER TABLE podcasts
ALTER COLUMN file_location SET NOT NULL;
"""
)

View file

@ -0,0 +1,66 @@
"""Add public_sharing permissions to existing roles
Revision ID: 90
Revises: 89
Create Date: 2026-02-02
"""
from sqlalchemy import text
from alembic import op
revision = "90"
down_revision = "89"
branch_labels = None
depends_on = None
def upgrade():
connection = op.get_bind()
connection.execute(
text(
"""
UPDATE search_space_roles
SET permissions = array_append(permissions, 'public_sharing:view')
WHERE name IN ('Editor', 'Viewer')
AND NOT ('public_sharing:view' = ANY(permissions))
"""
)
)
connection.execute(
text(
"""
UPDATE search_space_roles
SET permissions = array_append(permissions, 'public_sharing:create')
WHERE name = 'Editor'
AND NOT ('public_sharing:create' = ANY(permissions))
"""
)
)
def downgrade():
connection = op.get_bind()
connection.execute(
text(
"""
UPDATE search_space_roles
SET permissions = array_remove(permissions, 'public_sharing:view')
WHERE name IN ('Editor', 'Viewer')
"""
)
)
connection.execute(
text(
"""
UPDATE search_space_roles
SET permissions = array_remove(permissions, 'public_sharing:create')
WHERE name = 'Editor'
"""
)
)

View file

@ -0,0 +1,33 @@
"""Add DISCORD_JOIN to incentive task type enum
Revision ID: 91
Revises: 90
Changes:
1. Add DISCORD_JOIN value to incentivetasktype enum
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "91"
down_revision: str | None = "90"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add DISCORD_JOIN to incentivetasktype enum."""
op.execute("ALTER TYPE incentivetasktype ADD VALUE IF NOT EXISTS 'DISCORD_JOIN'")
def downgrade() -> None:
"""Remove DISCORD_JOIN from incentivetasktype enum.
Note: PostgreSQL doesn't support removing values from enums directly.
This would require recreating the enum type, which is complex and risky.
For safety, we leave the enum value in place during downgrade.
"""
pass

View file

@ -0,0 +1,92 @@
"""Add refresh_tokens table for user session management
Revision ID: 92
Revises: 91
Changes:
1. Create refresh_tokens table with columns:
- id (primary key)
- user_id (foreign key to user)
- token_hash (unique, indexed)
- expires_at (indexed)
- is_revoked
- family_id (indexed, for token rotation tracking)
- created_at, updated_at (timestamps)
"""
from collections.abc import Sequence
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "92"
down_revision: str | None = "91"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Create refresh_tokens table (idempotent)."""
# Check if table already exists
connection = op.get_bind()
result = connection.execute(
sa.text(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'refresh_tokens')"
)
)
table_exists = result.scalar()
if not table_exists:
op.create_table(
"refresh_tokens",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("user_id", UUID(as_uuid=True), nullable=False),
sa.Column("token_hash", sa.String(256), nullable=False),
sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False),
sa.Column("is_revoked", sa.Boolean(), nullable=False, default=False),
sa.Column("family_id", UUID(as_uuid=True), nullable=False),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
ondelete="CASCADE",
),
)
# Create indexes if they don't exist
op.execute(
"CREATE INDEX IF NOT EXISTS ix_refresh_tokens_user_id ON refresh_tokens (user_id)"
)
op.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS ix_refresh_tokens_token_hash ON refresh_tokens (token_hash)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_refresh_tokens_expires_at ON refresh_tokens (expires_at)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS ix_refresh_tokens_family_id ON refresh_tokens (family_id)"
)
def downgrade() -> None:
"""Drop refresh_tokens table (idempotent)."""
op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_family_id")
op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_expires_at")
op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_token_hash")
op.execute("DROP INDEX IF EXISTS ix_refresh_tokens_user_id")
op.execute("DROP TABLE IF EXISTS refresh_tokens")

View file

@ -21,8 +21,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Podcast, PodcastStatus from app.db import Podcast, PodcastStatus
# Redis connection for tracking active podcast tasks # Redis connection for tracking active podcast tasks
# Uses the same Redis instance as Celery # Defaults to the Celery broker when REDIS_APP_URL is not set
REDIS_URL = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0") REDIS_URL = os.getenv(
"REDIS_APP_URL",
os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"),
)
_redis_client: redis.Redis | None = None _redis_client: redis.Redis | None = None

View file

@ -12,6 +12,7 @@ from app.agents.new_chat.checkpointer import (
from app.config import config, initialize_llm_router from app.config import config, initialize_llm_router
from app.db import User, create_db_and_tables, get_async_session from app.db import User, create_db_and_tables, get_async_session
from app.routes import router as crud_router from app.routes import router as crud_router
from app.routes.auth_routes import router as auth_router
from app.schemas import UserCreate, UserRead, UserUpdate from app.schemas import UserCreate, UserRead, UserUpdate
from app.tasks.surfsense_docs_indexer import seed_surfsense_docs from app.tasks.surfsense_docs_indexer import seed_surfsense_docs
from app.users import SECRET, auth_backend, current_active_user, fastapi_users from app.users import SECRET, auth_backend, current_active_user, fastapi_users
@ -111,6 +112,9 @@ app.include_router(
tags=["users"], tags=["users"],
) )
# Include custom auth routes (refresh token, logout)
app.include_router(auth_router)
if config.AUTH_TYPE == "GOOGLE": if config.AUTH_TYPE == "GOOGLE":
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse

View file

@ -26,6 +26,7 @@ def init_worker(**kwargs):
# Get Celery configuration from environment # Get Celery configuration from environment
CELERY_BROKER_URL = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0") CELERY_BROKER_URL = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", "redis://localhost:6379/0") CELERY_RESULT_BACKEND = os.getenv("CELERY_RESULT_BACKEND", "redis://localhost:6379/0")
CELERY_TASK_DEFAULT_QUEUE = os.getenv("CELERY_TASK_DEFAULT_QUEUE", "surfsense")
# Get schedule checker interval from environment # Get schedule checker interval from environment
# Format: "<number><unit>" where unit is 'm' (minutes) or 'h' (hours) # Format: "<number><unit>" where unit is 'm' (minutes) or 'h' (hours)
@ -80,6 +81,7 @@ celery_app = Celery(
"app.tasks.celery_tasks.blocknote_migration_tasks", "app.tasks.celery_tasks.blocknote_migration_tasks",
"app.tasks.celery_tasks.document_reindex_tasks", "app.tasks.celery_tasks.document_reindex_tasks",
"app.tasks.celery_tasks.stale_notification_cleanup_task", "app.tasks.celery_tasks.stale_notification_cleanup_task",
"app.tasks.celery_tasks.connector_deletion_task",
], ],
) )
@ -91,6 +93,9 @@ celery_app.conf.update(
result_serializer="json", result_serializer="json",
timezone="UTC", timezone="UTC",
enable_utc=True, enable_utc=True,
task_default_queue=CELERY_TASK_DEFAULT_QUEUE,
task_default_exchange=CELERY_TASK_DEFAULT_QUEUE,
task_default_routing_key=CELERY_TASK_DEFAULT_QUEUE,
# Task execution settings # Task execution settings
task_track_started=True, task_track_started=True,
task_time_limit=28800, # 8 hour hard limit task_time_limit=28800, # 8 hour hard limit

View file

@ -255,6 +255,14 @@ class Config:
# OAuth JWT # OAuth JWT
SECRET_KEY = os.getenv("SECRET_KEY") SECRET_KEY = os.getenv("SECRET_KEY")
# JWT Token Lifetimes
ACCESS_TOKEN_LIFETIME_SECONDS = int(
os.getenv("ACCESS_TOKEN_LIFETIME_SECONDS", str(24 * 60 * 60)) # 1 day
)
REFRESH_TOKEN_LIFETIME_SECONDS = int(
os.getenv("REFRESH_TOKEN_LIFETIME_SECONDS", str(14 * 24 * 60 * 60)) # 2 weeks
)
# ETL Service # ETL Service
ETL_SERVICE = os.getenv("ETL_SERVICE") ETL_SERVICE = os.getenv("ETL_SERVICE")

View file

@ -122,8 +122,52 @@ global_llm_configs:
use_default_system_instructions: false use_default_system_instructions: false
citations_enabled: true citations_enabled: true
# Example: Groq - Fast inference # Example: Azure OpenAI GPT-4o
# IMPORTANT: For Azure deployments, always include 'base_model' in litellm_params
# to enable accurate token counting, cost tracking, and max token limits
- id: -5 - id: -5
name: "Global Azure GPT-4o"
description: "Azure OpenAI GPT-4o deployment"
provider: "AZURE"
# model_name format for Azure: azure/<your-deployment-name>
model_name: "azure/gpt-4o-deployment"
api_key: "your-azure-api-key-here"
api_base: "https://your-resource.openai.azure.com"
api_version: "2024-02-15-preview" # Azure API version
rpm: 1000
tpm: 150000
litellm_params:
temperature: 0.7
max_tokens: 4000
# REQUIRED for Azure: Specify the underlying OpenAI model
# This fixes "Could not identify azure model" warnings
# Common base_model values: gpt-4, gpt-4-turbo, gpt-4o, gpt-4o-mini, gpt-3.5-turbo
base_model: "gpt-4o"
system_instructions: ""
use_default_system_instructions: true
citations_enabled: true
# Example: Azure OpenAI GPT-4 Turbo
- id: -6
name: "Global Azure GPT-4 Turbo"
description: "Azure OpenAI GPT-4 Turbo deployment"
provider: "AZURE"
model_name: "azure/gpt-4-turbo-deployment"
api_key: "your-azure-api-key-here"
api_base: "https://your-resource.openai.azure.com"
api_version: "2024-02-15-preview"
rpm: 500
tpm: 100000
litellm_params:
temperature: 0.7
max_tokens: 4000
base_model: "gpt-4-turbo" # Maps to gpt-4-turbo-preview
system_instructions: ""
use_default_system_instructions: true
citations_enabled: true
# Example: Groq - Fast inference
- id: -7
name: "Global Groq Llama 3" name: "Global Groq Llama 3"
description: "Ultra-fast Llama 3 70B via Groq" description: "Ultra-fast Llama 3 70B via Groq"
provider: "GROQ" provider: "GROQ"
@ -150,3 +194,11 @@ global_llm_configs:
# - All standard LiteLLM providers are supported # - All standard LiteLLM providers are supported
# - rpm/tpm: Optional rate limits for load balancing (requests/tokens per minute) # - rpm/tpm: Optional rate limits for load balancing (requests/tokens per minute)
# These help the router distribute load evenly and avoid rate limit errors # These help the router distribute load evenly and avoid rate limit errors
#
# AZURE-SPECIFIC NOTES:
# - Always add 'base_model' in litellm_params for Azure deployments
# - This fixes "Could not identify azure model 'X'" warnings
# - base_model should match the underlying OpenAI model (e.g., gpt-4o, gpt-4-turbo, gpt-3.5-turbo)
# - model_name format: "azure/<your-deployment-name>"
# - api_version: Use a recent Azure API version (e.g., "2024-02-15-preview")
# - See: https://docs.litellm.ai/docs/proxy/cost_tracking#spend-tracking-for-azure-openai-models

View file

@ -71,6 +71,14 @@ class AirtableHistoryConnector:
config_data = connector.config.copy() config_data = connector.config.copy()
# Check if access_token exists before processing
raw_access_token = config_data.get("access_token")
if not raw_access_token:
raise ValueError(
"Airtable access token not found. "
"Please reconnect your Airtable account."
)
# Decrypt credentials if they are encrypted # Decrypt credentials if they are encrypted
token_encrypted = config_data.get("_token_encrypted", False) token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY: if token_encrypted and config.SECRET_KEY:
@ -98,6 +106,14 @@ class AirtableHistoryConnector:
f"Failed to decrypt Airtable credentials: {e!s}" f"Failed to decrypt Airtable credentials: {e!s}"
) from e ) from e
# Final validation after decryption
final_token = config_data.get("access_token")
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
raise ValueError(
"Airtable access token is invalid or empty. "
"Please reconnect your Airtable account."
)
try: try:
self._credentials = AirtableAuthCredentialsBase.from_dict(config_data) self._credentials = AirtableAuthCredentialsBase.from_dict(config_data)
except Exception as e: except Exception as e:

View file

@ -394,6 +394,8 @@ async def _process_gmail_message_batch(
embedding=summary_embedding, embedding=summary_embedding,
chunks=chunks, chunks=chunks,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
) )
session.add(document) session.add(document)
documents_indexed += 1 documents_indexed += 1

View file

@ -442,6 +442,8 @@ async def index_composio_google_calendar(
embedding=summary_embedding, embedding=summary_embedding,
chunks=chunks, chunks=chunks,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
) )
session.add(document) session.add(document)
documents_indexed += 1 documents_indexed += 1

View file

@ -4,6 +4,7 @@ Composio Google Drive Connector Module.
Provides Google Drive specific methods for data retrieval and indexing via Composio. Provides Google Drive specific methods for data retrieval and indexing via Composio.
""" """
import contextlib
import hashlib import hashlib
import json import json
import logging import logging
@ -179,13 +180,14 @@ class ComposioGoogleDriveConnector(ComposioConnector):
) )
async def get_drive_file_content( async def get_drive_file_content(
self, file_id: str self, file_id: str, original_mime_type: str | None = None
) -> tuple[bytes | None, str | None]: ) -> tuple[bytes | None, str | None]:
""" """
Download file content from Google Drive via Composio. Download file content from Google Drive via Composio.
Args: Args:
file_id: Google Drive file ID. file_id: Google Drive file ID.
original_mime_type: Original MIME type (used to detect Google Workspace files for export).
Returns: Returns:
Tuple of (file content bytes, error message). Tuple of (file content bytes, error message).
@ -200,6 +202,31 @@ class ComposioGoogleDriveConnector(ComposioConnector):
connected_account_id=connected_account_id, connected_account_id=connected_account_id,
entity_id=entity_id, entity_id=entity_id,
file_id=file_id, file_id=file_id,
original_mime_type=original_mime_type,
)
async def get_file_metadata(
self, file_id: str
) -> tuple[dict[str, Any] | None, str | None]:
"""
Get metadata for a specific file from Google Drive.
Args:
file_id: The ID of the file to get metadata for.
Returns:
Tuple of (metadata dict, error message).
"""
connected_account_id = await self.get_connected_account_id()
if not connected_account_id:
return None, "No connected account ID found"
entity_id = await self.get_entity_id()
service = await self._get_service()
return await service.get_file_metadata(
connected_account_id=connected_account_id,
entity_id=entity_id,
file_id=file_id,
) )
async def get_drive_start_page_token(self) -> tuple[str | None, str | None]: async def get_drive_start_page_token(self) -> tuple[str | None, str | None]:
@ -292,8 +319,10 @@ async def _process_file_content(
if isinstance(content, str): if isinstance(content, str):
content = content.encode("utf-8") content = content.encode("utf-8")
# Check if this is a binary file # Check if this is a binary file based on extension or MIME type
if _is_binary_file(file_name, mime_type): is_binary = _is_binary_file(file_name, mime_type)
if is_binary:
# Use ETL service for binary files (PDF, Office docs, etc.) # Use ETL service for binary files (PDF, Office docs, etc.)
temp_file_path = None temp_file_path = None
try: try:
@ -316,7 +345,7 @@ async def _process_file_content(
return extracted_text return extracted_text
else: else:
# Fallback if extraction fails # Fallback if extraction fails
logger.warning(f"Could not extract text from binary file {file_name}") logger.warning(f"ETL returned empty for binary file {file_name}")
return f"# {file_name}\n\n[Binary file - text extraction failed]\n\n**File ID:** {file_id}\n**Type:** {mime_type}\n" return f"# {file_name}\n\n[Binary file - text extraction failed]\n\n**File ID:** {file_id}\n**Type:** {mime_type}\n"
except Exception as e: except Exception as e:
@ -327,10 +356,8 @@ async def _process_file_content(
finally: finally:
# Cleanup temp file # Cleanup temp file
if temp_file_path and os.path.exists(temp_file_path): if temp_file_path and os.path.exists(temp_file_path):
try: with contextlib.suppress(Exception):
os.unlink(temp_file_path) os.unlink(temp_file_path)
except Exception as e:
logger.debug(f"Could not delete temp file {temp_file_path}: {e}")
else: else:
# Text file - try to decode as UTF-8 # Text file - try to decode as UTF-8
try: try:
@ -372,9 +399,13 @@ async def _extract_text_with_etl(
from logging import ERROR, getLogger from logging import ERROR, getLogger
etl_service = config.ETL_SERVICE etl_service = config.ETL_SERVICE
logger.debug(
f"[_extract_text_with_etl] START - file_path={file_path}, file_name={file_name}, etl_service={etl_service}"
)
try: try:
if etl_service == "UNSTRUCTURED": if etl_service == "UNSTRUCTURED":
logger.debug("[_extract_text_with_etl] Using UNSTRUCTURED ETL")
from langchain_unstructured import UnstructuredLoader from langchain_unstructured import UnstructuredLoader
from app.utils.document_converters import convert_document_to_markdown from app.utils.document_converters import convert_document_to_markdown
@ -390,11 +421,20 @@ async def _extract_text_with_etl(
) )
docs = await loader.aload() docs = await loader.aload()
logger.debug(
f"[_extract_text_with_etl] UNSTRUCTURED loaded {len(docs) if docs else 0} docs"
)
if docs: if docs:
return await convert_document_to_markdown(docs) result = await convert_document_to_markdown(docs)
logger.debug(
f"[_extract_text_with_etl] UNSTRUCTURED result: {len(result) if result else 0} chars"
)
return result
logger.debug("[_extract_text_with_etl] UNSTRUCTURED returned no docs")
return None return None
elif etl_service == "LLAMACLOUD": elif etl_service == "LLAMACLOUD":
logger.debug("[_extract_text_with_etl] Using LLAMACLOUD ETL")
from app.tasks.document_processors.file_processors import ( from app.tasks.document_processors.file_processors import (
parse_with_llamacloud_retry, parse_with_llamacloud_retry,
) )
@ -413,11 +453,22 @@ async def _extract_text_with_etl(
markdown_documents = await result.aget_markdown_documents( markdown_documents = await result.aget_markdown_documents(
split_by_page=False split_by_page=False
) )
logger.debug(
f"[_extract_text_with_etl] LLAMACLOUD got {len(markdown_documents) if markdown_documents else 0} markdown docs"
)
if markdown_documents: if markdown_documents:
return markdown_documents[0].text text = markdown_documents[0].text
logger.debug(
f"[_extract_text_with_etl] LLAMACLOUD result: {len(text) if text else 0} chars"
)
return text
logger.debug(
"[_extract_text_with_etl] LLAMACLOUD returned no markdown docs"
)
return None return None
elif etl_service == "DOCLING": elif etl_service == "DOCLING":
logger.debug("[_extract_text_with_etl] Using DOCLING ETL")
from app.services.docling_service import create_docling_service from app.services.docling_service import create_docling_service
docling_service = create_docling_service() docling_service = create_docling_service()
@ -441,16 +492,30 @@ async def _extract_text_with_etl(
result = await docling_service.process_document( result = await docling_service.process_document(
file_path, file_name file_path, file_name
) )
logger.debug(
f"[_extract_text_with_etl] DOCLING result keys: {list(result.keys()) if result else 'None'}"
)
finally: finally:
pdfminer_logger.setLevel(original_level) pdfminer_logger.setLevel(original_level)
return result.get("content") content = result.get("content")
logger.debug(
f"[_extract_text_with_etl] DOCLING content: {len(content) if content else 0} chars"
)
return content
else: else:
logger.warning(f"Unknown ETL service: {etl_service}") logger.warning(
f"[_extract_text_with_etl] Unknown ETL service: {etl_service}"
)
return None return None
except Exception as e: except Exception as e:
logger.error(f"ETL extraction failed for {file_name}: {e!s}") logger.error(
f"[_extract_text_with_etl] ETL extraction EXCEPTION for {file_name}: {e!s}"
)
import traceback
logger.error(f"[_extract_text_with_etl] Traceback: {traceback.format_exc()}")
return None return None
@ -979,7 +1044,7 @@ async def _index_composio_drive_full_scan(
all_files.extend(folder_files[:max_files_per_folder]) all_files.extend(folder_files[:max_files_per_folder])
logger.info(f"Found {len(folder_files)} files in folder {folder_name}") logger.info(f"Found {len(folder_files)} files in folder {folder_name}")
# Add specifically selected files # Add specifically selected files - fetch metadata to get mimeType
for selected_file in selected_files: for selected_file in selected_files:
file_id = selected_file.get("id") file_id = selected_file.get("id")
file_name = selected_file.get("name", "Unknown") file_name = selected_file.get("name", "Unknown")
@ -987,14 +1052,35 @@ async def _index_composio_drive_full_scan(
if not file_id: if not file_id:
continue continue
# Add file info (we'll fetch content later during indexing) # Fetch file metadata to get proper mimeType
all_files.append( metadata, meta_error = await composio_connector.get_file_metadata(file_id)
{ if metadata and not meta_error:
"id": file_id, all_files.append(
"name": file_name, {
"mimeType": "", # Will be determined later "id": file_id,
} "name": metadata.get("name") or file_name,
) "mimeType": metadata.get("mimeType", ""),
"modifiedTime": metadata.get("modifiedTime", ""),
"createdTime": metadata.get("createdTime", ""),
}
)
logger.info(
f"Fetched metadata for UI-selected file: {file_name} "
f"(mimeType={metadata.get('mimeType', 'unknown')})"
)
else:
# Fallback if metadata fetch fails - content-based detection will handle it
logger.warning(
f"Could not fetch metadata for file {file_name}: {meta_error}. "
f"Falling back to content-based detection."
)
all_files.append(
{
"id": file_id,
"name": file_name,
"mimeType": "", # Content-based detection will handle this
}
)
else: else:
# No selection specified - fetch all files (original behavior) # No selection specified - fetch all files (original behavior)
page_token = None page_token = None
@ -1128,8 +1214,10 @@ async def _process_single_drive_file(
session, unique_identifier_hash session, unique_identifier_hash
) )
# Get file content # Get file content (pass mime_type for Google Workspace export handling)
content, content_error = await composio_connector.get_drive_file_content(file_id) content, content_error = await composio_connector.get_drive_file_content(
file_id, original_mime_type=mime_type
)
if content_error or not content: if content_error or not content:
logger.warning(f"Could not get content for file {file_name}: {content_error}") logger.warning(f"Could not get content for file {file_name}: {content_error}")
@ -1248,7 +1336,6 @@ async def _process_single_drive_file(
"file_name": file_name, "file_name": file_name,
"FILE_NAME": file_name, # For compatibility "FILE_NAME": file_name, # For compatibility
"mime_type": mime_type, "mime_type": mime_type,
"connector_id": connector_id,
"toolkit_id": "googledrive", "toolkit_id": "googledrive",
"source": "composio", "source": "composio",
}, },
@ -1258,6 +1345,8 @@ async def _process_single_drive_file(
embedding=summary_embedding, embedding=summary_embedding,
chunks=chunks, chunks=chunks,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
) )
session.add(document) session.add(document)

View file

@ -87,6 +87,14 @@ class ConfluenceHistoryConnector:
if is_oauth: if is_oauth:
# OAuth 2.0 authentication # OAuth 2.0 authentication
# Check if access_token exists before processing
raw_access_token = config_data.get("access_token")
if not raw_access_token:
raise ValueError(
"Confluence access token not found. "
"Please reconnect your Confluence account."
)
# Decrypt credentials if they are encrypted # Decrypt credentials if they are encrypted
token_encrypted = config_data.get("_token_encrypted", False) token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY: if token_encrypted and config.SECRET_KEY:
@ -118,6 +126,14 @@ class ConfluenceHistoryConnector:
f"Failed to decrypt Confluence credentials: {e!s}" f"Failed to decrypt Confluence credentials: {e!s}"
) from e ) from e
# Final validation after decryption
final_token = config_data.get("access_token")
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
raise ValueError(
"Confluence access token is invalid or empty. "
"Please reconnect your Confluence account."
)
try: try:
self._credentials = AtlassianAuthCredentialsBase.from_dict( self._credentials = AtlassianAuthCredentialsBase.from_dict(
config_data config_data

View file

@ -25,6 +25,7 @@ async def download_and_process_file(
session: AsyncSession, session: AsyncSession,
task_logger: TaskLoggingService, task_logger: TaskLoggingService,
log_entry: Log, log_entry: Log,
connector_id: int | None = None,
) -> tuple[Any, str | None, dict[str, Any] | None]: ) -> tuple[Any, str | None, dict[str, Any] | None]:
""" """
Download Google Drive file and process using Surfsense file processors. Download Google Drive file and process using Surfsense file processors.
@ -37,6 +38,7 @@ async def download_and_process_file(
session: Database session session: Database session
task_logger: Task logging service task_logger: Task logging service
log_entry: Log entry for tracking log_entry: Log entry for tracking
connector_id: ID of the connector (for de-indexing support)
Returns: Returns:
Tuple of (Document object if successful, error message if failed, file metadata dict) Tuple of (Document object if successful, error message if failed, file metadata dict)
@ -92,6 +94,9 @@ async def download_and_process_file(
"source_connector": "google_drive", "source_connector": "google_drive",
}, },
} }
# Include connector_id for de-indexing support
if connector_id is not None:
connector_info["connector_id"] = connector_id
# Add additional Drive metadata if available # Add additional Drive metadata if available
if "modifiedTime" in file: if "modifiedTime" in file:

View file

@ -127,7 +127,12 @@ async def get_valid_credentials(
) )
creds_dict["_token_encrypted"] = True creds_dict["_token_encrypted"] = True
connector.config = creds_dict # IMPORTANT: Merge new credentials with existing config to preserve
# user settings like selected_folders, selected_files, indexing_options,
# folder_tokens, etc. that would otherwise be wiped on token refresh.
existing_config = connector.config.copy() if connector.config else {}
existing_config.update(creds_dict)
connector.config = existing_config
flag_modified(connector, "config") flag_modified(connector, "config")
await session.commit() await session.commit()

View file

@ -86,6 +86,14 @@ class JiraHistoryConnector:
if is_oauth: if is_oauth:
# OAuth 2.0 authentication # OAuth 2.0 authentication
# Check if access_token exists before processing
raw_access_token = config_data.get("access_token")
if not raw_access_token:
raise ValueError(
"Jira access token not found. "
"Please reconnect your Jira account."
)
if not config.SECRET_KEY: if not config.SECRET_KEY:
raise ValueError( raise ValueError(
"SECRET_KEY not configured but tokens are marked as encrypted" "SECRET_KEY not configured but tokens are marked as encrypted"
@ -119,6 +127,14 @@ class JiraHistoryConnector:
f"Failed to decrypt Jira credentials: {e!s}" f"Failed to decrypt Jira credentials: {e!s}"
) from e ) from e
# Final validation after decryption
final_token = config_data.get("access_token")
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
raise ValueError(
"Jira access token is invalid or empty. "
"Please reconnect your Jira account."
)
try: try:
self._credentials = AtlassianAuthCredentialsBase.from_dict( self._credentials = AtlassianAuthCredentialsBase.from_dict(
config_data config_data

View file

@ -116,6 +116,14 @@ class LinearConnector:
config_data = connector.config.copy() config_data = connector.config.copy()
# Check if access_token exists before processing
raw_access_token = config_data.get("access_token")
if not raw_access_token:
raise ValueError(
"Linear access token not found. "
"Please reconnect your Linear account."
)
# Decrypt credentials if they are encrypted # Decrypt credentials if they are encrypted
token_encrypted = config_data.get("_token_encrypted", False) token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY: if token_encrypted and config.SECRET_KEY:
@ -143,6 +151,14 @@ class LinearConnector:
f"Failed to decrypt Linear credentials: {e!s}" f"Failed to decrypt Linear credentials: {e!s}"
) from e ) from e
# Final validation after decryption
final_token = config_data.get("access_token")
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
raise ValueError(
"Linear access token is invalid or empty. "
"Please reconnect your Linear account."
)
try: try:
self._credentials = LinearAuthCredentialsBase.from_dict(config_data) self._credentials = LinearAuthCredentialsBase.from_dict(config_data)
except Exception as e: except Exception as e:

View file

@ -164,6 +164,7 @@ class IncentiveTaskType(str, Enum):
GITHUB_STAR = "GITHUB_STAR" GITHUB_STAR = "GITHUB_STAR"
REDDIT_FOLLOW = "REDDIT_FOLLOW" REDDIT_FOLLOW = "REDDIT_FOLLOW"
DISCORD_JOIN = "DISCORD_JOIN"
# Future tasks can be added here: # Future tasks can be added here:
# GITHUB_ISSUE = "GITHUB_ISSUE" # GITHUB_ISSUE = "GITHUB_ISSUE"
# SOCIAL_SHARE = "SOCIAL_SHARE" # SOCIAL_SHARE = "SOCIAL_SHARE"
@ -185,6 +186,12 @@ INCENTIVE_TASKS_CONFIG = {
"pages_reward": 100, "pages_reward": 100,
"action_url": "https://www.reddit.com/r/SurfSense/", "action_url": "https://www.reddit.com/r/SurfSense/",
}, },
IncentiveTaskType.DISCORD_JOIN: {
"title": "Join our Discord",
"description": "Join the SurfSense community on Discord",
"pages_reward": 100,
"action_url": "https://discord.gg/ejRNvftDp9",
},
# Future tasks can be configured here: # Future tasks can be configured here:
# IncentiveTaskType.GITHUB_ISSUE: { # IncentiveTaskType.GITHUB_ISSUE: {
# "title": "Create an issue", # "title": "Create an issue",
@ -257,6 +264,11 @@ class Permission(str, Enum):
SETTINGS_UPDATE = "settings:update" SETTINGS_UPDATE = "settings:update"
SETTINGS_DELETE = "settings:delete" # Delete the entire search space SETTINGS_DELETE = "settings:delete" # Delete the entire search space
# Public Sharing
PUBLIC_SHARING_VIEW = "public_sharing:view"
PUBLIC_SHARING_CREATE = "public_sharing:create"
PUBLIC_SHARING_DELETE = "public_sharing:delete"
# Full access wildcard # Full access wildcard
FULL_ACCESS = "*" FULL_ACCESS = "*"
@ -299,6 +311,9 @@ DEFAULT_ROLE_PERMISSIONS = {
Permission.ROLES_READ.value, Permission.ROLES_READ.value,
# Settings (view only, no update or delete) # Settings (view only, no update or delete)
Permission.SETTINGS_VIEW.value, Permission.SETTINGS_VIEW.value,
# Public Sharing (can create and view, no delete)
Permission.PUBLIC_SHARING_VIEW.value,
Permission.PUBLIC_SHARING_CREATE.value,
], ],
"Viewer": [ "Viewer": [
# Documents (read only) # Documents (read only)
@ -322,6 +337,8 @@ DEFAULT_ROLE_PERMISSIONS = {
Permission.ROLES_READ.value, Permission.ROLES_READ.value,
# Settings (view only) # Settings (view only)
Permission.SETTINGS_VIEW.value, Permission.SETTINGS_VIEW.value,
# Public Sharing (view only)
Permission.PUBLIC_SHARING_VIEW.value,
], ],
} }
@ -751,7 +768,27 @@ class Document(BaseModel, TimestampMixin):
search_space_id = Column( search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
) )
# Track who created/uploaded this document
created_by_id = Column(
UUID(as_uuid=True),
ForeignKey("user.id", ondelete="SET NULL"),
nullable=True, # Nullable for backward compatibility with existing records
index=True,
)
# Track which connector created this document (for cleanup on connector deletion)
connector_id = Column(
Integer,
ForeignKey("search_source_connectors.id", ondelete="SET NULL"),
nullable=True, # Nullable for manually uploaded docs without connector
index=True,
)
# Relationships
search_space = relationship("SearchSpace", back_populates="documents") search_space = relationship("SearchSpace", back_populates="documents")
created_by = relationship("User", back_populates="documents")
connector = relationship("SearchSourceConnector", back_populates="documents")
chunks = relationship( chunks = relationship(
"Chunk", back_populates="document", cascade="all, delete-orphan" "Chunk", back_populates="document", cascade="all, delete-orphan"
) )
@ -980,6 +1017,9 @@ class SearchSourceConnector(BaseModel, TimestampMixin):
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
) )
# Documents created by this connector (for cleanup on connector deletion)
documents = relationship("Document", back_populates="connector")
class NewLLMConfig(BaseModel, TimestampMixin): class NewLLMConfig(BaseModel, TimestampMixin):
""" """
@ -1286,6 +1326,13 @@ if config.AUTH_TYPE == "GOOGLE":
passive_deletes=True, passive_deletes=True,
) )
# Documents created/uploaded by this user
documents = relationship(
"Document",
back_populates="created_by",
passive_deletes=True,
)
# User memories for personalized AI responses # User memories for personalized AI responses
memories = relationship( memories = relationship(
"UserMemory", "UserMemory",
@ -1314,6 +1361,13 @@ if config.AUTH_TYPE == "GOOGLE":
display_name = Column(String, nullable=True) display_name = Column(String, nullable=True)
avatar_url = Column(String, nullable=True) avatar_url = Column(String, nullable=True)
# Refresh tokens for this user
refresh_tokens = relationship(
"RefreshToken",
back_populates="user",
cascade="all, delete-orphan",
)
else: else:
class User(SQLAlchemyBaseUserTableUUID, Base): class User(SQLAlchemyBaseUserTableUUID, Base):
@ -1344,6 +1398,13 @@ else:
passive_deletes=True, passive_deletes=True,
) )
# Documents created/uploaded by this user
documents = relationship(
"Document",
back_populates="created_by",
passive_deletes=True,
)
# User memories for personalized AI responses # User memories for personalized AI responses
memories = relationship( memories = relationship(
"UserMemory", "UserMemory",
@ -1372,6 +1433,43 @@ else:
display_name = Column(String, nullable=True) display_name = Column(String, nullable=True)
avatar_url = Column(String, nullable=True) avatar_url = Column(String, nullable=True)
# Refresh tokens for this user
refresh_tokens = relationship(
"RefreshToken",
back_populates="user",
cascade="all, delete-orphan",
)
class RefreshToken(Base, TimestampMixin):
"""
Stores refresh tokens for user session management.
Each row represents one device/session.
"""
__tablename__ = "refresh_tokens"
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(
UUID(as_uuid=True),
ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
user = relationship("User", back_populates="refresh_tokens")
token_hash = Column(String(256), unique=True, nullable=False, index=True)
expires_at = Column(TIMESTAMP(timezone=True), nullable=False, index=True)
is_revoked = Column(Boolean, default=False, nullable=False)
family_id = Column(UUID(as_uuid=True), nullable=False, index=True)
@property
def is_expired(self) -> bool:
return datetime.now(UTC) >= self.expires_at
@property
def is_valid(self) -> bool:
return not self.is_expired and not self.is_revoked
engine = create_async_engine(DATABASE_URL) engine = create_async_engine(DATABASE_URL)
async_session_maker = async_sessionmaker(engine, expire_on_commit=False) async_session_maker = async_sessionmaker(engine, expire_on_commit=False)

View file

@ -104,3 +104,33 @@ SUMMARY_PROMPT = (
SUMMARY_PROMPT_TEMPLATE = PromptTemplate( SUMMARY_PROMPT_TEMPLATE = PromptTemplate(
input_variables=["document"], template=SUMMARY_PROMPT input_variables=["document"], template=SUMMARY_PROMPT
) )
# =============================================================================
# Chat Title Generation Prompt
# =============================================================================
TITLE_GENERATION_PROMPT = """Generate a concise, descriptive title for the following conversation.
<rules>
- The title MUST be between 1 and 6 words
- The title MUST be on a single line
- Capture the main topic or intent of the conversation
- Do NOT use quotes, punctuation, or formatting
- Do NOT include words like "Chat about" or "Discussion of"
- Return ONLY the title, nothing else
</rules>
<user_query>
{user_query}
</user_query>
<assistant_response>
{assistant_response}
</assistant_response>
Title:"""
TITLE_GENERATION_PROMPT_TEMPLATE = PromptTemplate(
input_variables=["user_query", "assistant_response"],
template=TITLE_GENERATION_PROMPT,
)

View file

@ -0,0 +1,93 @@
"""Authentication routes for refresh token management."""
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from app.db import User, async_session_maker
from app.schemas.auth import (
LogoutAllResponse,
LogoutRequest,
LogoutResponse,
RefreshTokenRequest,
RefreshTokenResponse,
)
from app.users import current_active_user, get_jwt_strategy
from app.utils.refresh_tokens import (
revoke_all_user_tokens,
revoke_refresh_token,
rotate_refresh_token,
validate_refresh_token,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/auth/jwt", tags=["auth"])
@router.post("/refresh", response_model=RefreshTokenResponse)
async def refresh_access_token(request: RefreshTokenRequest):
"""
Exchange a valid refresh token for a new access token and refresh token.
Implements token rotation for security.
"""
token_record = await validate_refresh_token(request.refresh_token)
if not token_record:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token",
)
# Get user from token record
async with async_session_maker() as session:
result = await session.execute(
select(User).where(User.id == token_record.user_id)
)
user = result.scalars().first()
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found",
)
# Generate new access token
strategy = get_jwt_strategy()
access_token = await strategy.write_token(user)
# Rotate refresh token
new_refresh_token = await rotate_refresh_token(token_record)
logger.info(f"Refreshed token for user {user.id}")
return RefreshTokenResponse(
access_token=access_token,
refresh_token=new_refresh_token,
)
@router.post("/revoke", response_model=LogoutResponse)
async def revoke_token(request: LogoutRequest):
"""
Logout current device by revoking the provided refresh token.
Does not require authentication - just the refresh token.
"""
revoked = await revoke_refresh_token(request.refresh_token)
if revoked:
logger.info("User logged out from current device - token revoked")
else:
logger.warning("Logout called but no matching token found to revoke")
return LogoutResponse()
@router.post("/logout-all", response_model=LogoutAllResponse)
async def logout_all_devices(user: User = Depends(current_active_user)):
"""
Logout from all devices by revoking all refresh tokens for the user.
Requires valid access token.
"""
await revoke_all_user_tokens(user.id)
logger.info(f"User {user.id} logged out from all devices")
return LogoutAllResponse()

View file

@ -9,8 +9,12 @@ import logging
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import SearchSourceConnector, SearchSourceConnectorType, get_async_session
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -212,6 +216,7 @@ def format_circleback_meeting_to_markdown(payload: CirclebackWebhookPayload) ->
async def receive_circleback_webhook( async def receive_circleback_webhook(
search_space_id: int, search_space_id: int,
payload: CirclebackWebhookPayload, payload: CirclebackWebhookPayload,
session: AsyncSession = Depends(get_async_session),
): ):
""" """
Receive and process a Circleback webhook. Receive and process a Circleback webhook.
@ -223,6 +228,7 @@ async def receive_circleback_webhook(
Args: Args:
search_space_id: The ID of the search space to save the document to search_space_id: The ID of the search space to save the document to
payload: The Circleback webhook payload containing meeting data payload: The Circleback webhook payload containing meeting data
session: Database session for looking up the connector
Returns: Returns:
Success message with document details Success message with document details
@ -236,6 +242,26 @@ async def receive_circleback_webhook(
f"Received Circleback webhook for meeting {payload.id} in search space {search_space_id}" f"Received Circleback webhook for meeting {payload.id} in search space {search_space_id}"
) )
# Look up the Circleback connector for this search space
connector_result = await session.execute(
select(SearchSourceConnector.id).where(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CIRCLEBACK_CONNECTOR,
)
)
connector_id = connector_result.scalar_one_or_none()
if connector_id:
logger.info(
f"Found Circleback connector {connector_id} for search space {search_space_id}"
)
else:
logger.warning(
f"No Circleback connector found for search space {search_space_id}. "
"Document will be created without connector_id."
)
# Convert to markdown # Convert to markdown
markdown_content = format_circleback_meeting_to_markdown(payload) markdown_content = format_circleback_meeting_to_markdown(payload)
@ -264,6 +290,7 @@ async def receive_circleback_webhook(
markdown_content=markdown_content, markdown_content=markdown_content,
metadata=meeting_metadata, metadata=meeting_metadata,
search_space_id=search_space_id, search_space_id=search_space_id,
connector_id=connector_id,
) )
logger.info( logger.info(

View file

@ -20,6 +20,7 @@ from pydantic import ValidationError
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.config import config from app.config import config
from app.db import ( from app.db import (
@ -41,10 +42,6 @@ from app.utils.connector_naming import (
) )
from app.utils.oauth_security import OAuthStateManager from app.utils.oauth_security import OAuthStateManager
# Note: We no longer use check_duplicate_connector for Composio connectors because
# Composio generates a new connected_account_id each time, even for the same Google account.
# Instead, we check for existing connectors by type/space/user and update them.
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@ -255,11 +252,6 @@ async def composio_callback(
"connectedAccountId" "connectedAccountId"
) or query_params.get("connected_account_id") ) or query_params.get("connected_account_id")
# DEBUG: Log query parameter received
logger.info(
f"DEBUG: Callback received - connectedAccountId: {query_params.get('connectedAccountId')}, connected_account_id: {query_params.get('connected_account_id')}, using: {final_connected_account_id}"
)
# If we still don't have a connected_account_id, warn but continue # If we still don't have a connected_account_id, warn but continue
# (the connector will be created but indexing won't work until updated) # (the connector will be created but indexing won't work until updated)
if not final_connected_account_id: if not final_connected_account_id:
@ -272,6 +264,9 @@ async def composio_callback(
f"Successfully got connected_account_id: {final_connected_account_id}" f"Successfully got connected_account_id: {final_connected_account_id}"
) )
# Build entity_id for Composio API calls (same format as used in initiate)
entity_id = f"surfsense_{user_id}"
# Build connector config # Build connector config
connector_config = { connector_config = {
"composio_connected_account_id": final_connected_account_id, "composio_connected_account_id": final_connected_account_id,
@ -289,20 +284,51 @@ async def composio_callback(
) )
connector_type = SearchSourceConnectorType(connector_type_str) connector_type = SearchSourceConnectorType(connector_type_str)
# Check for existing connector of the same type for this user/space # Get the base name for this connector type (e.g., "Google Drive", "Gmail")
# When reconnecting, Composio gives a new connected_account_id, so we need to base_name = get_base_name_for_type(connector_type)
# check by connector_type, user_id, and search_space_id instead of connected_account_id
# FIRST: Get the email for this connected account
# This is needed to determine if it's a reconnection (same email) or new account
email = None
try:
email = await service.get_connected_account_email(
connected_account_id=final_connected_account_id,
entity_id=entity_id,
toolkit_id=toolkit_id,
)
if email:
logger.info(f"Retrieved email {email} for {toolkit_id} connector")
except Exception as email_error:
logger.warning(f"Could not get email for connector: {email_error!s}")
# Generate the connector name (with email if available)
# Format: "Gmail (Composio) - john@gmail.com" or "Gmail (Composio) 1" if no email
if email:
connector_name = f"{base_name} (Composio) - {email}"
else:
# Fallback to generic naming if email not available
count = await count_connectors_of_type(
session, connector_type, space_id, user_id
)
if count == 0:
connector_name = f"{base_name} (Composio) 1"
else:
connector_name = f"{base_name} (Composio) {count + 1}"
# Check if a connector with this SAME name already exists (reconnection case)
# This allows multiple accounts (different emails) while supporting reconnection
existing_connector_result = await session.execute( existing_connector_result = await session.execute(
select(SearchSourceConnector).where( select(SearchSourceConnector).where(
SearchSourceConnector.connector_type == connector_type, SearchSourceConnector.connector_type == connector_type,
SearchSourceConnector.search_space_id == space_id, SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.user_id == user_id, SearchSourceConnector.user_id == user_id,
SearchSourceConnector.name == connector_name,
) )
) )
existing_connector = existing_connector_result.scalars().first() existing_connector = existing_connector_result.scalars().first()
if existing_connector: if existing_connector:
# Delete the old Composio connected account before updating # This is a RECONNECTION of the same account - update existing connector
old_connected_account_id = existing_connector.config.get( old_connected_account_id = existing_connector.config.get(
"composio_connected_account_id" "composio_connected_account_id"
) )
@ -319,46 +345,37 @@ async def composio_callback(
f"Deleted old Composio connected account {old_connected_account_id} " f"Deleted old Composio connected account {old_connected_account_id} "
f"before updating connector {existing_connector.id}" f"before updating connector {existing_connector.id}"
) )
else:
logger.warning(
f"Failed to delete old Composio connected account {old_connected_account_id}"
)
except Exception as delete_error: except Exception as delete_error:
# Log but don't fail - the old account may already be deleted
logger.warning( logger.warning(
f"Error deleting old Composio connected account {old_connected_account_id}: {delete_error!s}" f"Error deleting old Composio connected account {old_connected_account_id}: {delete_error!s}"
) )
# Update existing connector with new connected_account_id # Update existing connector with new connected_account_id
# Merge new credentials with existing config to preserve user settings
logger.info( logger.info(
f"Updating existing Composio connector {existing_connector.id} with new connected_account_id {final_connected_account_id}" f"Reconnecting existing Composio connector {existing_connector.id} ({connector_name}) "
f"with new connected_account_id {final_connected_account_id}"
) )
existing_connector.config = connector_config existing_config = (
existing_connector.config.copy() if existing_connector.config else {}
)
existing_config.update(connector_config)
existing_connector.config = existing_config
flag_modified(existing_connector, "config")
await session.commit() await session.commit()
await session.refresh(existing_connector) await session.refresh(existing_connector)
# Get the frontend connector ID based on toolkit_id
frontend_connector_id = TOOLKIT_TO_FRONTEND_CONNECTOR_ID.get( frontend_connector_id = TOOLKIT_TO_FRONTEND_CONNECTOR_ID.get(
toolkit_id, "composio-connector" toolkit_id, "composio-connector"
) )
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector={frontend_connector_id}&connectorId={existing_connector.id}" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector={frontend_connector_id}&connectorId={existing_connector.id}&view=configure"
) )
# This is a NEW account - create a new connector
try: try:
# Count existing connectors of this type to determine the number logger.info(f"Creating new Composio connector: {connector_name}")
count = await count_connectors_of_type(
session, connector_type, space_id, user_id
)
# Generate base name (e.g., "Gmail", "Google Drive")
base_name = get_base_name_for_type(connector_type)
# Format: "Gmail (Composio) 1", "Gmail (Composio) 2", etc.
if count == 0:
connector_name = f"{base_name} (Composio) 1"
else:
connector_name = f"{base_name} (Composio) {count + 1}"
db_connector = SearchSourceConnector( db_connector = SearchSourceConnector(
name=connector_name, name=connector_name,
@ -382,7 +399,7 @@ async def composio_callback(
toolkit_id, "composio-connector" toolkit_id, "composio-connector"
) )
return RedirectResponse( return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector={frontend_connector_id}&connectorId={db_connector.id}" url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector={frontend_connector_id}&connectorId={db_connector.id}&view=configure"
) )
except IntegrityError as e: except IntegrityError as e:

View file

@ -45,9 +45,9 @@ from app.schemas.new_chat import (
NewChatThreadUpdate, NewChatThreadUpdate,
NewChatThreadVisibilityUpdate, NewChatThreadVisibilityUpdate,
NewChatThreadWithMessages, NewChatThreadWithMessages,
PublicChatSnapshotCreateResponse,
PublicChatSnapshotListResponse,
RegenerateRequest, RegenerateRequest,
SnapshotCreateResponse,
SnapshotListResponse,
ThreadHistoryLoadResponse, ThreadHistoryLoadResponse,
ThreadListItem, ThreadListItem,
ThreadListResponse, ThreadListResponse,
@ -736,10 +736,11 @@ async def update_thread_visibility(
# ============================================================================= # =============================================================================
@router.post("/threads/{thread_id}/snapshots", response_model=SnapshotCreateResponse) @router.post(
"/threads/{thread_id}/snapshots", response_model=PublicChatSnapshotCreateResponse
)
async def create_thread_snapshot( async def create_thread_snapshot(
thread_id: int, thread_id: int,
request: Request,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), user: User = Depends(current_active_user),
): ):
@ -747,23 +748,21 @@ async def create_thread_snapshot(
Create a public snapshot of the thread. Create a public snapshot of the thread.
Returns existing snapshot URL if content unchanged (deduplication). Returns existing snapshot URL if content unchanged (deduplication).
Only the thread owner can create snapshots.
""" """
from app.services.public_chat_service import create_snapshot from app.services.public_chat_service import create_snapshot
base_url = str(request.base_url).rstrip("/")
return await create_snapshot( return await create_snapshot(
session=session, session=session,
thread_id=thread_id, thread_id=thread_id,
user=user, user=user,
base_url=base_url,
) )
@router.get("/threads/{thread_id}/snapshots", response_model=SnapshotListResponse) @router.get(
"/threads/{thread_id}/snapshots", response_model=PublicChatSnapshotListResponse
)
async def list_thread_snapshots( async def list_thread_snapshots(
thread_id: int, thread_id: int,
request: Request,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), user: User = Depends(current_active_user),
): ):
@ -774,13 +773,11 @@ async def list_thread_snapshots(
""" """
from app.services.public_chat_service import list_snapshots_for_thread from app.services.public_chat_service import list_snapshots_for_thread
base_url = str(request.base_url).rstrip("/") return PublicChatSnapshotListResponse(
return SnapshotListResponse(
snapshots=await list_snapshots_for_thread( snapshots=await list_snapshots_for_thread(
session=session, session=session,
thread_id=thread_id, thread_id=thread_id,
user=user, user=user,
base_url=base_url,
) )
) )
@ -889,30 +886,8 @@ async def append_message(
# Update thread's updated_at timestamp # Update thread's updated_at timestamp
thread.updated_at = datetime.now(UTC) thread.updated_at = datetime.now(UTC)
# Auto-generate title from first user message if title is still default # Note: Title generation now happens in stream_new_chat.py after the first response
if thread.title == "New Chat" and role_str == "user": # using LLM to generate a descriptive title (with truncation as fallback)
# Extract text content for title
content = message.content
if isinstance(content, str):
title_text = content
elif isinstance(content, list):
# Find first text content
title_text = ""
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
title_text = part.get("text", "")
break
elif isinstance(part, str):
title_text = part
break
else:
title_text = str(content)
# Truncate title
if title_text:
thread.title = title_text[:100] + (
"..." if len(title_text) > 100 else ""
)
await session.commit() await session.commit()
await session.refresh(db_message) await session.refresh(db_message)

View file

@ -76,6 +76,7 @@ async def create_note(
document_metadata={"NOTE": True}, document_metadata={"NOTE": True},
embedding=None, # Will be generated on first reindex embedding=None, # Will be generated on first reindex
updated_at=datetime.now(UTC), updated_at=datetime.now(UTC),
created_by_id=user.id, # Track who created this note
) )
session.add(document) session.add(document)
@ -93,6 +94,7 @@ async def create_note(
search_space_id=document.search_space_id, search_space_id=document.search_space_id,
created_at=document.created_at, created_at=document.created_at,
updated_at=document.updated_at, updated_at=document.updated_at,
created_by_id=document.created_by_id,
) )

View file

@ -91,7 +91,10 @@ def get_heartbeat_redis_client() -> redis.Redis:
"""Get or create Redis client for heartbeat tracking.""" """Get or create Redis client for heartbeat tracking."""
global _heartbeat_redis_client global _heartbeat_redis_client
if _heartbeat_redis_client is None: if _heartbeat_redis_client is None:
redis_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0") redis_url = os.getenv(
"REDIS_APP_URL",
os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"),
)
_heartbeat_redis_client = redis.from_url(redis_url, decode_responses=True) _heartbeat_redis_client = redis.from_url(redis_url, decode_responses=True)
return _heartbeat_redis_client return _heartbeat_redis_client
@ -524,9 +527,17 @@ async def delete_search_source_connector(
user: User = Depends(current_active_user), user: User = Depends(current_active_user),
): ):
""" """
Delete a search source connector. Delete a search source connector and all its associated documents.
The deletion runs in background via Celery task. User is notified
via the notification system when complete (no polling required).
Requires CONNECTORS_DELETE permission. Requires CONNECTORS_DELETE permission.
""" """
from app.tasks.celery_tasks.connector_deletion_task import (
delete_connector_with_documents_task,
)
try: try:
# Get the connector first # Get the connector first
result = await session.execute( result = await session.execute(
@ -548,7 +559,12 @@ async def delete_search_source_connector(
"You don't have permission to delete this connector", "You don't have permission to delete this connector",
) )
# Delete any periodic schedule associated with this connector # Store connector info before we queue the deletion task
connector_name = db_connector.name
connector_type = db_connector.connector_type.value
search_space_id = db_connector.search_space_id
# Delete any periodic schedule associated with this connector (lightweight, sync)
if db_connector.periodic_indexing_enabled: if db_connector.periodic_indexing_enabled:
success = delete_periodic_schedule(connector_id) success = delete_periodic_schedule(connector_id)
if not success: if not success:
@ -556,7 +572,7 @@ async def delete_search_source_connector(
f"Failed to delete periodic schedule for connector {connector_id}" f"Failed to delete periodic schedule for connector {connector_id}"
) )
# For Composio connectors, also delete the connected account in Composio # For Composio connectors, delete the connected account in Composio (lightweight API call, sync)
composio_connector_types = [ composio_connector_types = [
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR, SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
@ -588,16 +604,33 @@ async def delete_search_source_connector(
f"Error deleting Composio connected account {composio_connected_account_id}: {composio_error!s}" f"Error deleting Composio connected account {composio_connected_account_id}: {composio_error!s}"
) )
await session.delete(db_connector) # Queue background task to delete documents and connector
await session.commit() # This handles potentially large document counts without blocking the API
return {"message": "Search source connector deleted successfully"} delete_connector_with_documents_task.delay(
connector_id=connector_id,
user_id=str(user.id),
search_space_id=search_space_id,
connector_name=connector_name,
connector_type=connector_type,
)
logger.info(
f"Queued deletion task for connector {connector_id} ({connector_name})"
)
return {
"message": "Connector deletion started. You will be notified when complete.",
"status": "queued",
"connector_id": connector_id,
"connector_name": connector_name,
}
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail=f"Failed to delete search source connector: {e!s}", detail=f"Failed to start connector deletion: {e!s}",
) from e ) from e

View file

@ -501,3 +501,25 @@ async def update_llm_preferences(
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Failed to update LLM preferences: {e!s}" status_code=500, detail=f"Failed to update LLM preferences: {e!s}"
) from e ) from e
@router.get("/searchspaces/{search_space_id}/snapshots")
async def list_search_space_snapshots(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
List all public chat snapshots for a search space.
Requires PUBLIC_SHARING_VIEW permission.
"""
from app.schemas.new_chat import PublicChatSnapshotsBySpaceResponse
from app.services.public_chat_service import list_snapshots_for_search_space
snapshots = await list_snapshots_for_search_space(
session=session,
search_space_id=search_space_id,
user=user,
)
return PublicChatSnapshotsBySpaceResponse(snapshots=snapshots)

View file

@ -1,3 +1,10 @@
from .auth import (
LogoutAllResponse,
LogoutRequest,
LogoutResponse,
RefreshTokenRequest,
RefreshTokenResponse,
)
from .base import IDModel, TimestampModel from .base import IDModel, TimestampModel
from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate
from .documents import ( from .documents import (
@ -117,6 +124,10 @@ __all__ = [
"LogFilter", "LogFilter",
"LogRead", "LogRead",
"LogUpdate", "LogUpdate",
# Auth schemas
"LogoutAllResponse",
"LogoutRequest",
"LogoutResponse",
# Search source connector schemas # Search source connector schemas
"MCPConnectorCreate", "MCPConnectorCreate",
"MCPConnectorRead", "MCPConnectorRead",
@ -146,6 +157,8 @@ __all__ = [
"PodcastCreate", "PodcastCreate",
"PodcastRead", "PodcastRead",
"PodcastUpdate", "PodcastUpdate",
"RefreshTokenRequest",
"RefreshTokenResponse",
"RoleCreate", "RoleCreate",
"RoleRead", "RoleRead",
"RoleUpdate", "RoleUpdate",

View file

@ -0,0 +1,35 @@
"""Authentication schemas for refresh token endpoints."""
from pydantic import BaseModel
class RefreshTokenRequest(BaseModel):
"""Request body for token refresh endpoint."""
refresh_token: str
class RefreshTokenResponse(BaseModel):
"""Response from token refresh endpoint."""
access_token: str
refresh_token: str
token_type: str = "bearer"
class LogoutRequest(BaseModel):
"""Request body for logout endpoint (current device)."""
refresh_token: str
class LogoutResponse(BaseModel):
"""Response from logout endpoint (current device)."""
detail: str = "Successfully logged out"
class LogoutAllResponse(BaseModel):
"""Response from logout all devices endpoint."""
detail: str = "Successfully logged out from all devices"

View file

@ -1,5 +1,6 @@
from datetime import datetime from datetime import datetime
from typing import TypeVar from typing import TypeVar
from uuid import UUID
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -51,6 +52,7 @@ class DocumentRead(BaseModel):
created_at: datetime created_at: datetime
updated_at: datetime | None updated_at: datetime | None
search_space_id: int search_space_id: int
created_by_id: UUID | None = None # User who created/uploaded this document
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)

View file

@ -211,17 +211,17 @@ class RegenerateRequest(BaseModel):
# ============================================================================= # =============================================================================
class SnapshotCreateResponse(BaseModel): class PublicChatSnapshotCreateResponse(BaseModel):
"""Response after creating a public snapshot.""" """Response after creating a public chat snapshot."""
snapshot_id: int snapshot_id: int
share_token: str share_token: str
public_url: str public_url: str
is_new: bool # False if existing snapshot returned (same content) is_new: bool
class SnapshotInfo(BaseModel): class PublicChatSnapshotInfo(BaseModel):
"""Info about a single snapshot.""" """Info about a single public chat snapshot."""
id: int id: int
share_token: str share_token: str
@ -230,10 +230,28 @@ class SnapshotInfo(BaseModel):
message_count: int message_count: int
class SnapshotListResponse(BaseModel): class PublicChatSnapshotListResponse(BaseModel):
"""List of snapshots for a thread.""" """List of public chat snapshots for a thread."""
snapshots: list[SnapshotInfo] snapshots: list[PublicChatSnapshotInfo]
class PublicChatSnapshotDetail(BaseModel):
"""Public chat snapshot with thread context."""
id: int
share_token: str
public_url: str
created_at: datetime
message_count: int
thread_id: int
thread_title: str
class PublicChatSnapshotsBySpaceResponse(BaseModel):
"""List of public chat snapshots for a search space."""
snapshots: list[PublicChatSnapshotDetail]
# ============================================================================= # =============================================================================

View file

@ -5,7 +5,7 @@ Service layer for chat comments and mentions.
from uuid import UUID from uuid import UUID
from fastapi import HTTPException from fastapi import HTTPException
from sqlalchemy import delete, select from sqlalchemy import delete, or_, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
@ -103,6 +103,37 @@ async def process_mentions(
return mentions_map return mentions_map
async def get_comment_thread_participants(
session: AsyncSession,
parent_comment_id: int,
exclude_user_ids: set[UUID],
) -> list[UUID]:
"""
Get all unique authors in a comment thread (parent + replies), excluding specified users.
Args:
session: Database session
parent_comment_id: ID of the parent comment
exclude_user_ids: Set of user IDs to exclude (e.g., replier, mentioned users)
Returns:
List of user UUIDs who have participated in the thread
"""
query = select(ChatComment.author_id).where(
or_(
ChatComment.id == parent_comment_id,
ChatComment.parent_id == parent_comment_id,
),
ChatComment.author_id.isnot(None),
)
if exclude_user_ids:
query = query.where(ChatComment.author_id.notin_(list(exclude_user_ids)))
result = await session.execute(query.distinct())
return [row[0] for row in result.fetchall()]
async def get_comments_for_message( async def get_comments_for_message(
session: AsyncSession, session: AsyncSession,
message_id: int, message_id: int,
@ -436,6 +467,31 @@ async def create_reply(
search_space_id=search_space_id, search_space_id=search_space_id,
) )
# Notify thread participants (excluding replier and mentioned users)
mentioned_user_ids = set(mentions_map.keys())
exclude_ids = {user.id} | mentioned_user_ids
participants = await get_comment_thread_participants(
session, comment_id, exclude_ids
)
for participant_id in participants:
if participant_id in mentioned_user_ids:
continue
await NotificationService.comment_reply.notify_comment_reply(
session=session,
user_id=participant_id,
reply_id=reply.id,
parent_comment_id=comment_id,
message_id=parent_comment.message_id,
thread_id=thread.id,
thread_title=thread.title or "Untitled thread",
author_id=str(user.id),
author_name=author_name,
author_avatar_url=user.avatar_url,
author_email=user.email,
content_preview=content_preview[:200],
search_space_id=search_space_id,
)
author = AuthorResponse( author = AuthorResponse(
id=user.id, id=user.id,
display_name=user.display_name, display_name=user.display_name,

View file

@ -15,17 +15,6 @@ from app.config import config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Mapping of toolkit IDs to their Composio auth config IDs
# These use Composio's managed OAuth (no custom credentials needed)
COMPOSIO_TOOLKIT_AUTH_CONFIGS = {
"googledrive": "default", # Uses Composio's managed Google OAuth
"gmail": "default",
"googlecalendar": "default",
"slack": "default",
"notion": "default",
"github": "default",
}
# Mapping of toolkit IDs to their display names # Mapping of toolkit IDs to their display names
COMPOSIO_TOOLKIT_NAMES = { COMPOSIO_TOOLKIT_NAMES = {
"googledrive": "Google Drive", "googledrive": "Google Drive",
@ -234,134 +223,6 @@ class ComposioService:
logger.error(f"Failed to initiate Composio connection: {e!s}") logger.error(f"Failed to initiate Composio connection: {e!s}")
raise raise
async def get_connected_account(
self, connected_account_id: str
) -> dict[str, Any] | None:
"""
Get details of a connected account.
Args:
connected_account_id: The Composio connected account ID.
Returns:
Connected account details or None if not found.
"""
try:
# Pass connected_account_id as positional argument (not keyword)
account = self.client.connected_accounts.get(connected_account_id)
return {
"id": account.id,
"status": getattr(account, "status", None),
"toolkit": getattr(account, "toolkit", None),
"user_id": getattr(account, "user_id", None),
}
except Exception as e:
logger.error(
f"Failed to get connected account {connected_account_id}: {e!s}"
)
return None
async def list_all_connections(self) -> list[dict[str, Any]]:
"""
List ALL connected accounts (for debugging).
Returns:
List of all connected account details.
"""
try:
accounts_response = self.client.connected_accounts.list()
if hasattr(accounts_response, "items"):
accounts = accounts_response.items
elif hasattr(accounts_response, "__iter__"):
accounts = accounts_response
else:
logger.warning(
f"Unexpected accounts response type: {type(accounts_response)}"
)
return []
result = []
for acc in accounts:
toolkit_raw = getattr(acc, "toolkit", None)
toolkit_info = None
if toolkit_raw:
if isinstance(toolkit_raw, str):
toolkit_info = toolkit_raw
elif hasattr(toolkit_raw, "slug"):
toolkit_info = toolkit_raw.slug
elif hasattr(toolkit_raw, "name"):
toolkit_info = toolkit_raw.name
else:
toolkit_info = str(toolkit_raw)
result.append(
{
"id": acc.id,
"status": getattr(acc, "status", None),
"toolkit": toolkit_info,
"user_id": getattr(acc, "user_id", None),
}
)
return result
except Exception as e:
logger.error(f"Failed to list all connections: {e!s}")
return []
async def list_user_connections(self, user_id: str) -> list[dict[str, Any]]:
"""
List all connected accounts for a user.
Args:
user_id: The user's unique identifier.
Returns:
List of connected account details.
"""
try:
accounts_response = self.client.connected_accounts.list(user_id=user_id)
# Handle paginated response (may have .items attribute) or direct list
if hasattr(accounts_response, "items"):
accounts = accounts_response.items
elif hasattr(accounts_response, "__iter__"):
accounts = accounts_response
else:
logger.warning(
f"Unexpected accounts response type: {type(accounts_response)}"
)
return []
result = []
for acc in accounts:
# Extract toolkit info - might be string or object
toolkit_raw = getattr(acc, "toolkit", None)
toolkit_info = None
if toolkit_raw:
if isinstance(toolkit_raw, str):
toolkit_info = toolkit_raw
elif hasattr(toolkit_raw, "slug"):
toolkit_info = toolkit_raw.slug
elif hasattr(toolkit_raw, "name"):
toolkit_info = toolkit_raw.name
else:
toolkit_info = toolkit_raw
result.append(
{
"id": acc.id,
"status": getattr(acc, "status", None),
"toolkit": toolkit_info,
}
)
logger.info(f"Found {len(result)} connections for user {user_id}: {result}")
return result
except Exception as e:
logger.error(f"Failed to list connections for user {user_id}: {e!s}")
return []
async def delete_connected_account(self, connected_account_id: str) -> bool: async def delete_connected_account(self, connected_account_id: str) -> bool:
""" """
Delete a connected account from Composio. Delete a connected account from Composio.
@ -449,8 +310,11 @@ class ComposioService:
""" """
try: try:
# Composio uses snake_case for parameters # Composio uses snake_case for parameters
# IMPORTANT: Include 'fields' to ensure mimeType is returned in the response
# Without this, Google Drive API may not include mimeType for some files
params = { params = {
"page_size": min(page_size, 100), "page_size": min(page_size, 100),
"fields": "files(id,name,mimeType,modifiedTime,createdTime),nextPageToken",
} }
if folder_id: if folder_id:
# List contents of a specific folder (exclude shortcuts - we don't have access to them) # List contents of a specific folder (exclude shortcuts - we don't have access to them)
@ -498,7 +362,11 @@ class ComposioService:
return [], None, str(e) return [], None, str(e)
async def get_drive_file_content( async def get_drive_file_content(
self, connected_account_id: str, entity_id: str, file_id: str self,
connected_account_id: str,
entity_id: str,
file_id: str,
original_mime_type: str | None = None,
) -> tuple[bytes | None, str | None]: ) -> tuple[bytes | None, str | None]:
""" """
Download file content from Google Drive via Composio. Download file content from Google Drive via Composio.
@ -507,10 +375,13 @@ class ComposioService:
to a local directory, and the local file path is provided in the response. to a local directory, and the local file path is provided in the response.
Response includes: file_path, file_name, size fields. Response includes: file_path, file_name, size fields.
For Google Workspace files (Docs, Sheets, Slides), exports to PDF format.
Args: Args:
connected_account_id: Composio connected account ID. connected_account_id: Composio connected account ID.
entity_id: The entity/user ID that owns the connected account. entity_id: The entity/user ID that owns the connected account.
file_id: Google Drive file ID. file_id: Google Drive file ID.
original_mime_type: Original MIME type of the file (used to detect Google Workspace files).
Returns: Returns:
Tuple of (file content bytes, error message). Tuple of (file content bytes, error message).
@ -518,10 +389,19 @@ class ComposioService:
from pathlib import Path from pathlib import Path
try: try:
params = {"file_id": file_id}
# For Google Workspace files, explicitly export as PDF
# This ensures consistent behavior and proper binary detection
if original_mime_type and original_mime_type.startswith(
"application/vnd.google-apps."
):
params["mime_type"] = "application/pdf"
result = await self.execute_tool( result = await self.execute_tool(
connected_account_id=connected_account_id, connected_account_id=connected_account_id,
tool_name="GOOGLEDRIVE_DOWNLOAD_FILE", tool_name="GOOGLEDRIVE_DOWNLOAD_FILE",
params={"file_id": file_id}, params=params,
entity_id=entity_id, entity_id=entity_id,
) )
@ -651,6 +531,60 @@ class ComposioService:
logger.error(f"Failed to get Drive file content: {e!s}") logger.error(f"Failed to get Drive file content: {e!s}")
return None, str(e) return None, str(e)
async def get_file_metadata(
self, connected_account_id: str, entity_id: str, file_id: str
) -> tuple[dict[str, Any] | None, str | None]:
"""
Get metadata for a specific file from Google Drive.
Args:
connected_account_id: Composio connected account ID.
entity_id: The entity/user ID that owns the connected account.
file_id: The ID of the file to get metadata for.
Returns:
Tuple of (metadata dict, error message).
"""
try:
result = await self.execute_tool(
connected_account_id=connected_account_id,
tool_name="GOOGLEDRIVE_GET_FILE_METADATA",
params={
"file_id": file_id,
"fields": "id,name,mimeType,modifiedTime,createdTime,size",
},
entity_id=entity_id,
)
if not result.get("success"):
return None, result.get("error", "Unknown error")
data = result.get("data", {})
# Handle nested response structure
if isinstance(data, dict):
inner_data = data.get("data", data)
if isinstance(inner_data, dict):
# Extract metadata fields with fallbacks for camelCase/snake_case
metadata = {
"id": inner_data.get("id") or file_id,
"name": inner_data.get("name", ""),
"mimeType": inner_data.get("mimeType")
or inner_data.get("mime_type", ""),
"modifiedTime": inner_data.get("modifiedTime")
or inner_data.get("modified_time", ""),
"createdTime": inner_data.get("createdTime")
or inner_data.get("created_time", ""),
"size": inner_data.get("size", ""),
}
return metadata, None
return None, "Could not extract metadata from response"
except Exception as e:
logger.error(f"Failed to get file metadata: {e!s}")
return None, str(e)
async def get_drive_start_page_token( async def get_drive_start_page_token(
self, connected_account_id: str, entity_id: str self, connected_account_id: str, entity_id: str
) -> tuple[str | None, str | None]: ) -> tuple[str | None, str | None]:
@ -945,6 +879,178 @@ class ComposioService:
logger.error(f"Failed to list Calendar events: {e!s}") logger.error(f"Failed to list Calendar events: {e!s}")
return [], str(e) return [], str(e)
# ===== User Info Methods =====
async def get_connected_account_email(
self,
connected_account_id: str,
entity_id: str,
toolkit_id: str,
) -> str | None:
"""
Get the email address associated with a connected account.
Uses toolkit-specific API calls:
- Google Drive: List files and extract owner email
- Gmail: Get user profile
- Google Calendar: List events and extract organizer/creator email
Args:
connected_account_id: Composio connected account ID.
entity_id: The entity/user ID that owns the connected account.
toolkit_id: The toolkit identifier (googledrive, gmail, googlecalendar).
Returns:
Email address string or None if not available.
"""
try:
email = await self._extract_email_for_toolkit(
connected_account_id, entity_id, toolkit_id
)
if email:
logger.info(f"Retrieved email {email} for {toolkit_id} connector")
else:
logger.warning(f"Could not retrieve email for {toolkit_id} connector")
return email
except Exception as e:
logger.error(f"Failed to get email for {toolkit_id} connector: {e!s}")
return None
async def _extract_email_for_toolkit(
self,
connected_account_id: str,
entity_id: str,
toolkit_id: str,
) -> str | None:
"""Extract email based on toolkit type."""
if toolkit_id == "googledrive":
return await self._get_drive_owner_email(connected_account_id, entity_id)
elif toolkit_id == "gmail":
return await self._get_gmail_profile_email(connected_account_id, entity_id)
elif toolkit_id == "googlecalendar":
return await self._get_calendar_user_email(connected_account_id, entity_id)
return None
async def _get_drive_owner_email(
self, connected_account_id: str, entity_id: str
) -> str | None:
"""Get email from Google Drive file owner where me=True."""
# List files owned by the user and find one where owner.me=True
result = await self.execute_tool(
connected_account_id=connected_account_id,
tool_name="GOOGLEDRIVE_LIST_FILES",
params={
"page_size": 10,
"fields": "files(owners)",
"q": "'me' in owners", # Only files owned by current user
},
entity_id=entity_id,
)
if not result.get("success"):
return None
data = result.get("data", {})
if not isinstance(data, dict):
return None
files = data.get("files") or data.get("data", {}).get("files", [])
for file in files:
owners = file.get("owners", [])
for owner in owners:
# Only return email if this is the current user (me=True)
if owner.get("me") and owner.get("emailAddress"):
return owner.get("emailAddress")
return None
async def _get_gmail_profile_email(
self, connected_account_id: str, entity_id: str
) -> str | None:
"""Get email from Gmail profile."""
result = await self.execute_tool(
connected_account_id=connected_account_id,
tool_name="GMAIL_GET_PROFILE",
params={},
entity_id=entity_id,
)
if not result.get("success"):
return None
data = result.get("data", {})
if not isinstance(data, dict):
return None
return data.get("emailAddress") or data.get("data", {}).get("emailAddress")
async def _get_calendar_user_email(
self, connected_account_id: str, entity_id: str
) -> str | None:
"""Get email from Google Calendar primary calendar or event organizer/creator."""
# Method 1: Get primary calendar - the "summary" field is the user's email
result = await self.execute_tool(
connected_account_id=connected_account_id,
tool_name="GOOGLECALENDAR_GET_CALENDAR",
params={"calendar_id": "primary"},
entity_id=entity_id,
)
if result.get("success"):
data = result.get("data", {})
if isinstance(data, dict):
# Handle nested structure: data['data']['calendar_data']['summary']
calendar_data = (
data.get("data", {}).get("calendar_data", {})
if isinstance(data.get("data"), dict)
else {}
)
summary = (
calendar_data.get("summary")
or calendar_data.get("id")
or data.get("data", {}).get("summary")
or data.get("summary")
)
if summary and "@" in summary:
return summary
# Method 2: Fallback - list events to get calendar summary (owner's email)
result = await self.execute_tool(
connected_account_id=connected_account_id,
tool_name="GOOGLECALENDAR_EVENTS_LIST",
params={"max_results": 20},
entity_id=entity_id,
)
if not result.get("success"):
return None
data = result.get("data", {})
if not isinstance(data, dict):
return None
# The events list response contains 'summary' which is the calendar owner's email
nested_data = data.get("data", {}) if isinstance(data.get("data"), dict) else {}
summary = nested_data.get("summary") or data.get("summary")
if summary and "@" in summary:
return summary
# Method 3: Check event organizers/creators
items = nested_data.get("items", []) or data.get("items", [])
for event in items:
organizer = event.get("organizer", {})
if organizer.get("self"):
return organizer.get("email")
creator = event.get("creator", {})
if creator.get("self"):
return creator.get("email")
return None
# Singleton instance # Singleton instance
_composio_service: ComposioService | None = None _composio_service: ComposioService | None = None

View file

@ -479,6 +479,31 @@ class VercelStreamingService:
}, },
) )
def format_thread_title_update(self, thread_id: int, title: str) -> str:
"""
Format a thread title update notification (SurfSense specific).
This is sent after the first response in a thread to update the
auto-generated title based on the conversation content.
Args:
thread_id: The ID of the thread being updated
title: The new title for the thread
Returns:
str: SSE formatted thread title update data part
Example output:
data: {"type":"data-thread-title-update","data":{"threadId":123,"title":"New Title"}}
"""
return self.format_data(
"thread-title-update",
{
"threadId": thread_id,
"title": title,
},
)
# ========================================================================= # =========================================================================
# Error Part # Error Part
# ========================================================================= # =========================================================================

View file

@ -861,6 +861,98 @@ class MentionNotificationHandler(BaseNotificationHandler):
raise raise
class CommentReplyNotificationHandler(BaseNotificationHandler):
"""Handler for comment reply notifications."""
def __init__(self):
super().__init__("comment_reply")
async def find_notification_by_reply(
self,
session: AsyncSession,
reply_id: int,
user_id: UUID,
) -> Notification | None:
query = select(Notification).where(
Notification.type == self.notification_type,
Notification.user_id == user_id,
Notification.notification_metadata["reply_id"].astext == str(reply_id),
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def notify_comment_reply(
self,
session: AsyncSession,
user_id: UUID,
reply_id: int,
parent_comment_id: int,
message_id: int,
thread_id: int,
thread_title: str,
author_id: str,
author_name: str,
author_avatar_url: str | None,
author_email: str,
content_preview: str,
search_space_id: int,
) -> Notification:
existing = await self.find_notification_by_reply(session, reply_id, user_id)
if existing:
logger.info(
f"Notification already exists for reply {reply_id} to user {user_id}"
)
return existing
title = f"{author_name} replied in a thread"
message = content_preview[:100] + ("..." if len(content_preview) > 100 else "")
metadata = {
"reply_id": reply_id,
"parent_comment_id": parent_comment_id,
"message_id": message_id,
"thread_id": thread_id,
"thread_title": thread_title,
"author_id": author_id,
"author_name": author_name,
"author_avatar_url": author_avatar_url,
"author_email": author_email,
"content_preview": content_preview[:200],
}
try:
notification = Notification(
user_id=user_id,
search_space_id=search_space_id,
type=self.notification_type,
title=title,
message=message,
notification_metadata=metadata,
)
session.add(notification)
await session.commit()
await session.refresh(notification)
logger.info(
f"Created comment_reply notification {notification.id} for user {user_id}"
)
return notification
except Exception as e:
await session.rollback()
if (
"duplicate key" in str(e).lower()
or "unique constraint" in str(e).lower()
):
logger.warning(
f"Duplicate notification for reply {reply_id} to user {user_id}"
)
existing = await self.find_notification_by_reply(
session, reply_id, user_id
)
if existing:
return existing
raise
class PageLimitNotificationHandler(BaseNotificationHandler): class PageLimitNotificationHandler(BaseNotificationHandler):
"""Handler for page limit exceeded notifications.""" """Handler for page limit exceeded notifications."""
@ -959,6 +1051,7 @@ class NotificationService:
connector_indexing = ConnectorIndexingNotificationHandler() connector_indexing = ConnectorIndexingNotificationHandler()
document_processing = DocumentProcessingNotificationHandler() document_processing = DocumentProcessingNotificationHandler()
mention = MentionNotificationHandler() mention = MentionNotificationHandler()
comment_reply = CommentReplyNotificationHandler()
page_limit = PageLimitNotificationHandler() page_limit = PageLimitNotificationHandler()
@staticmethod @staticmethod

View file

@ -25,12 +25,14 @@ from app.db import (
ChatVisibility, ChatVisibility,
NewChatMessage, NewChatMessage,
NewChatThread, NewChatThread,
Permission,
Podcast, Podcast,
PodcastStatus, PodcastStatus,
PublicChatSnapshot, PublicChatSnapshot,
SearchSpaceMembership, SearchSpaceMembership,
User, User,
) )
from app.utils.rbac import check_permission
UI_TOOLS = { UI_TOOLS = {
"display_image", "display_image",
@ -159,7 +161,6 @@ async def create_snapshot(
session: AsyncSession, session: AsyncSession,
thread_id: int, thread_id: int,
user: User, user: User,
base_url: str,
) -> dict: ) -> dict:
""" """
Create a public snapshot of a chat thread. Create a public snapshot of a chat thread.
@ -167,6 +168,9 @@ async def create_snapshot(
Returns existing snapshot if content unchanged (same hash). Returns existing snapshot if content unchanged (same hash).
Returns new snapshot with unique URL if content changed. Returns new snapshot with unique URL if content changed.
""" """
from app.config import config
frontend_url = (config.NEXT_FRONTEND_URL or "").rstrip("/")
result = await session.execute( result = await session.execute(
select(NewChatThread) select(NewChatThread)
.options(selectinload(NewChatThread.messages)) .options(selectinload(NewChatThread.messages))
@ -177,11 +181,13 @@ async def create_snapshot(
if not thread: if not thread:
raise HTTPException(status_code=404, detail="Thread not found") raise HTTPException(status_code=404, detail="Thread not found")
if thread.created_by_id != user.id: await check_permission(
raise HTTPException( session,
status_code=403, user,
detail="Only the creator of this chat can create public snapshots", thread.search_space_id,
) Permission.PUBLIC_SHARING_CREATE.value,
"You don't have permission to create public share links",
)
# Build snapshot data # Build snapshot data
user_cache: dict[UUID, dict] = {} user_cache: dict[UUID, dict] = {}
@ -246,7 +252,7 @@ async def create_snapshot(
return { return {
"snapshot_id": existing.id, "snapshot_id": existing.id,
"share_token": existing.share_token, "share_token": existing.share_token,
"public_url": f"{base_url}/public/{existing.share_token}", "public_url": f"{frontend_url}/public/{existing.share_token}",
"is_new": False, "is_new": False,
} }
@ -279,7 +285,7 @@ async def create_snapshot(
return { return {
"snapshot_id": snapshot.id, "snapshot_id": snapshot.id,
"share_token": snapshot.share_token, "share_token": snapshot.share_token,
"public_url": f"{base_url}/public/{snapshot.share_token}", "public_url": f"{frontend_url}/public/{snapshot.share_token}",
"is_new": True, "is_new": True,
} }
@ -348,10 +354,10 @@ async def list_snapshots_for_thread(
session: AsyncSession, session: AsyncSession,
thread_id: int, thread_id: int,
user: User, user: User,
base_url: str,
) -> list[dict]: ) -> list[dict]:
"""List all public snapshots for a thread.""" """List all public snapshots for a thread."""
# Verify ownership from app.config import config
result = await session.execute( result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id) select(NewChatThread).filter(NewChatThread.id == thread_id)
) )
@ -360,13 +366,15 @@ async def list_snapshots_for_thread(
if not thread: if not thread:
raise HTTPException(status_code=404, detail="Thread not found") raise HTTPException(status_code=404, detail="Thread not found")
if thread.created_by_id != user.id: # Check permission to view public share links
raise HTTPException( await check_permission(
status_code=403, session,
detail="Only the creator can view snapshots", user,
) thread.search_space_id,
Permission.PUBLIC_SHARING_VIEW.value,
"You don't have permission to view public share links",
)
# Get snapshots
result = await session.execute( result = await session.execute(
select(PublicChatSnapshot) select(PublicChatSnapshot)
.filter(PublicChatSnapshot.thread_id == thread_id) .filter(PublicChatSnapshot.thread_id == thread_id)
@ -374,11 +382,13 @@ async def list_snapshots_for_thread(
) )
snapshots = result.scalars().all() snapshots = result.scalars().all()
frontend_url = (config.NEXT_FRONTEND_URL or "").rstrip("/")
return [ return [
{ {
"id": s.id, "id": s.id,
"share_token": s.share_token, "share_token": s.share_token,
"public_url": f"{base_url}/public/{s.share_token}", "public_url": f"{frontend_url}/public/{s.share_token}",
"created_at": s.created_at.isoformat() if s.created_at else None, "created_at": s.created_at.isoformat() if s.created_at else None,
"message_count": len(s.message_ids) if s.message_ids else 0, "message_count": len(s.message_ids) if s.message_ids else 0,
} }
@ -386,6 +396,54 @@ async def list_snapshots_for_thread(
] ]
async def list_snapshots_for_search_space(
session: AsyncSession,
search_space_id: int,
user: User,
) -> list[dict]:
"""List all public snapshots for a search space."""
from app.config import config
await check_permission(
session,
user,
search_space_id,
Permission.PUBLIC_SHARING_VIEW.value,
"You don't have permission to view public share links",
)
result = await session.execute(
select(PublicChatSnapshot)
.join(NewChatThread, PublicChatSnapshot.thread_id == NewChatThread.id)
.filter(NewChatThread.search_space_id == search_space_id)
.order_by(PublicChatSnapshot.created_at.desc())
)
snapshots = result.scalars().all()
snapshot_thread_ids = [s.thread_id for s in snapshots]
thread_result = await session.execute(
select(NewChatThread.id, NewChatThread.title).filter(
NewChatThread.id.in_(snapshot_thread_ids)
)
)
thread_titles = {row[0]: row[1] for row in thread_result.fetchall()}
frontend_url = (config.NEXT_FRONTEND_URL or "").rstrip("/")
return [
{
"id": s.id,
"share_token": s.share_token,
"public_url": f"{frontend_url}/public/{s.share_token}",
"created_at": s.created_at.isoformat() if s.created_at else None,
"message_count": len(s.message_ids) if s.message_ids else 0,
"thread_id": s.thread_id,
"thread_title": thread_titles.get(s.thread_id, "Untitled"),
}
for s in snapshots
]
# ============================================================================= # =============================================================================
# Snapshot Deletion # Snapshot Deletion
# ============================================================================= # =============================================================================
@ -412,11 +470,13 @@ async def delete_snapshot(
if not snapshot: if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found") raise HTTPException(status_code=404, detail="Snapshot not found")
if snapshot.thread.created_by_id != user.id: await check_permission(
raise HTTPException( session,
status_code=403, user,
detail="Only the creator can delete snapshots", snapshot.thread.search_space_id,
) Permission.PUBLIC_SHARING_DELETE.value,
"You don't have permission to delete public share links",
)
await session.delete(snapshot) await session.delete(snapshot)
await session.commit() await session.commit()

View file

@ -0,0 +1,269 @@
"""Celery task for background connector deletion.
This task handles the deletion of all documents associated with a connector
in the background, then deletes the connector itself. User is notified via
the notification system when complete (no polling required).
Features:
- Batch deletion to handle large document counts
- Automatic retry on failure
- Progress tracking via notifications
- Handles both success and failure notifications
"""
import asyncio
import logging
from uuid import UUID
from sqlalchemy import delete, func, select
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
from app.config import config
from app.db import Document, Notification, SearchSourceConnector
logger = logging.getLogger(__name__)
# Batch size for document deletion
DELETION_BATCH_SIZE = 500
def _get_celery_session_maker():
"""Create async session maker for Celery tasks."""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool,
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False), engine
@celery_app.task(
bind=True,
name="delete_connector_with_documents",
max_retries=3,
default_retry_delay=60,
autoretry_for=(Exception,),
retry_backoff=True,
)
def delete_connector_with_documents_task(
self,
connector_id: int,
user_id: str,
search_space_id: int,
connector_name: str,
connector_type: str,
):
"""
Background task to delete a connector and all its associated documents.
Creates a notification when complete (success or failure).
No polling required - user sees notification in UI.
Args:
connector_id: ID of the connector to delete
user_id: ID of the user who initiated the deletion
search_space_id: ID of the search space
connector_name: Name of the connector (for notification message)
connector_type: Type of the connector (for logging)
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
_delete_connector_async(
connector_id=connector_id,
user_id=user_id,
search_space_id=search_space_id,
connector_name=connector_name,
connector_type=connector_type,
)
)
finally:
loop.close()
async def _delete_connector_async(
connector_id: int,
user_id: str,
search_space_id: int,
connector_name: str,
connector_type: str,
) -> dict:
"""
Async implementation of connector deletion.
Steps:
1. Count total documents to delete
2. Delete documents in batches (chunks cascade automatically)
3. Delete the connector record
4. Create success notification
On failure, creates failure notification and re-raises exception.
"""
session_maker, engine = _get_celery_session_maker()
total_deleted = 0
try:
async with session_maker() as session:
# Step 1: Count total documents for this connector
count_result = await session.execute(
select(func.count(Document.id)).where(
Document.connector_id == connector_id
)
)
total_docs = count_result.scalar() or 0
logger.info(
f"Starting deletion of connector {connector_id} ({connector_name}). "
f"Documents to delete: {total_docs}"
)
# Step 2: Delete documents in batches
while True:
# Get batch of document IDs
result = await session.execute(
select(Document.id)
.where(Document.connector_id == connector_id)
.limit(DELETION_BATCH_SIZE)
)
doc_ids = [row[0] for row in result.fetchall()]
if not doc_ids:
break
# Delete this batch (chunks are deleted via CASCADE)
await session.execute(delete(Document).where(Document.id.in_(doc_ids)))
await session.commit()
total_deleted += len(doc_ids)
logger.info(
f"Deleted batch of {len(doc_ids)} documents. "
f"Progress: {total_deleted}/{total_docs}"
)
# Step 3: Delete the connector record
result = await session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == connector_id
)
)
connector = result.scalar_one_or_none()
if connector:
await session.delete(connector)
logger.info(f"Deleted connector record: {connector_id}")
else:
logger.warning(
f"Connector {connector_id} not found - may have been already deleted"
)
# Step 4: Create success notification
doc_text = "document" if total_deleted == 1 else "documents"
notification = Notification(
user_id=UUID(user_id),
search_space_id=search_space_id,
type="connector_deletion",
title=f"{connector_name} removed",
message=f"Cleanup complete. {total_deleted} {doc_text} removed.",
notification_metadata={
"connector_id": connector_id,
"connector_name": connector_name,
"connector_type": connector_type,
"documents_deleted": total_deleted,
"status": "completed",
},
)
session.add(notification)
await session.commit()
logger.info(
f"Connector {connector_id} ({connector_name}) deleted successfully. "
f"Total documents deleted: {total_deleted}"
)
return {
"status": "success",
"connector_id": connector_id,
"connector_name": connector_name,
"documents_deleted": total_deleted,
}
except Exception as e:
logger.error(
f"Failed to delete connector {connector_id} ({connector_name}): {e!s}",
exc_info=True,
)
# Create failure notification
try:
async with session_maker() as session:
notification = Notification(
user_id=UUID(user_id),
search_space_id=search_space_id,
type="connector_deletion",
title=f"Failed to Remove {connector_name}",
message="Something went wrong while removing this connector. Please try again.",
notification_metadata={
"connector_id": connector_id,
"connector_name": connector_name,
"connector_type": connector_type,
"documents_deleted": total_deleted,
"status": "failed",
"error": str(e),
},
)
session.add(notification)
await session.commit()
except Exception as notify_error:
logger.error(
f"Failed to create failure notification: {notify_error!s}",
exc_info=True,
)
# Re-raise to trigger Celery retry
raise
finally:
await engine.dispose()
async def delete_documents_by_connector_id(
session,
connector_id: int,
batch_size: int = DELETION_BATCH_SIZE,
) -> int:
"""
Delete all documents associated with a connector in batches.
This is a utility function that can be used independently of the Celery task
for synchronous deletion scenarios (e.g., small document counts).
Args:
session: AsyncSession instance
connector_id: ID of the connector
batch_size: Number of documents to delete per batch
Returns:
Total number of documents deleted
"""
total_deleted = 0
while True:
result = await session.execute(
select(Document.id)
.where(Document.connector_id == connector_id)
.limit(batch_size)
)
doc_ids = [row[0] for row in result.fetchall()]
if not doc_ids:
break
await session.execute(delete(Document).where(Document.id.in_(doc_ids)))
await session.commit()
total_deleted += len(doc_ids)
return total_deleted

View file

@ -323,6 +323,28 @@ def process_file_upload_task(
user_id: ID of the user user_id: ID of the user
""" """
import asyncio import asyncio
import os
import traceback
logger.info(
f"[process_file_upload] Task started - file: {filename}, "
f"search_space_id: {search_space_id}, user_id: {user_id}"
)
logger.info(f"[process_file_upload] File path: {file_path}")
# Check if file exists and is accessible
if not os.path.exists(file_path):
logger.error(
f"[process_file_upload] File does not exist: {file_path}. "
"The temp file may have been cleaned up before the task ran."
)
return
try:
file_size = os.path.getsize(file_path)
logger.info(f"[process_file_upload] File size: {file_size} bytes")
except Exception as e:
logger.warning(f"[process_file_upload] Could not get file size: {e}")
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
@ -331,6 +353,15 @@ def process_file_upload_task(
loop.run_until_complete( loop.run_until_complete(
_process_file_upload(file_path, filename, search_space_id, user_id) _process_file_upload(file_path, filename, search_space_id, user_id)
) )
logger.info(
f"[process_file_upload] Task completed successfully for: {filename}"
)
except Exception as e:
logger.error(
f"[process_file_upload] Task failed for {filename}: {e}\n"
f"Traceback:\n{traceback.format_exc()}"
)
raise
finally: finally:
loop.close() loop.close()
@ -343,16 +374,22 @@ async def _process_file_upload(
from app.tasks.document_processors.file_processors import process_file_in_background from app.tasks.document_processors.file_processors import process_file_in_background
logger.info(f"[_process_file_upload] Starting async processing for: {filename}")
async with get_celery_session_maker()() as session: async with get_celery_session_maker()() as session:
logger.info(f"[_process_file_upload] Database session created for: {filename}")
task_logger = TaskLoggingService(session, search_space_id) task_logger = TaskLoggingService(session, search_space_id)
# Get file size for notification metadata # Get file size for notification metadata
try: try:
file_size = os.path.getsize(file_path) file_size = os.path.getsize(file_path)
except Exception: logger.info(f"[_process_file_upload] File size: {file_size} bytes")
except Exception as e:
logger.warning(f"[_process_file_upload] Could not get file size: {e}")
file_size = None file_size = None
# Create notification for document processing # Create notification for document processing
logger.info(f"[_process_file_upload] Creating notification for: {filename}")
notification = ( notification = (
await NotificationService.document_processing.notify_processing_started( await NotificationService.document_processing.notify_processing_started(
session=session, session=session,
@ -363,6 +400,9 @@ async def _process_file_upload(
file_size=file_size, file_size=file_size,
) )
) )
logger.info(
f"[_process_file_upload] Notification created with ID: {notification.id if notification else 'None'}"
)
log_entry = await task_logger.log_task_start( log_entry = await task_logger.log_task_start(
task_name="process_file_upload", task_name="process_file_upload",
@ -505,6 +545,7 @@ def process_circleback_meeting_task(
markdown_content: str, markdown_content: str,
metadata: dict, metadata: dict,
search_space_id: int, search_space_id: int,
connector_id: int | None = None,
): ):
""" """
Celery task to process Circleback meeting webhook data. Celery task to process Circleback meeting webhook data.
@ -515,6 +556,7 @@ def process_circleback_meeting_task(
markdown_content: Meeting content formatted as markdown markdown_content: Meeting content formatted as markdown
metadata: Meeting metadata dictionary metadata: Meeting metadata dictionary
search_space_id: ID of the search space search_space_id: ID of the search space
connector_id: ID of the Circleback connector (for deletion support)
""" """
import asyncio import asyncio
@ -529,6 +571,7 @@ def process_circleback_meeting_task(
markdown_content, markdown_content,
metadata, metadata,
search_space_id, search_space_id,
connector_id,
) )
) )
finally: finally:
@ -541,6 +584,7 @@ async def _process_circleback_meeting(
markdown_content: str, markdown_content: str,
metadata: dict, metadata: dict,
search_space_id: int, search_space_id: int,
connector_id: int | None = None,
): ):
"""Process Circleback meeting with new session.""" """Process Circleback meeting with new session."""
from app.tasks.document_processors.circleback_processor import ( from app.tasks.document_processors.circleback_processor import (
@ -597,6 +641,7 @@ async def _process_circleback_meeting(
markdown_content=markdown_content, markdown_content=markdown_content,
metadata=metadata, metadata=metadata,
search_space_id=search_space_id, search_space_id=search_space_id,
connector_id=connector_id,
) )
if result: if result:

View file

@ -51,7 +51,10 @@ def _clear_generating_podcast(search_space_id: int) -> None:
import redis import redis
try: try:
redis_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0") redis_url = os.getenv(
"REDIS_APP_URL",
os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"),
)
client = redis.from_url(redis_url, decode_responses=True) client = redis.from_url(redis_url, decode_responses=True)
key = f"podcast:generating:{search_space_id}" key = f"podcast:generating:{search_space_id}"
client.delete(key) client.delete(key)

View file

@ -36,7 +36,10 @@ def get_redis_client() -> redis.Redis:
"""Get or create Redis client for heartbeat checking.""" """Get or create Redis client for heartbeat checking."""
global _redis_client global _redis_client
if _redis_client is None: if _redis_client is None:
redis_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0") redis_url = os.getenv(
"REDIS_APP_URL",
os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0"),
)
_redis_client = redis.from_url(redis_url, decode_responses=True) _redis_client = redis.from_url(redis_url, decode_responses=True)
return _redis_client return _redis_client

View file

@ -32,6 +32,7 @@ from app.services.chat_session_state_service import (
clear_ai_responding, clear_ai_responding,
set_ai_responding, set_ai_responding,
) )
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
from app.services.connector_service import ConnectorService from app.services.connector_service import ConnectorService
from app.services.new_streaming_service import VercelStreamingService from app.services.new_streaming_service import VercelStreamingService
from app.utils.content_utils import bootstrap_history_from_db from app.utils.content_utils import bootstrap_history_from_db
@ -1208,6 +1209,59 @@ async def stream_new_chat(
if completion_event: if completion_event:
yield completion_event yield completion_event
# Generate LLM title for new chats after first response
# Check if this is the first assistant response by counting existing assistant messages
from app.db import NewChatMessage, NewChatThread
from sqlalchemy import func
assistant_count_result = await session.execute(
select(func.count(NewChatMessage.id)).filter(
NewChatMessage.thread_id == chat_id,
NewChatMessage.role == "assistant",
)
)
assistant_message_count = assistant_count_result.scalar() or 0
# Only generate title on the first response (no prior assistant messages)
if assistant_message_count == 0:
generated_title = None
try:
# Generate title using the same LLM
title_chain = TITLE_GENERATION_PROMPT_TEMPLATE | llm
# Truncate inputs to avoid context length issues
truncated_query = user_query[:500]
truncated_response = accumulated_text[:1000]
title_result = await title_chain.ainvoke({
"user_query": truncated_query,
"assistant_response": truncated_response,
})
# Extract and clean the title
if title_result and hasattr(title_result, "content"):
raw_title = title_result.content.strip()
# Validate the title (reasonable length)
if raw_title and len(raw_title) <= 100:
# Remove any quotes or extra formatting
generated_title = raw_title.strip('"\'')
except Exception:
generated_title = None
# Only update if LLM succeeded (keep truncated prompt title as fallback)
if generated_title:
# Fetch thread and update title
thread_result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == chat_id)
)
thread = thread_result.scalars().first()
if thread:
thread.title = generated_title
await session.commit()
# Notify frontend of the title update
yield streaming_service.format_thread_title_update(
chat_id, generated_title
)
# Finish the step and message # Finish the step and message
yield streaming_service.format_finish_step() yield streaming_service.format_finish_step()
yield streaming_service.format_finish() yield streaming_service.format_finish()

View file

@ -417,6 +417,8 @@ async def index_airtable_records(
embedding=summary_embedding, embedding=summary_embedding,
chunks=chunks, chunks=chunks,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
) )
session.add(document) session.add(document)

View file

@ -396,6 +396,8 @@ async def index_bookstack_pages(
embedding=summary_embedding, embedding=summary_embedding,
chunks=chunks, chunks=chunks,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
) )
session.add(document) session.add(document)

View file

@ -395,6 +395,8 @@ async def index_clickup_tasks(
embedding=summary_embedding, embedding=summary_embedding,
chunks=chunks, chunks=chunks,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
) )
session.add(document) session.add(document)

View file

@ -402,6 +402,8 @@ async def index_confluence_pages(
embedding=summary_embedding, embedding=summary_embedding,
chunks=chunks, chunks=chunks,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
) )
session.add(document) session.add(document)

View file

@ -527,6 +527,8 @@ async def index_discord_messages(
content_hash=content_hash, content_hash=content_hash,
unique_identifier_hash=unique_identifier_hash, unique_identifier_hash=unique_identifier_hash,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
) )
session.add(document) session.add(document)

View file

@ -292,6 +292,8 @@ async def index_elasticsearch_documents(
document_metadata=metadata, document_metadata=metadata,
search_space_id=search_space_id, search_space_id=search_space_id,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
) )
# Create chunks and attach to document (persist via relationship) # Create chunks and attach to document (persist via relationship)

View file

@ -220,6 +220,7 @@ async def index_github_repos(
user_id=user_id, user_id=user_id,
task_logger=task_logger, task_logger=task_logger,
log_entry=log_entry, log_entry=log_entry,
connector_id=connector_id,
) )
documents_processed += docs_created documents_processed += docs_created
@ -292,6 +293,7 @@ async def _process_repository_digest(
user_id: str, user_id: str,
task_logger: TaskLoggingService, task_logger: TaskLoggingService,
log_entry, log_entry,
connector_id: int,
) -> int: ) -> int:
""" """
Process a repository digest and create documents. Process a repository digest and create documents.
@ -426,6 +428,8 @@ async def _process_repository_digest(
search_space_id=search_space_id, search_space_id=search_space_id,
chunks=chunks_data, chunks=chunks_data,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
) )
session.add(document) session.add(document)

View file

@ -520,6 +520,8 @@ async def index_google_calendar_events(
embedding=summary_embedding, embedding=summary_embedding,
chunks=chunks, chunks=chunks,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
) )
session.add(document) session.add(document)

View file

@ -767,6 +767,7 @@ async def _process_single_file(
session=session, session=session,
task_logger=task_logger, task_logger=task_logger,
log_entry=log_entry, log_entry=log_entry,
connector_id=connector_id,
) )
if error: if error:

View file

@ -413,7 +413,6 @@ async def index_google_gmail_messages(
"subject": subject, "subject": subject,
"sender": sender, "sender": sender,
"date": date_str, "date": date_str,
"connector_id": connector_id,
}, },
content=summary_content, content=summary_content,
content_hash=content_hash, content_hash=content_hash,
@ -421,6 +420,8 @@ async def index_google_gmail_messages(
embedding=summary_embedding, embedding=summary_embedding,
chunks=chunks, chunks=chunks,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
) )
session.add(document) session.add(document)
documents_indexed += 1 documents_indexed += 1

View file

@ -380,6 +380,8 @@ async def index_jira_issues(
embedding=summary_embedding, embedding=summary_embedding,
chunks=chunks, chunks=chunks,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
) )
session.add(document) session.add(document)

View file

@ -413,6 +413,8 @@ async def index_linear_issues(
embedding=summary_embedding, embedding=summary_embedding,
chunks=chunks, chunks=chunks,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
) )
session.add(document) session.add(document)

View file

@ -476,6 +476,8 @@ async def index_luma_events(
embedding=summary_embedding, embedding=summary_embedding,
chunks=chunks, chunks=chunks,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
) )
session.add(document) session.add(document)

View file

@ -429,6 +429,7 @@ async def index_notion_pages(
} }
existing_document.chunks = chunks existing_document.chunks = chunks
existing_document.updated_at = get_current_timestamp() existing_document.updated_at = get_current_timestamp()
existing_document.connector_id = connector_id
documents_indexed += 1 documents_indexed += 1
logger.info(f"Successfully updated Notion page: {page_title}") logger.info(f"Successfully updated Notion page: {page_title}")
@ -501,6 +502,8 @@ async def index_notion_pages(
embedding=summary_embedding, embedding=summary_embedding,
chunks=chunks, chunks=chunks,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
) )
session.add(document) session.add(document)

View file

@ -500,6 +500,8 @@ async def index_obsidian_vault(
embedding=embedding, embedding=embedding,
chunks=chunks, chunks=chunks,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
) )
session.add(new_document) session.add(new_document)

View file

@ -389,6 +389,8 @@ async def index_slack_messages(
content_hash=content_hash, content_hash=content_hash,
unique_identifier_hash=unique_identifier_hash, unique_identifier_hash=unique_identifier_hash,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
) )
session.add(document) session.add(document)

View file

@ -430,6 +430,8 @@ async def index_teams_messages(
content_hash=content_hash, content_hash=content_hash,
unique_identifier_hash=unique_identifier_hash, unique_identifier_hash=unique_identifier_hash,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
) )
session.add(document) session.add(document)

View file

@ -383,6 +383,8 @@ async def index_crawled_urls(
embedding=summary_embedding, embedding=summary_embedding,
chunks=chunks, chunks=chunks,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector_id,
) )
session.add(document) session.add(document)

View file

@ -8,10 +8,17 @@ and stores it as searchable documents in the database.
import logging import logging
from typing import Any from typing import Any
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Document, DocumentType from app.db import (
Document,
DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
SearchSpace,
)
from app.services.llm_service import get_document_summary_llm from app.services.llm_service import get_document_summary_llm
from app.utils.document_converters import ( from app.utils.document_converters import (
create_document_chunks, create_document_chunks,
@ -35,6 +42,7 @@ async def add_circleback_meeting_document(
markdown_content: str, markdown_content: str,
metadata: dict[str, Any], metadata: dict[str, Any],
search_space_id: int, search_space_id: int,
connector_id: int | None = None,
) -> Document | None: ) -> Document | None:
""" """
Process and store a Circleback meeting document. Process and store a Circleback meeting document.
@ -46,6 +54,7 @@ async def add_circleback_meeting_document(
markdown_content: Meeting content formatted as markdown markdown_content: Meeting content formatted as markdown
metadata: Meeting metadata dictionary metadata: Meeting metadata dictionary
search_space_id: ID of the search space search_space_id: ID of the search space
connector_id: ID of the Circleback connector (for deletion support)
Returns: Returns:
Document object if successful, None if failed or duplicate Document object if successful, None if failed or duplicate
@ -125,6 +134,30 @@ async def add_circleback_meeting_document(
**metadata, **metadata,
} }
# Fetch the user who set up the Circleback connector (preferred)
# or fall back to search space owner if no connector found
created_by_user_id = None
# Try to find the Circleback connector for this search space
connector_result = await session.execute(
select(SearchSourceConnector.user_id).where(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CIRCLEBACK_CONNECTOR,
)
)
connector_user = connector_result.scalar_one_or_none()
if connector_user:
# Use the user who set up the Circleback connector
created_by_user_id = connector_user
else:
# Fallback: use search space owner if no connector found
search_space_result = await session.execute(
select(SearchSpace.user_id).where(SearchSpace.id == search_space_id)
)
created_by_user_id = search_space_result.scalar_one_or_none()
# Update or create document # Update or create document
if existing_document: if existing_document:
# Update existing document # Update existing document
@ -138,6 +171,9 @@ async def add_circleback_meeting_document(
existing_document.blocknote_document = blocknote_json existing_document.blocknote_document = blocknote_json
existing_document.content_needs_reindexing = False existing_document.content_needs_reindexing = False
existing_document.updated_at = get_current_timestamp() existing_document.updated_at = get_current_timestamp()
# Ensure connector_id is set (backfill for documents created before this field)
if connector_id is not None:
existing_document.connector_id = connector_id
await session.commit() await session.commit()
await session.refresh(existing_document) await session.refresh(existing_document)
@ -160,6 +196,8 @@ async def add_circleback_meeting_document(
blocknote_document=blocknote_json, blocknote_document=blocknote_json,
content_needs_reindexing=False, content_needs_reindexing=False,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=created_by_user_id,
connector_id=connector_id,
) )
session.add(document) session.add(document)

View file

@ -185,6 +185,7 @@ async def add_extension_received_document(
unique_identifier_hash=unique_identifier_hash, unique_identifier_hash=unique_identifier_hash,
blocknote_document=blocknote_json, blocknote_document=blocknote_json,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
) )
session.add(document) session.add(document)

View file

@ -526,6 +526,8 @@ async def add_received_file_document_using_unstructured(
blocknote_document=blocknote_json, blocknote_document=blocknote_json,
content_needs_reindexing=False, content_needs_reindexing=False,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector.get("connector_id") if connector else None,
) )
session.add(document) session.add(document)
@ -665,6 +667,8 @@ async def add_received_file_document_using_llamacloud(
blocknote_document=blocknote_json, blocknote_document=blocknote_json,
content_needs_reindexing=False, content_needs_reindexing=False,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector.get("connector_id") if connector else None,
) )
session.add(document) session.add(document)
@ -829,6 +833,8 @@ async def add_received_file_document_using_docling(
blocknote_document=blocknote_json, blocknote_document=blocknote_json,
content_needs_reindexing=False, content_needs_reindexing=False,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector.get("connector_id") if connector else None,
) )
session.add(document) session.add(document)
@ -849,7 +855,7 @@ async def add_received_file_document_using_docling(
async def _update_document_from_connector( async def _update_document_from_connector(
document: Document | None, connector: dict | None, session: AsyncSession document: Document | None, connector: dict | None, session: AsyncSession
) -> None: ) -> None:
"""Helper to update document type and metadata from connector info.""" """Helper to update document type, metadata, and connector_id from connector info."""
if document and connector: if document and connector:
if "type" in connector: if "type" in connector:
document.document_type = connector["type"] document.document_type = connector["type"]
@ -861,6 +867,9 @@ async def _update_document_from_connector(
# Expand existing metadata with connector metadata # Expand existing metadata with connector metadata
merged = {**document.document_metadata, **connector["metadata"]} merged = {**document.document_metadata, **connector["metadata"]}
document.document_metadata = merged document.document_metadata = merged
# Set connector_id if provided for de-indexing support
if "connector_id" in connector:
document.connector_id = connector["connector_id"]
await session.commit() await session.commit()

View file

@ -295,6 +295,8 @@ async def add_received_markdown_file_document(
unique_identifier_hash=primary_hash, unique_identifier_hash=primary_hash,
blocknote_document=blocknote_json, blocknote_document=blocknote_json,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
connector_id=connector.get("connector_id") if connector else None,
) )
session.add(document) session.add(document)

View file

@ -357,6 +357,7 @@ async def add_youtube_video_document(
unique_identifier_hash=unique_identifier_hash, unique_identifier_hash=unique_identifier_hash,
blocknote_document=blocknote_json, blocknote_document=blocknote_json,
updated_at=get_current_timestamp(), updated_at=get_current_timestamp(),
created_by_id=user_id,
) )
session.add(document) session.add(document)

View file

@ -23,17 +23,20 @@ from app.db import (
get_default_roles_config, get_default_roles_config,
get_user_db, get_user_db,
) )
from app.utils.refresh_tokens import create_refresh_token
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BearerResponse(BaseModel): class BearerResponse(BaseModel):
access_token: str access_token: str
refresh_token: str
token_type: str token_type: str
SECRET = config.SECRET_KEY SECRET = config.SECRET_KEY
if config.AUTH_TYPE == "GOOGLE": if config.AUTH_TYPE == "GOOGLE":
from httpx_oauth.clients.google import GoogleOAuth2 from httpx_oauth.clients.google import GoogleOAuth2
@ -183,7 +186,10 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]: def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
return JWTStrategy(secret=SECRET, lifetime_seconds=3600 * 24) return JWTStrategy(
secret=SECRET,
lifetime_seconds=config.ACCESS_TOKEN_LIFETIME_SECONDS,
)
# # COOKIE AUTH | Uncomment if you want to use cookie auth. # # COOKIE AUTH | Uncomment if you want to use cookie auth.
@ -209,9 +215,30 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
# BEARER AUTH CODE. # BEARER AUTH CODE.
class CustomBearerTransport(BearerTransport): class CustomBearerTransport(BearerTransport):
async def get_login_response(self, token: str) -> Response: async def get_login_response(self, token: str) -> Response:
bearer_response = BearerResponse(access_token=token, token_type="bearer") import jwt
redirect_url = f"{config.NEXT_FRONTEND_URL}/auth/callback?token={bearer_response.access_token}"
# Decode JWT to get user_id for refresh token creation
try:
payload = jwt.decode(token, SECRET, algorithms=["HS256"], options={"verify_aud": False})
user_id = uuid.UUID(payload.get("sub"))
refresh_token = await create_refresh_token(user_id)
except Exception as e:
logger.error(f"Failed to create refresh token: {e}")
# Fall back to response without refresh token
refresh_token = ""
bearer_response = BearerResponse(
access_token=token,
refresh_token=refresh_token,
token_type="bearer",
)
if config.AUTH_TYPE == "GOOGLE": if config.AUTH_TYPE == "GOOGLE":
redirect_url = (
f"{config.NEXT_FRONTEND_URL}/auth/callback"
f"?token={bearer_response.access_token}"
f"&refresh_token={bearer_response.refresh_token}"
)
return RedirectResponse(redirect_url, status_code=302) return RedirectResponse(redirect_url, status_code=302)
else: else:
return JSONResponse(bearer_response.model_dump()) return JSONResponse(bearer_response.model_dump())

View file

@ -0,0 +1,153 @@
"""Utilities for managing refresh tokens."""
import hashlib
import logging
import secrets
import uuid
from datetime import UTC, datetime, timedelta
from sqlalchemy import select, update
from app.config import config
from app.db import RefreshToken, async_session_maker
logger = logging.getLogger(__name__)
def generate_refresh_token() -> str:
"""Generate a cryptographically secure refresh token."""
return secrets.token_urlsafe(32)
def hash_token(token: str) -> str:
"""Hash a token for secure storage."""
return hashlib.sha256(token.encode()).hexdigest()
async def create_refresh_token(
user_id: uuid.UUID,
family_id: uuid.UUID | None = None,
) -> str:
"""
Create and store a new refresh token for a user.
Args:
user_id: The user's ID
family_id: Optional family ID for token rotation
Returns:
The plaintext refresh token
"""
token = generate_refresh_token()
token_hash = hash_token(token)
expires_at = datetime.now(UTC) + timedelta(
seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS
)
if family_id is None:
family_id = uuid.uuid4()
async with async_session_maker() as session:
refresh_token = RefreshToken(
user_id=user_id,
token_hash=token_hash,
expires_at=expires_at,
family_id=family_id,
)
session.add(refresh_token)
await session.commit()
return token
async def validate_refresh_token(token: str) -> RefreshToken | None:
"""
Validate a refresh token. Handles reuse detection.
Args:
token: The plaintext refresh token
Returns:
RefreshToken if valid, None otherwise
"""
token_hash = hash_token(token)
async with async_session_maker() as session:
result = await session.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
refresh_token = result.scalars().first()
if not refresh_token:
return None
# Reuse detection: revoked token used while family has active tokens
if refresh_token.is_revoked:
active = await session.execute(
select(RefreshToken).where(
RefreshToken.family_id == refresh_token.family_id,
RefreshToken.is_revoked == False, # noqa: E712
RefreshToken.expires_at > datetime.now(UTC),
)
)
if active.scalars().first():
# Revoke entire family
await session.execute(
update(RefreshToken)
.where(RefreshToken.family_id == refresh_token.family_id)
.values(is_revoked=True)
)
await session.commit()
logger.warning(f"Token reuse detected for user {refresh_token.user_id}")
return None
if refresh_token.is_expired:
return None
return refresh_token
async def rotate_refresh_token(old_token: RefreshToken) -> str:
"""Revoke old token and create new one in same family."""
async with async_session_maker() as session:
await session.execute(
update(RefreshToken)
.where(RefreshToken.id == old_token.id)
.values(is_revoked=True)
)
await session.commit()
return await create_refresh_token(old_token.user_id, old_token.family_id)
async def revoke_refresh_token(token: str) -> bool:
"""
Revoke a single refresh token by its plaintext value.
Args:
token: The plaintext refresh token
Returns:
True if token was found and revoked, False otherwise
"""
token_hash = hash_token(token)
async with async_session_maker() as session:
result = await session.execute(
update(RefreshToken)
.where(RefreshToken.token_hash == token_hash)
.values(is_revoked=True)
)
await session.commit()
return result.rowcount > 0
async def revoke_all_user_tokens(user_id: uuid.UUID) -> None:
"""Revoke all refresh tokens for a user (logout all devices)."""
async with async_session_maker() as session:
await session.execute(
update(RefreshToken)
.where(RefreshToken.user_id == user_id)
.values(is_revoked=True)
)
await session.commit()

View file

@ -32,8 +32,6 @@ dependencies = [
"en-core-web-sm@https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl", "en-core-web-sm@https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl",
"static-ffmpeg>=2.13", "static-ffmpeg>=2.13",
"tavily-python>=0.3.2", "tavily-python>=0.3.2",
"unstructured-client>=0.30.0",
"unstructured[all-docs]>=0.16.25",
"uvicorn[standard]>=0.34.0", "uvicorn[standard]>=0.34.0",
"validators>=0.34.0", "validators>=0.34.0",
"youtube-transcript-api>=1.0.3", "youtube-transcript-api>=1.0.3",
@ -45,7 +43,6 @@ dependencies = [
"firecrawl-py>=4.9.0", "firecrawl-py>=4.9.0",
"boto3>=1.35.0", "boto3>=1.35.0",
"langchain-community>=0.3.31", "langchain-community>=0.3.31",
"langchain-unstructured>=1.0.0",
"litellm>=1.80.10", "litellm>=1.80.10",
"langchain-litellm>=0.3.5", "langchain-litellm>=0.3.5",
"fake-useragent>=2.2.0", "fake-useragent>=2.2.0",
@ -62,6 +59,9 @@ dependencies = [
"deepagents>=0.3.8", "deepagents>=0.3.8",
"langchain>=1.2.6", "langchain>=1.2.6",
"langgraph>=1.0.5", "langgraph>=1.0.5",
"unstructured[all-docs]>=0.18.31",
"unstructured-client>=0.42.3",
"langchain-unstructured>=1.0.1",
] ]
[dependency-groups] [dependency-groups]

View file

@ -39,7 +39,7 @@ backend_pid=$!
sleep 5 sleep 5
echo "Starting Celery Worker..." echo "Starting Celery Worker..."
celery -A app.celery_app worker --loglevel=info & celery -A app.celery_app worker --loglevel=info --autoscale=128,4 &
celery_worker_pid=$! celery_worker_pid=$!
# Wait a bit for worker to initialize # Wait a bit for worker to initialize

File diff suppressed because it is too large Load diff

View file

@ -9,7 +9,6 @@ import {
import { useQueryClient } from "@tanstack/react-query"; import { useQueryClient } from "@tanstack/react-query";
import { useAtomValue, useSetAtom } from "jotai"; import { useAtomValue, useSetAtom } from "jotai";
import { useParams, useSearchParams } from "next/navigation"; import { useParams, useSearchParams } from "next/navigation";
import { useTranslations } from "next-intl";
import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "sonner"; import { toast } from "sonner";
import { z } from "zod"; import { z } from "zod";
@ -39,9 +38,10 @@ import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast";
import { LinkPreviewToolUI } from "@/components/tool-ui/link-preview"; import { LinkPreviewToolUI } from "@/components/tool-ui/link-preview";
import { ScrapeWebpageToolUI } from "@/components/tool-ui/scrape-webpage"; import { ScrapeWebpageToolUI } from "@/components/tool-ui/scrape-webpage";
import { RecallMemoryToolUI, SaveMemoryToolUI } from "@/components/tool-ui/user-memory"; import { RecallMemoryToolUI, SaveMemoryToolUI } from "@/components/tool-ui/user-memory";
import { Spinner } from "@/components/ui/spinner"; import { Skeleton } from "@/components/ui/skeleton";
import { useChatSessionStateSync } from "@/hooks/use-chat-session-state"; import { useChatSessionStateSync } from "@/hooks/use-chat-session-state";
import { useMessagesElectric } from "@/hooks/use-messages-electric"; import { useMessagesElectric } from "@/hooks/use-messages-electric";
import { documentsApiService } from "@/lib/apis/documents-api.service";
// import { WriteTodosToolUI } from "@/components/tool-ui/write-todos"; // import { WriteTodosToolUI } from "@/components/tool-ui/write-todos";
import { getBearerToken } from "@/lib/auth-utils"; import { getBearerToken } from "@/lib/auth-utils";
import { createAttachmentAdapter, extractAttachmentContent } from "@/lib/chat/attachment-adapter"; import { createAttachmentAdapter, extractAttachmentContent } from "@/lib/chat/attachment-adapter";
@ -53,12 +53,10 @@ import {
} from "@/lib/chat/podcast-state"; } from "@/lib/chat/podcast-state";
import { import {
appendMessage, appendMessage,
type ChatVisibility,
createThread, createThread,
getRegenerateUrl, getRegenerateUrl,
getThreadFull, getThreadFull,
getThreadMessages, getThreadMessages,
type MessageRecord,
type ThreadRecord, type ThreadRecord,
} from "@/lib/chat/thread-persistence"; } from "@/lib/chat/thread-persistence";
import { import {
@ -137,7 +135,6 @@ interface ThinkingStepData {
} }
export default function NewChatPage() { export default function NewChatPage() {
const t = useTranslations("dashboard");
const params = useParams(); const params = useParams();
const queryClient = useQueryClient(); const queryClient = useQueryClient();
const [isInitializing, setIsInitializing] = useState(true); const [isInitializing, setIsInitializing] = useState(true);
@ -329,6 +326,33 @@ export default function NewChatPage() {
initializeThread(); initializeThread();
}, [initializeThread]); }, [initializeThread]);
// Prefetch document titles for @ mention picker
// Runs when user lands on page so data is ready when they type @
useEffect(() => {
if (!searchSpaceId) return;
const prefetchParams = {
search_space_id: searchSpaceId,
page: 0,
page_size: 20,
};
queryClient.prefetchQuery({
queryKey: ["document-titles", prefetchParams],
queryFn: () => documentsApiService.searchDocumentTitles({ queryParams: prefetchParams }),
staleTime: 60 * 1000,
});
queryClient.prefetchQuery({
queryKey: ["surfsense-docs-mention", "", false],
queryFn: () =>
documentsApiService.getSurfsenseDocs({
queryParams: { page: 0, page_size: 20 },
}),
staleTime: 3 * 60 * 1000,
});
}, [searchSpaceId, queryClient]);
// Handle scroll to comment from URL query params (e.g., from inbox item click) // Handle scroll to comment from URL query params (e.g., from inbox item click)
const searchParams = useSearchParams(); const searchParams = useSearchParams();
const targetCommentIdParam = searchParams.get("commentId"); const targetCommentIdParam = searchParams.get("commentId");
@ -367,19 +391,6 @@ export default function NewChatPage() {
setIsRunning(false); setIsRunning(false);
}, []); }, []);
// Handle visibility change from ChatShareButton
const handleVisibilityChange = useCallback(
(newVisibility: ChatVisibility) => {
setCurrentThread((prev) => (prev ? { ...prev, visibility: newVisibility } : null));
// Refetch all thread queries so sidebar reflects the change immediately
// Use predicate to match any query that starts with "threads"
queryClient.refetchQueries({
predicate: (query) => Array.isArray(query.queryKey) && query.queryKey[0] === "threads",
});
},
[queryClient]
);
// Handle new message from user // Handle new message from user
const onNew = useCallback( const onNew = useCallback(
async (message: AppendMessage) => { async (message: AppendMessage) => {
@ -426,7 +437,10 @@ export default function NewChatPage() {
let isNewThread = false; let isNewThread = false;
if (!currentThreadId) { if (!currentThreadId) {
try { try {
const newThread = await createThread(searchSpaceId, "New Chat"); // Create thread with truncated prompt as initial title
const initialTitle =
userQuery.trim().slice(0, 100) + (userQuery.trim().length > 100 ? "..." : "");
const newThread = await createThread(searchSpaceId, initialTitle);
currentThreadId = newThread.id; currentThreadId = newThread.id;
setThreadId(currentThreadId); setThreadId(currentThreadId);
// Set currentThread so ChatHeader can show share button immediately // Set currentThread so ChatHeader can show share button immediately
@ -816,6 +830,26 @@ export default function NewChatPage() {
break; break;
} }
case "data-thread-title-update": {
// Handle thread title update from LLM-generated title
const titleData = parsed.data as { threadId: number; title: string };
if (titleData?.title && titleData?.threadId === currentThreadId) {
// Update current thread state with new title
setCurrentThread((prev) =>
prev ? { ...prev, title: titleData.title } : prev
);
// Invalidate thread list to refresh sidebar
queryClient.invalidateQueries({
queryKey: ["threads", String(searchSpaceId)],
});
// Invalidate thread detail for breadcrumb update
queryClient.invalidateQueries({
queryKey: ["threads", String(searchSpaceId), "detail", String(titleData.threadId)],
});
}
break;
}
case "error": case "error":
throw new Error(parsed.errorText || "Server error"); throw new Error(parsed.errorText || "Server error");
} }
@ -1346,14 +1380,11 @@ export default function NewChatPage() {
); );
// Handle reloading/refreshing the last AI response // Handle reloading/refreshing the last AI response
const onReload = useCallback( const onReload = useCallback(async () => {
async (parentId: string | null) => { // parentId is the ID of the message to reload from (the user message)
// parentId is the ID of the message to reload from (the user message) // We call regenerate without a query to use the same query
// We call regenerate without a query to use the same query await handleRegenerate(null);
await handleRegenerate(null); }, [handleRegenerate]);
},
[handleRegenerate]
);
// Create external store runtime with attachment support // Create external store runtime with attachment support
const runtime = useExternalStoreRuntime({ const runtime = useExternalStoreRuntime({
@ -1372,9 +1403,39 @@ export default function NewChatPage() {
// Show loading state only when loading an existing thread // Show loading state only when loading an existing thread
if (isInitializing) { if (isInitializing) {
return ( return (
<div className="flex h-[calc(100vh-64px)] flex-col items-center justify-center gap-4"> <div className="flex h-[calc(100vh-64px)] flex-col bg-background px-4">
<Spinner size="lg" /> <div className="mx-auto w-full max-w-[44rem] flex flex-1 flex-col gap-6 py-8">
<div className="text-sm text-muted-foreground">{t("loading_chat")}</div> {/* User message */}
<div className="flex justify-end">
<Skeleton className="h-12 w-56 rounded-2xl" />
</div>
{/* Assistant message */}
<div className="flex flex-col gap-2">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-[85%]" />
<Skeleton className="h-4 w-[70%]" />
</div>
{/* User message */}
<div className="flex justify-end">
<Skeleton className="h-12 w-40 rounded-2xl" />
</div>
{/* Assistant message */}
<div className="flex flex-col gap-2">
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-[90%]" />
<Skeleton className="h-4 w-[60%]" />
</div>
</div>
{/* Input bar */}
<div className="sticky bottom-0 pb-6 bg-background">
<div className="mx-auto w-full max-w-[44rem]">
<Skeleton className="h-24 w-full rounded-2xl" />
</div>
</div>
</div> </div>
); );
} }

View file

@ -6,6 +6,7 @@ import {
Brain, Brain,
ChevronRight, ChevronRight,
FileText, FileText,
Globe,
type LucideIcon, type LucideIcon,
Menu, Menu,
MessageSquare, MessageSquare,
@ -16,6 +17,7 @@ import { AnimatePresence, motion } from "motion/react";
import { useParams, useRouter } from "next/navigation"; import { useParams, useRouter } from "next/navigation";
import { useTranslations } from "next-intl"; import { useTranslations } from "next-intl";
import { useCallback, useEffect, useState } from "react"; import { useCallback, useEffect, useState } from "react";
import { PublicChatSnapshotsManager } from "@/components/public-chat-snapshots/public-chat-snapshots-manager";
import { GeneralSettingsManager } from "@/components/settings/general-settings-manager"; import { GeneralSettingsManager } from "@/components/settings/general-settings-manager";
import { LLMRoleManager } from "@/components/settings/llm-role-manager"; import { LLMRoleManager } from "@/components/settings/llm-role-manager";
import { ModelConfigManager } from "@/components/settings/model-config-manager"; import { ModelConfigManager } from "@/components/settings/model-config-manager";
@ -56,6 +58,12 @@ const settingsNavItems: SettingsNavItem[] = [
descriptionKey: "nav_system_instructions_desc", descriptionKey: "nav_system_instructions_desc",
icon: MessageSquare, icon: MessageSquare,
}, },
{
id: "public-links",
labelKey: "nav_public_links",
descriptionKey: "nav_public_links_desc",
icon: Globe,
},
]; ];
function SettingsSidebar({ function SettingsSidebar({
@ -276,6 +284,9 @@ function SettingsContent({
{activeSection === "models" && <ModelConfigManager searchSpaceId={searchSpaceId} />} {activeSection === "models" && <ModelConfigManager searchSpaceId={searchSpaceId} />}
{activeSection === "roles" && <LLMRoleManager searchSpaceId={searchSpaceId} />} {activeSection === "roles" && <LLMRoleManager searchSpaceId={searchSpaceId} />}
{activeSection === "prompts" && <PromptConfigManager searchSpaceId={searchSpaceId} />} {activeSection === "prompts" && <PromptConfigManager searchSpaceId={searchSpaceId} />}
{activeSection === "public-links" && (
<PublicChatSnapshotsManager searchSpaceId={searchSpaceId} />
)}
</motion.div> </motion.div>
</AnimatePresence> </AnimatePresence>
</div> </div>

View file

@ -11,6 +11,7 @@ import {
Crown, Crown,
Edit2, Edit2,
FileText, FileText,
Globe,
Hash, Hash,
Link2, Link2,
LinkIcon, LinkIcon,
@ -206,7 +207,15 @@ export default function TeamManagementPage() {
); );
const handleUpdateRole = useCallback( const handleUpdateRole = useCallback(
async (roleId: number, data: { permissions?: string[] }): Promise<Role> => { async (
roleId: number,
data: {
name?: string;
description?: string | null;
permissions?: string[];
is_default?: boolean;
}
): Promise<Role> => {
const request: UpdateRoleRequest = { const request: UpdateRoleRequest = {
search_space_id: searchSpaceId, search_space_id: searchSpaceId,
role_id: roleId, role_id: roleId,
@ -827,6 +836,12 @@ const CATEGORY_CONFIG: Record<
description: "Manage search space settings", description: "Manage search space settings",
order: 10, order: 10,
}, },
public_sharing: {
label: "Public Chat Sharing",
icon: Globe,
description: "Share chats publicly via links",
order: 11,
},
}; };
const ACTION_LABELS: Record<string, string> = { const ACTION_LABELS: Record<string, string> = {
@ -944,7 +959,7 @@ function RolesTab({
roles, roles,
groupedPermissions, groupedPermissions,
loading, loading,
onUpdateRole: _onUpdateRole, onUpdateRole,
onDeleteRole, onDeleteRole,
onCreateRole, onCreateRole,
canUpdate, canUpdate,
@ -954,7 +969,15 @@ function RolesTab({
roles: Role[]; roles: Role[];
groupedPermissions: Record<string, PermissionWithDescription[]>; groupedPermissions: Record<string, PermissionWithDescription[]>;
loading: boolean; loading: boolean;
onUpdateRole: (roleId: number, data: { permissions?: string[] }) => Promise<Role>; onUpdateRole: (
roleId: number,
data: {
name?: string;
description?: string | null;
permissions?: string[];
is_default?: boolean;
}
) => Promise<Role>;
onDeleteRole: (roleId: number) => Promise<boolean>; onDeleteRole: (roleId: number) => Promise<boolean>;
onCreateRole: (data: CreateRoleRequest["data"]) => Promise<Role>; onCreateRole: (data: CreateRoleRequest["data"]) => Promise<Role>;
canUpdate: boolean; canUpdate: boolean;
@ -962,6 +985,7 @@ function RolesTab({
canCreate: boolean; canCreate: boolean;
}) { }) {
const [showCreateRole, setShowCreateRole] = useState(false); const [showCreateRole, setShowCreateRole] = useState(false);
const [editingRoleId, setEditingRoleId] = useState<number | null>(null);
if (loading) { if (loading) {
return ( return (
@ -997,6 +1021,21 @@ function RolesTab({
/> />
)} )}
{/* Edit Role Form */}
{editingRoleId !== null &&
(() => {
const roleToEdit = roles.find((r) => r.id === editingRoleId);
if (!roleToEdit) return null;
return (
<EditRoleSection
role={roleToEdit}
groupedPermissions={groupedPermissions}
onUpdateRole={onUpdateRole}
onCancel={() => setEditingRoleId(null)}
/>
);
})()}
{/* Roles Grid */} {/* Roles Grid */}
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4"> <div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
{roles.map((role, index) => ( {roles.map((role, index) => (
@ -1055,13 +1094,9 @@ function RolesTab({
<MoreHorizontal className="h-4 w-4" /> <MoreHorizontal className="h-4 w-4" />
</Button> </Button>
</DropdownMenuTrigger> </DropdownMenuTrigger>
<DropdownMenuContent align="end"> <DropdownMenuContent align="end" onCloseAutoFocus={(e) => e.preventDefault()}>
{canUpdate && ( {canUpdate && (
<DropdownMenuItem <DropdownMenuItem onClick={() => setEditingRoleId(role.id)}>
onClick={() => {
// TODO: Implement edit role dialog/modal
}}
>
<Edit2 className="h-4 w-4 mr-2" /> <Edit2 className="h-4 w-4 mr-2" />
Edit Role Edit Role
</DropdownMenuItem> </DropdownMenuItem>
@ -2026,3 +2061,371 @@ function CreateRoleSection({
</motion.div> </motion.div>
); );
} }
function EditRoleSection({
role,
groupedPermissions,
onUpdateRole,
onCancel,
}: {
role: Role;
groupedPermissions: Record<string, PermissionWithDescription[]>;
onUpdateRole: (
roleId: number,
data: {
name?: string;
description?: string | null;
permissions?: string[];
is_default?: boolean;
}
) => Promise<Role>;
onCancel: () => void;
}) {
const [saving, setSaving] = useState(false);
const [name, setName] = useState(role.name);
const [description, setDescription] = useState(role.description || "");
const [selectedPermissions, setSelectedPermissions] = useState<string[]>(role.permissions);
const [isDefault, setIsDefault] = useState(role.is_default);
const [expandedCategories, setExpandedCategories] = useState<string[]>([]);
// Sort categories by order
const sortedCategories = useMemo(() => {
return Object.keys(groupedPermissions).sort((a, b) => {
const orderA = CATEGORY_CONFIG[a]?.order ?? 99;
const orderB = CATEGORY_CONFIG[b]?.order ?? 99;
return orderA - orderB;
});
}, [groupedPermissions]);
const handleSave = async () => {
if (!name.trim()) {
toast.error("Please enter a role name");
return;
}
setSaving(true);
try {
await onUpdateRole(role.id, {
name: name.trim(),
description: description.trim() || null,
permissions: selectedPermissions,
is_default: isDefault,
});
toast.success("Role updated successfully");
onCancel();
} catch (error) {
console.error("Failed to update role:", error);
toast.error("Failed to update role");
} finally {
setSaving(false);
}
};
const togglePermission = useCallback((perm: string) => {
setSelectedPermissions((prev) =>
prev.includes(perm) ? prev.filter((p) => p !== perm) : [...prev, perm]
);
}, []);
const toggleCategory = useCallback(
(category: string) => {
const categoryPerms = groupedPermissions[category]?.map((p) => p.value) || [];
const allSelected = categoryPerms.every((p) => selectedPermissions.includes(p));
if (allSelected) {
setSelectedPermissions((prev) => prev.filter((p) => !categoryPerms.includes(p)));
} else {
setSelectedPermissions((prev) => [...new Set([...prev, ...categoryPerms])]);
}
},
[groupedPermissions, selectedPermissions]
);
const toggleCategoryExpanded = useCallback((category: string) => {
setExpandedCategories((prev) =>
prev.includes(category) ? prev.filter((c) => c !== category) : [...prev, category]
);
}, []);
const getCategoryStats = useCallback(
(category: string) => {
const perms = groupedPermissions[category] || [];
const selected = perms.filter((p) => selectedPermissions.includes(p.value)).length;
return { selected, total: perms.length, allSelected: selected === perms.length };
},
[groupedPermissions, selectedPermissions]
);
return (
<motion.div
initial={{ opacity: 0, y: -10 }}
animate={{ opacity: 1, y: 0 }}
exit={{ opacity: 0, y: -10 }}
className="mb-6"
>
<Card className="border-primary/20 bg-gradient-to-br from-primary/5 via-background to-background">
<CardHeader className="pb-4">
<div className="flex items-center justify-between">
<div className="flex items-center gap-3">
<div className="h-10 w-10 rounded-xl bg-primary/10 flex items-center justify-center">
<Edit2 className="h-5 w-5 text-primary" />
</div>
<div>
<CardTitle className="text-lg">Edit Role</CardTitle>
<CardDescription className="text-sm">
Modify permissions for "{role.name}"
</CardDescription>
</div>
</div>
<Button variant="ghost" size="icon" onClick={onCancel}>
<Trash2 className="h-4 w-4" />
</Button>
</div>
</CardHeader>
<CardContent className="space-y-6">
{/* Role Details */}
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
<div className="space-y-2">
<Label htmlFor="edit-role-name">Role Name *</Label>
<Input
id="edit-role-name"
placeholder="e.g., Content Manager"
value={name}
onChange={(e) => setName(e.target.value)}
/>
</div>
<div className="space-y-2">
<Label htmlFor="edit-role-description">Description</Label>
<Input
id="edit-role-description"
placeholder="Brief description of this role"
value={description}
onChange={(e) => setDescription(e.target.value)}
/>
</div>
</div>
{/* Default Role Checkbox */}
<div className="flex items-center gap-3 p-3 rounded-lg bg-muted/50">
<Checkbox
id="edit-is-default"
checked={isDefault}
onCheckedChange={(checked) => setIsDefault(checked === true)}
/>
<div className="flex-1">
<Label htmlFor="edit-is-default" className="cursor-pointer font-medium">
Set as default role
</Label>
<p className="text-xs text-muted-foreground">
New members without a specific role will be assigned this role
</p>
</div>
</div>
{/* Permissions Section */}
<div className="space-y-3">
<div className="flex items-center justify-between">
<Label className="text-sm font-medium">
Permissions ({selectedPermissions.length} selected)
</Label>
<Button
type="button"
variant="ghost"
size="sm"
className="text-xs h-7"
onClick={() =>
setExpandedCategories(
expandedCategories.length === sortedCategories.length ? [] : sortedCategories
)
}
>
{expandedCategories.length === sortedCategories.length
? "Collapse All"
: "Expand All"}
</Button>
</div>
<div className="space-y-2">
{sortedCategories.map((category) => {
const config = CATEGORY_CONFIG[category] || {
label: category,
icon: FileText,
description: "",
order: 99,
};
const IconComponent = config.icon;
const stats = getCategoryStats(category);
const isExpanded = expandedCategories.includes(category);
const perms = groupedPermissions[category] || [];
return (
<div key={category} className="rounded-lg border bg-card overflow-hidden">
{/* Category Header */}
<div
className={cn(
"flex items-center justify-between p-3 cursor-pointer hover:bg-muted/50 transition-colors",
stats.allSelected && "bg-primary/5"
)}
onClick={() => toggleCategoryExpanded(category)}
onKeyDown={(e) => {
if (e.key === "Enter" || e.key === " ") {
e.preventDefault();
toggleCategoryExpanded(category);
}
}}
tabIndex={0}
role="button"
>
<div className="flex items-center gap-3">
<div
className={cn(
"h-8 w-8 rounded-lg flex items-center justify-center",
stats.selected > 0 ? "bg-primary/10" : "bg-muted"
)}
>
<IconComponent
className={cn(
"h-4 w-4",
stats.selected > 0 ? "text-primary" : "text-muted-foreground"
)}
/>
</div>
<div>
<div className="flex items-center gap-2">
<span className="font-medium text-sm">{config.label}</span>
<Badge
variant={stats.selected > 0 ? "default" : "secondary"}
className="text-xs h-5"
>
{stats.selected}/{stats.total}
</Badge>
</div>
<p className="text-xs text-muted-foreground hidden md:block">
{config.description}
</p>
</div>
</div>
<div className="flex items-center gap-2">
<Checkbox
checked={stats.allSelected}
onCheckedChange={() => toggleCategory(category)}
onClick={(e) => e.stopPropagation()}
aria-label={`Select all ${config.label} permissions`}
/>
<motion.div
animate={{ rotate: isExpanded ? 180 : 0 }}
transition={{ duration: 0.2 }}
>
<svg
className="h-4 w-4 text-muted-foreground"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth={2}
d="M19 9l-7 7-7-7"
/>
</svg>
</motion.div>
</div>
</div>
{/* Permissions List */}
{isExpanded && (
<motion.div
initial={{ height: 0, opacity: 0 }}
animate={{ height: "auto", opacity: 1 }}
exit={{ height: 0, opacity: 0 }}
transition={{ duration: 0.2 }}
className="border-t"
>
<div className="p-3 space-y-1">
{perms.map((perm) => {
const action = perm.value.split(":")[1];
const actionConfig = ACTION_DISPLAY[action] || {
label: action,
color: "text-gray-600 bg-gray-500/10",
};
const isSelected = selectedPermissions.includes(perm.value);
return (
<div
key={perm.value}
className={cn(
"flex items-center justify-between p-2 rounded-md cursor-pointer transition-colors",
isSelected
? "bg-primary/10 hover:bg-primary/15"
: "hover:bg-muted/50"
)}
onClick={() => togglePermission(perm.value)}
onKeyDown={(e) => {
if (e.key === "Enter" || e.key === " ") {
e.preventDefault();
togglePermission(perm.value);
}
}}
tabIndex={0}
role="checkbox"
aria-checked={isSelected}
>
<div className="flex items-center gap-3 flex-1 min-w-0">
<Checkbox
checked={isSelected}
onCheckedChange={() => togglePermission(perm.value)}
onClick={(e) => e.stopPropagation()}
/>
<div className="flex-1 min-w-0">
<div className="flex items-center gap-2">
<span
className={cn(
"text-xs font-medium px-2 py-0.5 rounded",
actionConfig.color
)}
>
{actionConfig.label}
</span>
</div>
<p className="text-xs text-muted-foreground mt-0.5 truncate">
{perm.description}
</p>
</div>
</div>
</div>
);
})}
</div>
</motion.div>
)}
</div>
);
})}
</div>
</div>
{/* Actions */}
<div className="flex items-center justify-end gap-3 pt-4 border-t">
<Button variant="outline" onClick={onCancel}>
Cancel
</Button>
<Button onClick={handleSave} disabled={saving || !name.trim()}>
{saving ? (
<>
<Spinner size="sm" className="mr-2" />
Saving...
</>
) : (
<>
<Check className="h-4 w-4 mr-2" />
Save Changes
</>
)}
</Button>
</div>
</CardContent>
</Card>
</motion.div>
);
}

View file

@ -5,6 +5,7 @@ import { ArrowLeft, ChevronRight, X } from "lucide-react";
import { AnimatePresence, motion } from "motion/react"; import { AnimatePresence, motion } from "motion/react";
import { useTranslations } from "next-intl"; import { useTranslations } from "next-intl";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { APP_VERSION } from "@/lib/env-config";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
export interface SettingsNavItem { export interface SettingsNavItem {
@ -148,6 +149,11 @@ export function UserSettingsSidebar({
); );
})} })}
</nav> </nav>
{/* Version display */}
<div className="mt-auto border-t px-6 py-3">
<p className="text-xs text-muted-foreground/50">v{APP_VERSION}</p>
</div>
</aside> </aside>
</> </>
); );

View file

@ -101,7 +101,7 @@ export default function RootLayout({
attribute="class" attribute="class"
enableSystem enableSystem
disableTransitionOnChange disableTransitionOnChange
defaultTheme="light" defaultTheme="system"
> >
<RootProvider> <RootProvider>
<ReactQueryClientProvider> <ReactQueryClientProvider>

View file

@ -1,31 +0,0 @@
import { atomWithMutation } from "jotai-tanstack-query";
import { toast } from "sonner";
import type {
CreateSnapshotRequest,
CreateSnapshotResponse,
} from "@/contracts/types/chat-threads.types";
import { chatThreadsApiService } from "@/lib/apis/chat-threads-api.service";
export const createSnapshotMutationAtom = atomWithMutation(() => ({
mutationFn: async (request: CreateSnapshotRequest) => {
return chatThreadsApiService.createSnapshot(request);
},
onSuccess: (response: CreateSnapshotResponse) => {
// Construct URL using frontend origin (backend returns its own URL which differs)
const publicUrl = `${window.location.origin}/public/${response.share_token}`;
navigator.clipboard.writeText(publicUrl);
if (response.is_new) {
toast.success("Public link created and copied to clipboard", {
description: "Anyone with this link can view a snapshot of this chat",
});
} else {
toast.success("Public link copied to clipboard", {
description: "This snapshot already exists",
});
}
},
onError: (error: Error) => {
console.error("Failed to create snapshot:", error);
toast.error("Failed to create public link");
},
}));

View file

@ -0,0 +1,53 @@
import { atomWithMutation } from "jotai-tanstack-query";
import { toast } from "sonner";
import type {
PublicChatSnapshotCreateRequest,
PublicChatSnapshotCreateResponse,
PublicChatSnapshotDeleteRequest,
} from "@/contracts/types/chat-threads.types";
import { chatThreadsApiService } from "@/lib/apis/chat-threads-api.service";
import { cacheKeys } from "@/lib/query-client/cache-keys";
import { queryClient } from "@/lib/query-client/client";
export const createPublicChatSnapshotMutationAtom = atomWithMutation(() => ({
mutationFn: async (request: PublicChatSnapshotCreateRequest) => {
return chatThreadsApiService.createPublicChatSnapshot(request);
},
onSuccess: (response: PublicChatSnapshotCreateResponse) => {
queryClient.invalidateQueries({
queryKey: cacheKeys.publicChatSnapshots.all,
});
const publicUrl = `${window.location.origin}/public/${response.share_token}`;
navigator.clipboard.writeText(publicUrl);
if (response.is_new) {
toast.success("Public link created and copied to clipboard", {
description: "Anyone with this link can view a snapshot of this chat",
});
} else {
toast.success("Public link copied to clipboard", {
description: "This snapshot already exists",
});
}
},
onError: (error: Error) => {
console.error("Failed to create snapshot:", error);
toast.error("Failed to create public link");
},
}));
export const deletePublicChatSnapshotMutationAtom = atomWithMutation(() => ({
mutationFn: async (request: PublicChatSnapshotDeleteRequest) => {
return chatThreadsApiService.deletePublicChatSnapshot(request);
},
onSuccess: () => {
queryClient.invalidateQueries({
queryKey: cacheKeys.publicChatSnapshots.all,
});
toast.success("Public link deleted");
},
onError: (error: Error) => {
console.error("Failed to delete public chat link:", error);
toast.error("Failed to delete public link");
},
}));

View file

@ -0,0 +1,22 @@
import { atomWithQuery } from "jotai-tanstack-query";
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
import { chatThreadsApiService } from "@/lib/apis/chat-threads-api.service";
import { cacheKeys } from "@/lib/query-client/cache-keys";
export const publicChatSnapshotsAtom = atomWithQuery((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
return {
queryKey: cacheKeys.publicChatSnapshots.bySearchSpace(Number(searchSpaceId) || 0),
enabled: !!searchSpaceId,
staleTime: 5 * 60 * 1000,
queryFn: async () => {
if (!searchSpaceId) {
return { snapshots: [] };
}
return chatThreadsApiService.listPublicChatSnapshotsForSearchSpace({
search_space_id: Number(searchSpaceId),
});
},
};
});

View file

@ -3,7 +3,7 @@
import { useSearchParams } from "next/navigation"; import { useSearchParams } from "next/navigation";
import { useEffect } from "react"; import { useEffect } from "react";
import { useGlobalLoadingEffect } from "@/hooks/use-global-loading"; import { useGlobalLoadingEffect } from "@/hooks/use-global-loading";
import { getAndClearRedirectPath, setBearerToken } from "@/lib/auth-utils"; import { getAndClearRedirectPath, setBearerToken, setRefreshToken } from "@/lib/auth-utils";
import { trackLoginSuccess } from "@/lib/posthog/events"; import { trackLoginSuccess } from "@/lib/posthog/events";
interface TokenHandlerProps { interface TokenHandlerProps {
@ -35,8 +35,9 @@ const TokenHandler = ({
// Only run on client-side // Only run on client-side
if (typeof window === "undefined") return; if (typeof window === "undefined") return;
// Get token from URL parameters // Get tokens from URL parameters
const token = searchParams.get(tokenParamName); const token = searchParams.get(tokenParamName);
const refreshToken = searchParams.get("refresh_token");
if (token) { if (token) {
try { try {
@ -50,10 +51,15 @@ const TokenHandler = ({
// Clear the flag for future logins // Clear the flag for future logins
sessionStorage.removeItem("login_success_tracked"); sessionStorage.removeItem("login_success_tracked");
// Store token in localStorage using both methods for compatibility // Store access token in localStorage using both methods for compatibility
localStorage.setItem(storageKey, token); localStorage.setItem(storageKey, token);
setBearerToken(token); setBearerToken(token);
// Store refresh token if provided
if (refreshToken) {
setRefreshToken(refreshToken);
}
// Check if there's a saved redirect path from before the auth flow // Check if there's a saved redirect path from before the auth flow
const savedRedirectPath = getAndClearRedirectPath(); const savedRedirectPath = getAndClearRedirectPath();

View file

@ -1,7 +1,8 @@
"use client"; "use client";
import { BadgeCheck, LogOut } from "lucide-react"; import { BadgeCheck, Loader2, LogOut } from "lucide-react";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { useState } from "react";
import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar"; import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { import {
@ -13,6 +14,7 @@ import {
DropdownMenuSeparator, DropdownMenuSeparator,
DropdownMenuTrigger, DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu"; } from "@/components/ui/dropdown-menu";
import { logout } from "@/lib/auth-utils";
import { cleanupElectric } from "@/lib/electric/client"; import { cleanupElectric } from "@/lib/electric/client";
import { resetUser, trackLogout } from "@/lib/posthog/events"; import { resetUser, trackLogout } from "@/lib/posthog/events";
@ -26,8 +28,11 @@ export function UserDropdown({
}; };
}) { }) {
const router = useRouter(); const router = useRouter();
const [isLoggingOut, setIsLoggingOut] = useState(false);
const handleLogout = async () => { const handleLogout = async () => {
if (isLoggingOut) return;
setIsLoggingOut(true);
try { try {
// Track logout event and reset PostHog identity // Track logout event and reset PostHog identity
trackLogout(); trackLogout();
@ -41,15 +46,17 @@ export function UserDropdown({
console.warn("[Logout] Electric cleanup failed (will be handled on next login):", err); console.warn("[Logout] Electric cleanup failed (will be handled on next login):", err);
} }
// Revoke refresh token on server and clear all tokens from localStorage
await logout();
if (typeof window !== "undefined") { if (typeof window !== "undefined") {
localStorage.removeItem("surfsense_bearer_token");
window.location.href = "/"; window.location.href = "/";
} }
} catch (error) { } catch (error) {
console.error("Error during logout:", error); console.error("Error during logout:", error);
// Optionally, provide user feedback // Even if there's an error, try to clear tokens and redirect
await logout();
if (typeof window !== "undefined") { if (typeof window !== "undefined") {
localStorage.removeItem("surfsense_bearer_token");
window.location.href = "/"; window.location.href = "/";
} }
} }
@ -85,9 +92,17 @@ export function UserDropdown({
</DropdownMenuItem> </DropdownMenuItem>
</DropdownMenuGroup> </DropdownMenuGroup>
<DropdownMenuSeparator /> <DropdownMenuSeparator />
<DropdownMenuItem onClick={handleLogout} className="text-xs md:text-sm"> <DropdownMenuItem
<LogOut className="mr-2 h-3.5 w-3.5 md:h-4 md:w-4" /> onClick={handleLogout}
Log out className="text-xs md:text-sm"
disabled={isLoggingOut}
>
{isLoggingOut ? (
<Loader2 className="mr-2 h-3.5 w-3.5 md:h-4 md:w-4 animate-spin" />
) : (
<LogOut className="mr-2 h-3.5 w-3.5 md:h-4 md:w-4" />
)}
{isLoggingOut ? "Logging out..." : "Log out"}
</DropdownMenuItem> </DropdownMenuItem>
</DropdownMenuContent> </DropdownMenuContent>
</DropdownMenu> </DropdownMenu>

View file

@ -4,20 +4,19 @@ import {
ErrorPrimitive, ErrorPrimitive,
MessagePrimitive, MessagePrimitive,
useAssistantState, useAssistantState,
useMessage,
} from "@assistant-ui/react"; } from "@assistant-ui/react";
import { useAtom, useAtomValue, useSetAtom } from "jotai"; import { useAtom, useAtomValue } from "jotai";
import { CheckIcon, CopyIcon, DownloadIcon, MessageSquare, RefreshCwIcon } from "lucide-react"; import { CheckIcon, CopyIcon, DownloadIcon, MessageSquare, RefreshCwIcon } from "lucide-react";
import type { FC } from "react"; import type { FC } from "react";
import { useContext, useEffect, useMemo, useRef, useState } from "react"; import { useContext, useEffect, useMemo, useRef, useState } from "react";
import { import {
addingCommentToMessageIdAtom, addingCommentToMessageIdAtom,
clearTargetCommentIdAtom,
commentsCollapsedAtom, commentsCollapsedAtom,
commentsEnabledAtom, commentsEnabledAtom,
targetCommentIdAtom, targetCommentIdAtom,
} from "@/atoms/chat/current-thread.atom"; } from "@/atoms/chat/current-thread.atom";
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms"; import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
import { BranchPicker } from "@/components/assistant-ui/branch-picker";
import { MarkdownText } from "@/components/assistant-ui/markdown-text"; import { MarkdownText } from "@/components/assistant-ui/markdown-text";
import { import {
ThinkingStepsContext, ThinkingStepsContext,
@ -84,7 +83,6 @@ const AssistantMessageInner: FC = () => {
</div> </div>
<div className="aui-assistant-message-footer mt-1 mb-5 ml-2 flex"> <div className="aui-assistant-message-footer mt-1 mb-5 ml-2 flex">
<BranchPicker />
<AssistantActionBar /> <AssistantActionBar />
</div> </div>
</> </>
@ -126,7 +124,6 @@ export const AssistantMessage: FC = () => {
// Target comment navigation - read target from global atom // Target comment navigation - read target from global atom
const targetCommentId = useAtomValue(targetCommentIdAtom); const targetCommentId = useAtomValue(targetCommentIdAtom);
const clearTargetCommentId = useSetAtom(clearTargetCommentIdAtom);
// Check if target comment belongs to this message (including replies) // Check if target comment belongs to this message (including replies)
const hasTargetComment = useMemo(() => { const hasTargetComment = useMemo(() => {
@ -263,6 +260,8 @@ export const AssistantMessage: FC = () => {
}; };
const AssistantActionBar: FC = () => { const AssistantActionBar: FC = () => {
const { isLast } = useMessage();
return ( return (
<ActionBarPrimitive.Root <ActionBarPrimitive.Root
hideWhenRunning hideWhenRunning
@ -285,11 +284,14 @@ const AssistantActionBar: FC = () => {
<DownloadIcon /> <DownloadIcon />
</TooltipIconButton> </TooltipIconButton>
</ActionBarPrimitive.ExportMarkdown> </ActionBarPrimitive.ExportMarkdown>
<ActionBarPrimitive.Reload asChild> {/* Only allow regenerating the last assistant message */}
<TooltipIconButton tooltip="Refresh"> {isLast && (
<RefreshCwIcon /> <ActionBarPrimitive.Reload asChild>
</TooltipIconButton> <TooltipIconButton tooltip="Refresh">
</ActionBarPrimitive.Reload> <RefreshCwIcon />
</TooltipIconButton>
</ActionBarPrimitive.Reload>
)}
</ActionBarPrimitive.Root> </ActionBarPrimitive.Root>
); );
}; };

View file

@ -1,32 +0,0 @@
import { BranchPickerPrimitive } from "@assistant-ui/react";
import { ChevronLeftIcon, ChevronRightIcon } from "lucide-react";
import type { FC } from "react";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
import { cn } from "@/lib/utils";
export const BranchPicker: FC<BranchPickerPrimitive.Root.Props> = ({ className, ...rest }) => {
return (
<BranchPickerPrimitive.Root
hideWhenSingleBranch
className={cn(
"aui-branch-picker-root -ml-2 mr-2 inline-flex items-center text-muted-foreground text-xs",
className
)}
{...rest}
>
<BranchPickerPrimitive.Previous asChild>
<TooltipIconButton tooltip="Previous">
<ChevronLeftIcon />
</TooltipIconButton>
</BranchPickerPrimitive.Previous>
<span className="aui-branch-picker-state font-medium">
<BranchPickerPrimitive.Number /> / <BranchPickerPrimitive.Count />
</span>
<BranchPickerPrimitive.Next asChild>
<TooltipIconButton tooltip="Next">
<ChevronRightIcon />
</TooltipIconButton>
</BranchPickerPrimitive.Next>
</BranchPickerPrimitive.Root>
);
};

View file

@ -536,10 +536,11 @@ export const InlineMentionEditor = forwardRef<InlineMentionEditorRef, InlineMent
role="textbox" role="textbox"
aria-multiline="true" aria-multiline="true"
/> />
{/* Placeholder */} {/* Placeholder with fade animation on change */}
{isEmpty && ( {isEmpty && (
<div <div
className="absolute top-0 left-0 pointer-events-none text-muted-foreground text-sm" key={placeholder}
className="absolute top-0 left-0 pointer-events-none text-muted-foreground text-sm animate-in fade-in duration-1000"
aria-hidden="true" aria-hidden="true"
> >
{placeholder} {placeholder}

View file

@ -65,6 +65,16 @@ import type { Document } from "@/contracts/types/document.types";
import { useCommentsElectric } from "@/hooks/use-comments-electric"; import { useCommentsElectric } from "@/hooks/use-comments-electric";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
/** Placeholder texts that cycle in new chats when input is empty */
const CYCLING_PLACEHOLDERS = [
"Ask SurfSense anything or @mention docs.",
"Generate a podcast from marketing tips in the company handbook.",
"Sum up our vacation policy from Drive.",
"Give me a brief overview of the most urgent tickets in Jira and Linear.",
"Create a concise table of today's top ten emails and calendar events.",
"Check if this week's Slack messages reference any GitHub issues.",
];
interface ThreadProps { interface ThreadProps {
messageThinkingSteps?: Map<string, ThinkingStep[]>; messageThinkingSteps?: Map<string, ThinkingStep[]>;
header?: React.ReactNode; header?: React.ReactNode;
@ -228,6 +238,30 @@ const Composer: FC = () => {
const isThreadEmpty = useAssistantState(({ thread }) => thread.isEmpty); const isThreadEmpty = useAssistantState(({ thread }) => thread.isEmpty);
const isThreadRunning = useAssistantState(({ thread }) => thread.isRunning); const isThreadRunning = useAssistantState(({ thread }) => thread.isRunning);
// Cycling placeholder state - only cycles in new chats
const [placeholderIndex, setPlaceholderIndex] = useState(0);
// Cycle through placeholders every 4 seconds when thread is empty (new chat)
useEffect(() => {
// Only cycle when thread is empty (new chat)
if (!isThreadEmpty) {
// Reset to first placeholder when chat becomes active
setPlaceholderIndex(0);
return;
}
const intervalId = setInterval(() => {
setPlaceholderIndex((prev) => (prev + 1) % CYCLING_PLACEHOLDERS.length);
}, 6000);
return () => clearInterval(intervalId);
}, [isThreadEmpty]);
// Compute current placeholder - only cycle in new chats
const currentPlaceholder = isThreadEmpty
? CYCLING_PLACEHOLDERS[placeholderIndex]
: CYCLING_PLACEHOLDERS[0];
// Live collaboration state // Live collaboration state
const { data: currentUser } = useAtomValue(currentUserAtom); const { data: currentUser } = useAtomValue(currentUserAtom);
const { data: members } = useAtomValue(membersAtom); const { data: members } = useAtomValue(membersAtom);
@ -410,7 +444,7 @@ const Composer: FC = () => {
<div ref={editorContainerRef} className="aui-composer-input-wrapper px-3 pt-3 pb-6"> <div ref={editorContainerRef} className="aui-composer-input-wrapper px-3 pt-3 pb-6">
<InlineMentionEditor <InlineMentionEditor
ref={editorRef} ref={editorRef}
placeholder="Ask SurfSense or @mention docs" placeholder={currentPlaceholder}
onMentionTrigger={handleMentionTrigger} onMentionTrigger={handleMentionTrigger}
onMentionClose={handleMentionClose} onMentionClose={handleMentionClose}
onChange={handleEditorChange} onChange={handleEditorChange}

View file

@ -4,7 +4,6 @@ import { FileText, PencilIcon } from "lucide-react";
import { type FC, useState } from "react"; import { type FC, useState } from "react";
import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom"; import { messageDocumentsMapAtom } from "@/atoms/chat/mentioned-documents.atom";
import { UserMessageAttachments } from "@/components/assistant-ui/attachment"; import { UserMessageAttachments } from "@/components/assistant-ui/attachment";
import { BranchPicker } from "@/components/assistant-ui/branch-picker";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
interface AuthorMetadata { interface AuthorMetadata {
@ -95,24 +94,47 @@ export const UserMessage: FC = () => {
</div> </div>
)} )}
</div> </div>
<BranchPicker className="aui-user-branch-picker -mr-1 col-span-full col-start-1 row-start-3 justify-end" />
</MessagePrimitive.Root> </MessagePrimitive.Root>
); );
}; };
const UserActionBar: FC = () => { const UserActionBar: FC = () => {
const isThreadRunning = useAssistantState(({ thread }) => thread.isRunning);
// Get current message ID
const currentMessageId = useAssistantState(({ message }) => message?.id);
// Find the last user message ID in the thread (computed once, memoized by selector)
const lastUserMessageId = useAssistantState(({ thread }) => {
const messages = thread.messages;
for (let i = messages.length - 1; i >= 0; i--) {
if (messages[i].role === "user") {
return messages[i].id;
}
}
return null;
});
// Simple comparison - no iteration needed per message
const isLastUserMessage = currentMessageId === lastUserMessageId;
// Show edit button only on the last user message and when thread is not running
const canEdit = isLastUserMessage && !isThreadRunning;
return ( return (
<ActionBarPrimitive.Root <ActionBarPrimitive.Root
hideWhenRunning hideWhenRunning
autohide="not-last" autohide="not-last"
className="aui-user-action-bar-root flex flex-col items-end" className="aui-user-action-bar-root flex flex-col items-end"
> >
<ActionBarPrimitive.Edit asChild> {/* Only allow editing the last user message */}
<TooltipIconButton tooltip="Edit" className="aui-user-action-edit p-4"> {canEdit && (
<PencilIcon /> <ActionBarPrimitive.Edit asChild>
</TooltipIconButton> <TooltipIconButton tooltip="Edit" className="aui-user-action-edit p-4">
</ActionBarPrimitive.Edit> <PencilIcon />
</TooltipIconButton>
</ActionBarPrimitive.Edit>
)}
</ActionBarPrimitive.Root> </ActionBarPrimitive.Root>
); );
}; };

View file

@ -20,7 +20,10 @@ export function ContactFormGridWithDetails() {
Contact Contact
</h2> </h2>
<p className="mt-8 max-w-lg text-center text-base text-neutral-600 dark:text-neutral-400"> <p className="mt-8 max-w-lg text-center text-base text-neutral-600 dark:text-neutral-400">
We'd love to hear from you. Schedule a meeting or send us an email. We'd love to hear from you!
</p>
<p className="mt-4 max-w-lg text-center text-base text-neutral-600 dark:text-neutral-400">
Schedule a meeting with our Head of Product, Eric Lammertsma, or send us an email.
</p> </p>
<div className="mt-10 flex flex-col items-center gap-6"> <div className="mt-10 flex flex-col items-center gap-6">

View file

@ -14,6 +14,7 @@ import {
} from "@/components/ui/breadcrumb"; } from "@/components/ui/breadcrumb";
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service"; import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
import { authenticatedFetch, getBearerToken } from "@/lib/auth-utils"; import { authenticatedFetch, getBearerToken } from "@/lib/auth-utils";
import { getThreadFull } from "@/lib/chat/thread-persistence";
import { cacheKeys } from "@/lib/query-client/cache-keys"; import { cacheKeys } from "@/lib/query-client/cache-keys";
interface BreadcrumbItemInterface { interface BreadcrumbItemInterface {
@ -34,6 +35,16 @@ export function DashboardBreadcrumb() {
enabled: !!searchSpaceId, enabled: !!searchSpaceId,
}); });
// Extract chat thread ID from pathname for chat pages
const chatThreadId = segments[2] === "new-chat" && segments[3] ? segments[3] : null;
// Fetch thread details when on a chat page with a thread ID
const { data: threadData } = useQuery({
queryKey: ["threads", searchSpaceId, "detail", chatThreadId],
queryFn: () => getThreadFull(Number(chatThreadId)),
enabled: !!chatThreadId && !!searchSpaceId,
});
// State to store document title for editor breadcrumb // State to store document title for editor breadcrumb
const [documentTitle, setDocumentTitle] = useState<string | null>(null); const [documentTitle, setDocumentTitle] = useState<string | null>(null);
@ -144,10 +155,11 @@ export function DashboardBreadcrumb() {
} }
// Handle new-chat sub-sections (thread IDs) // Handle new-chat sub-sections (thread IDs)
// Don't show thread ID in breadcrumb - users identify chats by content, not by ID // Show the chat title if available, otherwise fall back to "Chat"
if (section === "new-chat") { if (section === "new-chat") {
const chatLabel = threadData?.title || t("chat") || "Chat";
breadcrumbs.push({ breadcrumbs.push({
label: t("chat") || "Chat", label: chatLabel,
}); });
return breadcrumbs; return breadcrumbs;
} }

View file

@ -61,7 +61,8 @@ export function FeaturesCards() {
<CardContent> <CardContent>
<p className="text-sm"> <p className="text-sm">
Choose from 100+ leading LLMs and seamlessly call any model on demand. Choose from 100+ leading LLMs, seamlessly calling any model on demand. Even run
on-prem local LLM inference via vLLM, Ollama, llama.cpp, LM Studio, and more.
</p> </p>
</CardContent> </CardContent>
</Card> </Card>
@ -74,9 +75,9 @@ export function FeaturesCards() {
const CardDecorator = ({ children }: { children: ReactNode }) => ( const CardDecorator = ({ children }: { children: ReactNode }) => (
<div <div
aria-hidden aria-hidden
className="relative mx-auto size-36 [mask-image:radial-gradient(ellipse_50%_50%_at_50%_50%,#000_70%,transparent_100%)]" className="relative mx-auto size-36 mask-[radial-gradient(ellipse_50%_50%_at_50%_50%,#000_70%,transparent_100%)]"
> >
<div className="absolute inset-0 [--border:black] dark:[--border:white] bg-[linear-gradient(to_right,var(--border)_1px,transparent_1px),linear-gradient(to_bottom,var(--border)_1px,transparent_1px)] bg-[size:24px_24px] opacity-10" /> <div className="absolute inset-0 [--border:black] dark:[--border:white] bg-[linear-gradient(to_right,var(--border)_1px,transparent_1px),linear-gradient(to_bottom,var(--border)_1px,transparent_1px)] bg-size-[24px_24px] opacity-10" />
<div className="bg-background absolute inset-0 m-auto flex size-12 items-center justify-center border-t border-l"> <div className="bg-background absolute inset-0 m-auto flex size-12 items-center justify-center border-t border-l">
{children} {children}
</div> </div>

View file

@ -205,8 +205,8 @@ function ContactSalesButton() {
return ( return (
<motion.div whileHover={{ scale: 1.02, y: -2 }} whileTap={{ scale: 0.98 }}> <motion.div whileHover={{ scale: 1.02, y: -2 }} whileTap={{ scale: 0.98 }}>
<Link <Link
href="https://calendly.com/eric-surfsense/surfsense-meeting" href="/contact"
target="_blank" //target="_blank"
rel="noopener noreferrer" rel="noopener noreferrer"
className="group relative z-20 flex h-11 w-full cursor-pointer items-center justify-center gap-2 rounded-xl bg-white px-6 py-2.5 text-sm font-semibold text-neutral-700 shadow-lg ring-1 ring-neutral-200/50 transition-shadow duration-300 hover:shadow-xl sm:w-56 dark:bg-neutral-900 dark:text-neutral-200 dark:ring-neutral-700/50" className="group relative z-20 flex h-11 w-full cursor-pointer items-center justify-center gap-2 rounded-xl bg-white px-6 py-2.5 text-sm font-semibold text-neutral-700 shadow-lg ring-1 ring-neutral-200/50 transition-shadow duration-300 hover:shadow-xl sm:w-56 dark:bg-neutral-900 dark:text-neutral-200 dark:ring-neutral-700/50"
> >
@ -288,7 +288,7 @@ const CollisionMechanism = React.forwardRef<
} }
}; };
const animationInterval = setInterval(checkCollision, 50); const animationInterval = setInterval(checkCollision, 100);
return () => clearInterval(animationInterval); return () => clearInterval(animationInterval);
}, [cycleCollisionDetected, containerRef]); }, [cycleCollisionDetected, containerRef]);
@ -338,7 +338,7 @@ const CollisionMechanism = React.forwardRef<
repeatDelay: beamOptions.repeatDelay || 0, repeatDelay: beamOptions.repeatDelay || 0,
}} }}
className={cn( className={cn(
"absolute left-96 top-20 m-auto h-14 w-px rounded-full bg-linear-to-t from-orange-500 via-yellow-500 to-transparent", "absolute left-96 top-20 m-auto h-14 w-px rounded-full bg-linear-to-t from-orange-500 via-yellow-500 to-transparent will-change-transform",
beamOptions.className beamOptions.className
)} )}
/> />

View file

@ -19,10 +19,9 @@ export const Navbar = () => {
const [isScrolled, setIsScrolled] = useState(false); const [isScrolled, setIsScrolled] = useState(false);
const navItems = [ const navItems = [
{ name: "Contact Us", link: "/contact" },
{ name: "Pricing", link: "/pricing" }, { name: "Pricing", link: "/pricing" },
{ name: "Contact\u00A0Us", link: "/contact" },
{ name: "Changelog", link: "/changelog" }, { name: "Changelog", link: "/changelog" },
// { name: "Sign In", link: "/login" },
{ name: "Docs", link: "/docs" }, { name: "Docs", link: "/docs" },
]; ];
@ -61,10 +60,13 @@ const DesktopNav = ({ navItems, isScrolled }: any) => {
: "bg-transparent border border-transparent" : "bg-transparent border border-transparent"
)} )}
> >
<div className="flex flex-1 flex-row items-center gap-0.5"> <Link
href="/"
className="flex flex-1 flex-row items-center gap-0.5 hover:opacity-80 transition-opacity"
>
<Logo className="h-8 w-8 rounded-md" /> <Logo className="h-8 w-8 rounded-md" />
<span className="dark:text-white/90 text-gray-800 text-lg font-bold">SurfSense</span> <span className="dark:text-white/90 text-gray-800 text-lg font-bold">SurfSense</span>
</div> </Link>
<div className="hidden flex-1 flex-row items-center justify-center space-x-2 text-sm font-medium text-zinc-600 transition duration-200 hover:text-zinc-800 lg:flex lg:space-x-2"> <div className="hidden flex-1 flex-row items-center justify-center space-x-2 text-sm font-medium text-zinc-600 transition duration-200 hover:text-zinc-800 lg:flex lg:space-x-2">
{navItems.map((navItem: any, idx: number) => ( {navItems.map((navItem: any, idx: number) => (
<Link <Link
@ -139,10 +141,13 @@ const MobileNav = ({ navItems, isScrolled }: any) => {
)} )}
> >
<div className="flex w-full flex-row items-center justify-between"> <div className="flex w-full flex-row items-center justify-between">
<div className="flex flex-row items-center gap-2"> <Link
href="/"
className="flex flex-row items-center gap-2 hover:opacity-80 transition-opacity"
>
<Logo className="h-8 w-8 rounded-md" /> <Logo className="h-8 w-8 rounded-md" />
<span className="dark:text-white/90 text-gray-800 text-lg font-bold">SurfSense</span> <span className="dark:text-white/90 text-gray-800 text-lg font-bold">SurfSense</span>
</div> </Link>
<button <button
type="button" type="button"
onClick={() => setOpen(!open)} onClick={() => setOpen(!open)}

View file

@ -2,7 +2,7 @@
import { useQuery, useQueryClient } from "@tanstack/react-query"; import { useQuery, useQueryClient } from "@tanstack/react-query";
import { useAtomValue, useSetAtom } from "jotai"; import { useAtomValue, useSetAtom } from "jotai";
import { AlertTriangle, Inbox, LogOut, SquareLibrary, Trash2 } from "lucide-react"; import { AlertTriangle, Inbox, LogOut, PencilIcon, SquareLibrary, Trash2 } from "lucide-react";
import { useParams, usePathname, useRouter } from "next/navigation"; import { useParams, usePathname, useRouter } from "next/navigation";
import { useTranslations } from "next-intl"; import { useTranslations } from "next-intl";
import { useTheme } from "next-themes"; import { useTheme } from "next-themes";
@ -21,10 +21,12 @@ import {
DialogHeader, DialogHeader,
DialogTitle, DialogTitle,
} from "@/components/ui/dialog"; } from "@/components/ui/dialog";
import { Input } from "@/components/ui/input";
import { isPageLimitExceededMetadata } from "@/contracts/types/inbox.types"; import { isPageLimitExceededMetadata } from "@/contracts/types/inbox.types";
import { useInbox } from "@/hooks/use-inbox"; import { useInbox } from "@/hooks/use-inbox";
import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service"; import { searchSpacesApiService } from "@/lib/apis/search-spaces-api.service";
import { deleteThread, fetchThreads, updateThread } from "@/lib/chat/thread-persistence"; import { deleteThread, fetchThreads, updateThread } from "@/lib/chat/thread-persistence";
import { logout } from "@/lib/auth-utils";
import { cleanupElectric } from "@/lib/electric/client"; import { cleanupElectric } from "@/lib/electric/client";
import { resetUser, trackLogout } from "@/lib/posthog/events"; import { resetUser, trackLogout } from "@/lib/posthog/events";
import { cacheKeys } from "@/lib/query-client/cache-keys"; import { cacheKeys } from "@/lib/query-client/cache-keys";
@ -109,7 +111,6 @@ export function LayoutDataProvider({
// This ensures each tab has independent pagination and data loading // This ensures each tab has independent pagination and data loading
const userId = user?.id ? String(user.id) : null; const userId = user?.id ? String(user.id) : null;
// Mentions: Only fetch "new_mention" type notifications
const { const {
inboxItems: mentionItems, inboxItems: mentionItems,
unreadCount: mentionUnreadCount, unreadCount: mentionUnreadCount,
@ -121,11 +122,9 @@ export function LayoutDataProvider({
markAllAsRead: markAllMentionsAsRead, markAllAsRead: markAllMentionsAsRead,
} = useInbox(userId, Number(searchSpaceId) || null, "new_mention"); } = useInbox(userId, Number(searchSpaceId) || null, "new_mention");
// Status: Fetch all types (will be filtered client-side to status types)
// We pass null to get all, then InboxSidebar filters to status types
const { const {
inboxItems: statusItems, inboxItems: statusItems,
unreadCount: statusUnreadCount, unreadCount: allUnreadCount,
loading: statusLoading, loading: statusLoading,
loadingMore: statusLoadingMore, loadingMore: statusLoadingMore,
hasMore: statusHasMore, hasMore: statusHasMore,
@ -134,8 +133,8 @@ export function LayoutDataProvider({
markAllAsRead: markAllStatusAsRead, markAllAsRead: markAllStatusAsRead,
} = useInbox(userId, Number(searchSpaceId) || null, null); } = useInbox(userId, Number(searchSpaceId) || null, null);
// Combined unread count for nav badge (mentions take priority for visibility) const totalUnreadCount = allUnreadCount;
const totalUnreadCount = mentionUnreadCount + statusUnreadCount; const statusOnlyUnreadCount = Math.max(0, allUnreadCount - mentionUnreadCount);
// Track seen notification IDs to detect new page_limit_exceeded notifications // Track seen notification IDs to detect new page_limit_exceeded notifications
const seenPageLimitNotifications = useRef<Set<number>>(new Set()); const seenPageLimitNotifications = useRef<Set<number>>(new Set());
@ -207,6 +206,12 @@ export function LayoutDataProvider({
const [chatToDelete, setChatToDelete] = useState<{ id: number; name: string } | null>(null); const [chatToDelete, setChatToDelete] = useState<{ id: number; name: string } | null>(null);
const [isDeletingChat, setIsDeletingChat] = useState(false); const [isDeletingChat, setIsDeletingChat] = useState(false);
// Rename dialog state
const [showRenameChatDialog, setShowRenameChatDialog] = useState(false);
const [chatToRename, setChatToRename] = useState<{ id: number; name: string } | null>(null);
const [newChatTitle, setNewChatTitle] = useState("");
const [isRenamingChat, setIsRenamingChat] = useState(false);
// Delete/Leave search space dialog state // Delete/Leave search space dialog state
const [showDeleteSearchSpaceDialog, setShowDeleteSearchSpaceDialog] = useState(false); const [showDeleteSearchSpaceDialog, setShowDeleteSearchSpaceDialog] = useState(false);
const [showLeaveSearchSpaceDialog, setShowLeaveSearchSpaceDialog] = useState(false); const [showLeaveSearchSpaceDialog, setShowLeaveSearchSpaceDialog] = useState(false);
@ -421,6 +426,12 @@ export function LayoutDataProvider({
setShowDeleteChatDialog(true); setShowDeleteChatDialog(true);
}, []); }, []);
const handleChatRename = useCallback((chat: ChatItem) => {
setChatToRename({ id: chat.id, name: chat.name });
setNewChatTitle(chat.name);
setShowRenameChatDialog(true);
}, []);
const handleChatArchive = useCallback( const handleChatArchive = useCallback(
async (chat: ChatItem) => { async (chat: ChatItem) => {
const newArchivedState = !chat.archived; const newArchivedState = !chat.archived;
@ -464,12 +475,15 @@ export function LayoutDataProvider({
console.warn("[Logout] Electric cleanup failed (will be handled on next login):", err); console.warn("[Logout] Electric cleanup failed (will be handled on next login):", err);
} }
// Revoke refresh token on server and clear all tokens from localStorage
await logout();
if (typeof window !== "undefined") { if (typeof window !== "undefined") {
localStorage.removeItem("surfsense_bearer_token");
router.push("/"); router.push("/");
} }
} catch (error) { } catch (error) {
console.error("Error during logout:", error); console.error("Error during logout:", error);
await logout();
router.push("/"); router.push("/");
} }
}, [router]); }, [router]);
@ -501,6 +515,29 @@ export function LayoutDataProvider({
} }
}, [chatToDelete, queryClient, searchSpaceId, router, currentChatId]); }, [chatToDelete, queryClient, searchSpaceId, router, currentChatId]);
// Rename handler
const confirmRenameChat = useCallback(async () => {
if (!chatToRename || !newChatTitle.trim()) return;
setIsRenamingChat(true);
try {
await updateThread(chatToRename.id, { title: newChatTitle.trim() });
toast.success(tSidebar("chat_renamed") || "Chat renamed");
queryClient.invalidateQueries({ queryKey: ["threads", searchSpaceId] });
queryClient.invalidateQueries({ queryKey: ["all-threads", searchSpaceId] });
queryClient.invalidateQueries({ queryKey: ["search-threads", searchSpaceId] });
// Invalidate thread detail for breadcrumb update
queryClient.invalidateQueries({ queryKey: ["threads", searchSpaceId, "detail", String(chatToRename.id)] });
} catch (error) {
console.error("Error renaming thread:", error);
toast.error(tSidebar("error_renaming_chat") || "Failed to rename chat");
} finally {
setIsRenamingChat(false);
setShowRenameChatDialog(false);
setChatToRename(null);
setNewChatTitle("");
}
}, [chatToRename, newChatTitle, queryClient, searchSpaceId, tSidebar]);
// Page usage // Page usage
const pageUsage = user const pageUsage = user
? { ? {
@ -529,6 +566,7 @@ export function LayoutDataProvider({
activeChatId={currentChatId} activeChatId={currentChatId}
onNewChat={handleNewChat} onNewChat={handleNewChat}
onChatSelect={handleChatSelect} onChatSelect={handleChatSelect}
onChatRename={handleChatRename}
onChatDelete={handleChatDelete} onChatDelete={handleChatDelete}
onChatArchive={handleChatArchive} onChatArchive={handleChatArchive}
onViewAllSharedChats={handleViewAllSharedChats} onViewAllSharedChats={handleViewAllSharedChats}
@ -561,7 +599,7 @@ export function LayoutDataProvider({
}, },
status: { status: {
items: statusItems, items: statusItems,
unreadCount: statusUnreadCount, unreadCount: statusOnlyUnreadCount,
loading: statusLoading, loading: statusLoading,
loadingMore: statusLoadingMore, loadingMore: statusLoadingMore,
hasMore: statusHasMore, hasMore: statusHasMore,
@ -620,6 +658,57 @@ export function LayoutDataProvider({
</DialogContent> </DialogContent>
</Dialog> </Dialog>
{/* Rename Chat Dialog */}
<Dialog open={showRenameChatDialog} onOpenChange={setShowRenameChatDialog}>
<DialogContent className="sm:max-w-md">
<DialogHeader>
<DialogTitle className="flex items-center gap-2">
<PencilIcon className="h-5 w-5" />
<span>{tSidebar("rename_chat") || "Rename Chat"}</span>
</DialogTitle>
<DialogDescription>
{tSidebar("rename_chat_description") || "Enter a new name for this conversation."}
</DialogDescription>
</DialogHeader>
<Input
value={newChatTitle}
onChange={(e) => setNewChatTitle(e.target.value)}
placeholder={tSidebar("chat_title_placeholder") || "Chat title"}
onKeyDown={(e) => {
if (e.key === "Enter" && !isRenamingChat && newChatTitle.trim()) {
confirmRenameChat();
}
}}
/>
<DialogFooter className="flex gap-2 sm:justify-end">
<Button
variant="outline"
onClick={() => setShowRenameChatDialog(false)}
disabled={isRenamingChat}
>
{tCommon("cancel")}
</Button>
<Button
onClick={confirmRenameChat}
disabled={isRenamingChat || !newChatTitle.trim()}
className="gap-2"
>
{isRenamingChat ? (
<>
<span className="h-4 w-4 animate-spin rounded-full border-2 border-current border-t-transparent" />
{tSidebar("renaming") || "Renaming..."}
</>
) : (
<>
<PencilIcon className="h-4 w-4" />
{tSidebar("rename") || "Rename"}
</>
)}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
{/* Delete Search Space Dialog */} {/* Delete Search Space Dialog */}
<Dialog open={showDeleteSearchSpaceDialog} onOpenChange={setShowDeleteSearchSpaceDialog}> <Dialog open={showDeleteSearchSpaceDialog} onOpenChange={setShowDeleteSearchSpaceDialog}>
<DialogContent className="sm:max-w-md"> <DialogContent className="sm:max-w-md">

View file

@ -54,6 +54,7 @@ interface LayoutShellProps {
activeChatId?: number | null; activeChatId?: number | null;
onNewChat: () => void; onNewChat: () => void;
onChatSelect: (chat: ChatItem) => void; onChatSelect: (chat: ChatItem) => void;
onChatRename?: (chat: ChatItem) => void;
onChatDelete?: (chat: ChatItem) => void; onChatDelete?: (chat: ChatItem) => void;
onChatArchive?: (chat: ChatItem) => void; onChatArchive?: (chat: ChatItem) => void;
onViewAllSharedChats?: () => void; onViewAllSharedChats?: () => void;
@ -90,6 +91,7 @@ export function LayoutShell({
activeChatId, activeChatId,
onNewChat, onNewChat,
onChatSelect, onChatSelect,
onChatRename,
onChatDelete, onChatDelete,
onChatArchive, onChatArchive,
onViewAllSharedChats, onViewAllSharedChats,
@ -147,6 +149,7 @@ export function LayoutShell({
activeChatId={activeChatId} activeChatId={activeChatId}
onNewChat={onNewChat} onNewChat={onNewChat}
onChatSelect={onChatSelect} onChatSelect={onChatSelect}
onChatRename={onChatRename}
onChatDelete={onChatDelete} onChatDelete={onChatDelete}
onChatArchive={onChatArchive} onChatArchive={onChatArchive}
onViewAllSharedChats={onViewAllSharedChats} onViewAllSharedChats={onViewAllSharedChats}
@ -215,6 +218,7 @@ export function LayoutShell({
activeChatId={activeChatId} activeChatId={activeChatId}
onNewChat={onNewChat} onNewChat={onNewChat}
onChatSelect={onChatSelect} onChatSelect={onChatSelect}
onChatRename={onChatRename}
onChatDelete={onChatDelete} onChatDelete={onChatDelete}
onChatArchive={onChatArchive} onChatArchive={onChatArchive}
onViewAllSharedChats={onViewAllSharedChats} onViewAllSharedChats={onViewAllSharedChats}

View file

@ -1,6 +1,6 @@
"use client"; "use client";
import { ArchiveIcon, MessageSquare, MoreHorizontal, RotateCcwIcon, Trash2 } from "lucide-react"; import { ArchiveIcon, MessageSquare, MoreHorizontal, PencilIcon, RotateCcwIcon, Trash2 } from "lucide-react";
import { useTranslations } from "next-intl"; import { useTranslations } from "next-intl";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { import {
@ -17,6 +17,7 @@ interface ChatListItemProps {
isActive?: boolean; isActive?: boolean;
archived?: boolean; archived?: boolean;
onClick?: () => void; onClick?: () => void;
onRename?: () => void;
onArchive?: () => void; onArchive?: () => void;
onDelete?: () => void; onDelete?: () => void;
} }
@ -26,6 +27,7 @@ export function ChatListItem({
isActive, isActive,
archived, archived,
onClick, onClick,
onRename,
onArchive, onArchive,
onDelete, onDelete,
}: ChatListItemProps) { }: ChatListItemProps) {
@ -57,15 +59,26 @@ export function ChatListItem({
<span className="sr-only">{t("more_options")}</span> <span className="sr-only">{t("more_options")}</span>
</Button> </Button>
</DropdownMenuTrigger> </DropdownMenuTrigger>
<DropdownMenuContent align="end" side="right"> <DropdownMenuContent align="end" side="right">
{onArchive && ( {onRename && (
<DropdownMenuItem <DropdownMenuItem
onClick={(e) => { onClick={(e) => {
e.stopPropagation(); e.stopPropagation();
onArchive(); onRename();
}} }}
> >
{archived ? ( <PencilIcon className="mr-2 h-4 w-4" />
<span>{t("rename") || "Rename"}</span>
</DropdownMenuItem>
)}
{onArchive && (
<DropdownMenuItem
onClick={(e) => {
e.stopPropagation();
onArchive();
}}
>
{archived ? (
<> <>
<RotateCcwIcon className="mr-2 h-4 w-4" /> <RotateCcwIcon className="mr-2 h-4 w-4" />
<span>{t("unarchive") || "Restore"}</span> <span>{t("unarchive") || "Restore"}</span>

View file

@ -4,7 +4,6 @@ import { useAtom } from "jotai";
import { import {
AlertCircle, AlertCircle,
AlertTriangle, AlertTriangle,
AtSign,
BellDot, BellDot,
Check, Check,
CheckCheck, CheckCheck,
@ -15,6 +14,7 @@ import {
Inbox, Inbox,
LayoutGrid, LayoutGrid,
ListFilter, ListFilter,
MessageSquare,
Search, Search,
X, X,
} from "lucide-react"; } from "lucide-react";
@ -46,6 +46,7 @@ import { Tabs, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
import { import {
isCommentReplyMetadata,
isConnectorIndexingMetadata, isConnectorIndexingMetadata,
isNewMentionMetadata, isNewMentionMetadata,
isPageLimitExceededMetadata, isPageLimitExceededMetadata,
@ -133,7 +134,7 @@ function getConnectorTypeDisplayName(connectorType: string): string {
); );
} }
type InboxTab = "mentions" | "status"; type InboxTab = "comments" | "status";
type InboxFilter = "all" | "unread"; type InboxFilter = "all" | "unread";
// Tab-specific data source with independent pagination // Tab-specific data source with independent pagination
@ -186,7 +187,7 @@ export function InboxSidebar({
const [, setTargetCommentId] = useAtom(setTargetCommentIdAtom); const [, setTargetCommentId] = useAtom(setTargetCommentIdAtom);
const [searchQuery, setSearchQuery] = useState(""); const [searchQuery, setSearchQuery] = useState("");
const [activeTab, setActiveTab] = useState<InboxTab>("mentions"); const [activeTab, setActiveTab] = useState<InboxTab>("comments");
const [activeFilter, setActiveFilter] = useState<InboxFilter>("all"); const [activeFilter, setActiveFilter] = useState<InboxFilter>("all");
const [selectedConnector, setSelectedConnector] = useState<string | null>(null); const [selectedConnector, setSelectedConnector] = useState<string | null>(null);
const [mounted, setMounted] = useState(false); const [mounted, setMounted] = useState(false);
@ -233,19 +234,25 @@ export function InboxSidebar({
} }
}, [activeTab]); }, [activeTab]);
// Get current tab's data source - each tab has independent data and pagination // Both tabs now derive items from status (all types), so use status for pagination
const currentDataSource = activeTab === "mentions" ? mentions : status; const { loading, loadingMore = false, hasMore = false, loadMore } = status;
const { loading, loadingMore = false, hasMore = false, loadMore } = currentDataSource;
// Status tab includes: connector indexing, document processing, page limit exceeded // Comments tab: mentions and comment replies
// Filter to only show status notification types const commentsItems = useMemo(
() =>
status.items.filter((item) => item.type === "new_mention" || item.type === "comment_reply"),
[status.items]
);
// Status tab: connector indexing, document processing, page limit exceeded, connector deletion
const statusItems = useMemo( const statusItems = useMemo(
() => () =>
status.items.filter( status.items.filter(
(item) => (item) =>
item.type === "connector_indexing" || item.type === "connector_indexing" ||
item.type === "document_processing" || item.type === "document_processing" ||
item.type === "page_limit_exceeded" item.type === "page_limit_exceeded" ||
item.type === "connector_deletion"
), ),
[status.items] [status.items]
); );
@ -269,8 +276,8 @@ export function InboxSidebar({
})); }));
}, [statusItems]); }, [statusItems]);
// Get items for current tab - mentions use their source directly, status uses filtered items // Get items for current tab
const displayItems = activeTab === "mentions" ? mentions.items : statusItems; const displayItems = activeTab === "comments" ? commentsItems : statusItems;
// Filter items based on filter type, connector filter, and search query // Filter items based on filter type, connector filter, and search query
const filteredItems = useMemo(() => { const filteredItems = useMemo(() => {
@ -333,9 +340,15 @@ export function InboxSidebar({
return () => observer.disconnect(); return () => observer.disconnect();
}, [loadMore, hasMore, loadingMore, open, searchQuery]); }, [loadMore, hasMore, loadingMore, open, searchQuery]);
// Use unread counts from data sources (more accurate than client-side counting) // Unread counts derived from filtered items
const unreadMentionsCount = mentions.unreadCount; const unreadCommentsCount = useMemo(
const unreadStatusCount = status.unreadCount; () => commentsItems.filter((item) => !item.read).length,
[commentsItems]
);
const unreadStatusCount = useMemo(
() => statusItems.filter((item) => !item.read).length,
[statusItems]
);
const handleItemClick = useCallback( const handleItemClick = useCallback(
async (item: InboxItem) => { async (item: InboxItem) => {
@ -346,19 +359,15 @@ export function InboxSidebar({
} }
if (item.type === "new_mention") { if (item.type === "new_mention") {
// Use type guard for safe metadata access
if (isNewMentionMetadata(item.metadata)) { if (isNewMentionMetadata(item.metadata)) {
const searchSpaceId = item.search_space_id; const searchSpaceId = item.search_space_id;
const threadId = item.metadata.thread_id; const threadId = item.metadata.thread_id;
const commentId = item.metadata.comment_id; const commentId = item.metadata.comment_id;
if (searchSpaceId && threadId) { if (searchSpaceId && threadId) {
// Pre-set target comment ID before navigation
// This also ensures comments panel is not collapsed
if (commentId) { if (commentId) {
setTargetCommentId(commentId); setTargetCommentId(commentId);
} }
const url = commentId const url = commentId
? `/dashboard/${searchSpaceId}/new-chat/${threadId}?commentId=${commentId}` ? `/dashboard/${searchSpaceId}/new-chat/${threadId}?commentId=${commentId}`
: `/dashboard/${searchSpaceId}/new-chat/${threadId}`; : `/dashboard/${searchSpaceId}/new-chat/${threadId}`;
@ -367,6 +376,24 @@ export function InboxSidebar({
router.push(url); router.push(url);
} }
} }
} else if (item.type === "comment_reply") {
if (isCommentReplyMetadata(item.metadata)) {
const searchSpaceId = item.search_space_id;
const threadId = item.metadata.thread_id;
const replyId = item.metadata.reply_id;
if (searchSpaceId && threadId) {
if (replyId) {
setTargetCommentId(replyId);
}
const url = replyId
? `/dashboard/${searchSpaceId}/new-chat/${threadId}?commentId=${replyId}`
: `/dashboard/${searchSpaceId}/new-chat/${threadId}`;
onOpenChange(false);
onCloseMobileSidebar?.();
router.push(url);
}
}
} else if (item.type === "page_limit_exceeded") { } else if (item.type === "page_limit_exceeded") {
// Navigate to the upgrade/more-pages page // Navigate to the upgrade/more-pages page
if (isPageLimitExceededMetadata(item.metadata)) { if (isPageLimitExceededMetadata(item.metadata)) {
@ -410,24 +437,29 @@ export function InboxSidebar({
}; };
const getStatusIcon = (item: InboxItem) => { const getStatusIcon = (item: InboxItem) => {
// For mentions, show the author's avatar with initials fallback // For mentions and comment replies, show the author's avatar
if (item.type === "new_mention") { if (item.type === "new_mention" || item.type === "comment_reply") {
// Use type guard for safe metadata access const metadata =
if (isNewMentionMetadata(item.metadata)) { item.type === "new_mention"
const authorName = item.metadata.author_name; ? isNewMentionMetadata(item.metadata)
const avatarUrl = item.metadata.author_avatar_url; ? item.metadata
const authorEmail = item.metadata.author_email; : null
: isCommentReplyMetadata(item.metadata)
? item.metadata
: null;
if (metadata) {
return ( return (
<Avatar className="h-8 w-8"> <Avatar className="h-8 w-8">
{avatarUrl && <AvatarImage src={avatarUrl} alt={authorName || "User"} />} {metadata.author_avatar_url && (
<AvatarImage src={metadata.author_avatar_url} alt={metadata.author_name || "User"} />
)}
<AvatarFallback className="text-[10px] bg-primary/10 text-primary"> <AvatarFallback className="text-[10px] bg-primary/10 text-primary">
{getInitials(authorName, authorEmail)} {getInitials(metadata.author_name, metadata.author_email)}
</AvatarFallback> </AvatarFallback>
</Avatar> </Avatar>
); );
} }
// Fallback for invalid metadata
return ( return (
<Avatar className="h-8 w-8"> <Avatar className="h-8 w-8">
<AvatarFallback className="text-[10px] bg-primary/10 text-primary"> <AvatarFallback className="text-[10px] bg-primary/10 text-primary">
@ -480,10 +512,10 @@ export function InboxSidebar({
}; };
const getEmptyStateMessage = () => { const getEmptyStateMessage = () => {
if (activeTab === "mentions") { if (activeTab === "comments") {
return { return {
title: t("no_mentions") || "No mentions", title: t("no_comments") || "No comments",
hint: t("no_mentions_hint") || "You'll see mentions from others here", hint: t("no_comments_hint") || "You'll see mentions and replies here",
}; };
} }
return { return {
@ -822,14 +854,14 @@ export function InboxSidebar({
> >
<TabsList className="w-full h-auto p-0 bg-transparent rounded-none border-b"> <TabsList className="w-full h-auto p-0 bg-transparent rounded-none border-b">
<TabsTrigger <TabsTrigger
value="mentions" value="comments"
className="flex-1 rounded-none border-b-2 border-transparent px-1 py-2 text-xs font-medium data-[state=active]:border-primary data-[state=active]:bg-transparent data-[state=active]:shadow-none" className="flex-1 rounded-none border-b-2 border-transparent px-1 py-2 text-xs font-medium data-[state=active]:border-primary data-[state=active]:bg-transparent data-[state=active]:shadow-none"
> >
<span className="w-full inline-flex items-center justify-center gap-1.5 px-3 py-1.5 rounded-lg hover:bg-muted transition-colors"> <span className="w-full inline-flex items-center justify-center gap-1.5 px-3 py-1.5 rounded-lg hover:bg-muted transition-colors">
<AtSign className="h-4 w-4" /> <MessageSquare className="h-4 w-4" />
<span>{t("mentions") || "Mentions"}</span> <span>{t("comments") || "Comments"}</span>
<span className="inline-flex items-center justify-center min-w-5 h-5 px-1.5 rounded-full bg-primary/20 text-muted-foreground text-xs font-medium"> <span className="inline-flex items-center justify-center min-w-5 h-5 px-1.5 rounded-full bg-primary/20 text-muted-foreground text-xs font-medium">
{formatInboxCount(unreadMentionsCount)} {formatInboxCount(unreadCommentsCount)}
</span> </span>
</span> </span>
</TabsTrigger> </TabsTrigger>
@ -931,8 +963,8 @@ export function InboxSidebar({
</div> </div>
) : ( ) : (
<div className="text-center py-8"> <div className="text-center py-8">
{activeTab === "mentions" ? ( {activeTab === "comments" ? (
<AtSign className="h-12 w-12 mx-auto text-muted-foreground mb-3" /> <MessageSquare className="h-12 w-12 mx-auto text-muted-foreground mb-3" />
) : ( ) : (
<History className="h-12 w-12 mx-auto text-muted-foreground mb-3" /> <History className="h-12 w-12 mx-auto text-muted-foreground mb-3" />
)} )}

View file

@ -24,6 +24,7 @@ interface MobileSidebarProps {
activeChatId?: number | null; activeChatId?: number | null;
onNewChat: () => void; onNewChat: () => void;
onChatSelect: (chat: ChatItem) => void; onChatSelect: (chat: ChatItem) => void;
onChatRename?: (chat: ChatItem) => void;
onChatDelete?: (chat: ChatItem) => void; onChatDelete?: (chat: ChatItem) => void;
onChatArchive?: (chat: ChatItem) => void; onChatArchive?: (chat: ChatItem) => void;
onViewAllSharedChats?: () => void; onViewAllSharedChats?: () => void;
@ -64,6 +65,7 @@ export function MobileSidebar({
activeChatId, activeChatId,
onNewChat, onNewChat,
onChatSelect, onChatSelect,
onChatRename,
onChatDelete, onChatDelete,
onChatArchive, onChatArchive,
onViewAllSharedChats, onViewAllSharedChats,
@ -142,6 +144,7 @@ export function MobileSidebar({
onOpenChange(false); onOpenChange(false);
}} }}
onChatSelect={handleChatSelect} onChatSelect={handleChatSelect}
onChatRename={onChatRename}
onChatDelete={onChatDelete} onChatDelete={onChatDelete}
onChatArchive={onChatArchive} onChatArchive={onChatArchive}
onViewAllSharedChats={onViewAllSharedChats} onViewAllSharedChats={onViewAllSharedChats}

Some files were not shown because too many files have changed in this diff Show more