Merge pull request #612 from MODSetter/dev

feat: migrate to new surfsense deepagent
This commit is contained in:
Rohan Verma 2025-12-23 02:18:22 -08:00 committed by GitHub
commit 075e373de9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
197 changed files with 42216 additions and 32950 deletions

View file

@ -28,12 +28,15 @@ COPY surfsense_web/package.json surfsense_web/pnpm-lock.yaml* ./
COPY surfsense_web/source.config.ts ./
COPY surfsense_web/content ./content
# Install dependencies
RUN pnpm install --frozen-lockfile
# Install dependencies (skip postinstall which requires all source files)
RUN pnpm install --frozen-lockfile --ignore-scripts
# Copy source
COPY surfsense_web/ ./
# Run fumadocs-mdx postinstall now that source files are available
RUN pnpm fumadocs-mdx
# Build args for frontend
ARG NEXT_PUBLIC_FASTAPI_BACKEND_URL=http://localhost:8000
ARG NEXT_PUBLIC_FASTAPI_BACKEND_AUTH_TYPE=LOCAL

View file

@ -0,0 +1,240 @@
"""Migrate old chats to new_chat_threads and remove old tables
Revision ID: 49
Revises: 48
Create Date: 2025-12-21
This migration:
1. Migrates data from old 'chats' table to 'new_chat_threads' and 'new_chat_messages'
2. Drops the 'podcasts' table (podcast data is not migrated as per user request)
3. Drops the 'chats' table
4. Removes the 'chattype' enum
"""
import json
from collections.abc import Sequence
from datetime import datetime
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "49"
down_revision: str | None = "48"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def extract_text_content(content: str | dict | list) -> str:
"""Extract plain text content from various message formats."""
if isinstance(content, str):
return content
if isinstance(content, dict):
# Handle dict with 'text' key
if "text" in content:
return content["text"]
return str(content)
if isinstance(content, list):
# Handle list of parts (e.g., [{"type": "text", "text": "..."}])
texts = []
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
texts.append(part.get("text", ""))
elif isinstance(part, str):
texts.append(part)
return "\n".join(texts) if texts else ""
return ""
def parse_timestamp(ts, fallback):
"""Parse ISO timestamp string to datetime object."""
if ts is None:
return fallback
if isinstance(ts, datetime):
return ts
if isinstance(ts, str):
try:
# Handle ISO format like '2025-11-26T22:43:34.399Z'
ts = ts.replace("Z", "+00:00")
return datetime.fromisoformat(ts)
except (ValueError, TypeError):
return fallback
return fallback
def upgrade() -> None:
"""Migrate old chats to new_chat_threads and remove old tables."""
connection = op.get_bind()
# Get all old chats
old_chats = connection.execute(
sa.text("""
SELECT id, title, messages, search_space_id, created_at
FROM chats
ORDER BY created_at ASC
""")
).fetchall()
print(f"[Migration 49] Found {len(old_chats)} old chats to migrate")
migrated_count = 0
for chat_id, title, messages_json, search_space_id, created_at in old_chats:
try:
# Parse messages JSON
if isinstance(messages_json, str):
messages = json.loads(messages_json)
else:
messages = messages_json or []
# Skip empty chats
if not messages:
print(f"[Migration 49] Skipping empty chat {chat_id}")
continue
# Create new thread
result = connection.execute(
sa.text("""
INSERT INTO new_chat_threads
(title, archived, search_space_id, created_at, updated_at)
VALUES (:title, FALSE, :search_space_id, :created_at, :created_at)
RETURNING id
"""),
{
"title": title or "Migrated Chat",
"search_space_id": search_space_id,
"created_at": created_at,
},
)
new_thread_id = result.fetchone()[0]
# Migrate messages - only user and assistant roles, skip SOURCES/TERMINAL_INFO
message_count = 0
for msg in messages:
role_lower = msg.get("role", "").lower()
# Only migrate user and assistant messages
if role_lower not in ("user", "assistant"):
continue
# Convert to uppercase for database enum
role = role_lower.upper()
# Extract content - handle various formats
content_raw = msg.get("content", "")
content_text = extract_text_content(content_raw)
# Skip empty messages
if not content_text.strip():
continue
# Parse message timestamp
msg_created_at = parse_timestamp(msg.get("createdAt"), created_at)
# Store content as JSONB array format for assistant-ui compatibility
content_list = [{"type": "text", "text": content_text}]
# Use direct SQL with string interpolation for the enum since CAST doesn't work
# The enum value comes from trusted source (our own code), not user input
connection.execute(
sa.text(f"""
INSERT INTO new_chat_messages
(thread_id, role, content, created_at)
VALUES (:thread_id, '{role}', CAST(:content AS jsonb), :created_at)
"""),
{
"thread_id": new_thread_id,
"content": json.dumps(content_list),
"created_at": msg_created_at,
},
)
message_count += 1
print(
f"[Migration 49] Migrated chat {chat_id} -> thread {new_thread_id} ({message_count} messages)"
)
migrated_count += 1
except Exception as e:
print(f"[Migration 49] Error migrating chat {chat_id}: {e}")
# Re-raise to abort migration - we don't want partial data
raise
print(f"[Migration 49] Successfully migrated {migrated_count} chats")
# Drop podcasts table (FK references chats, so drop first)
print("[Migration 49] Dropping podcasts table...")
op.drop_table("podcasts")
# Drop chats table
print("[Migration 49] Dropping chats table...")
op.drop_table("chats")
# Drop chattype enum
print("[Migration 49] Dropping chattype enum...")
op.execute(sa.text("DROP TYPE IF EXISTS chattype"))
print("[Migration 49] Migration complete!")
def downgrade() -> None:
"""Recreate old tables (data cannot be restored)."""
# Recreate chattype enum
op.execute(
sa.text("""
CREATE TYPE chattype AS ENUM ('QNA')
""")
)
# Recreate chats table
op.create_table(
"chats",
sa.Column("id", sa.Integer(), primary_key=True, index=True),
sa.Column("type", sa.Enum("QNA", name="chattype"), nullable=False),
sa.Column("title", sa.String(), nullable=False, index=True),
sa.Column("initial_connectors", sa.ARRAY(sa.String()), nullable=True),
sa.Column("messages", sa.JSON(), nullable=False),
sa.Column("state_version", sa.BigInteger(), nullable=False, default=1),
sa.Column(
"search_space_id",
sa.Integer(),
sa.ForeignKey("searchspaces.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
# Recreate podcasts table
op.create_table(
"podcasts",
sa.Column("id", sa.Integer(), primary_key=True, index=True),
sa.Column("title", sa.String(), nullable=False, index=True),
sa.Column("podcast_transcript", sa.JSON(), nullable=False, server_default="{}"),
sa.Column("file_location", sa.String(500), nullable=False, server_default=""),
sa.Column(
"chat_id",
sa.Integer(),
sa.ForeignKey("chats.id", ondelete="CASCADE"),
nullable=True,
),
sa.Column("chat_state_version", sa.BigInteger(), nullable=True),
sa.Column(
"search_space_id",
sa.Integer(),
sa.ForeignKey("searchspaces.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"created_at",
sa.TIMESTAMP(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
print("[Migration 49 Downgrade] Tables recreated (data not restored)")

View file

@ -0,0 +1,48 @@
"""50_remove_podcast_chat_columns
Revision ID: 50
Revises: 49
Create Date: 2025-12-21
Removes chat_id and chat_state_version columns from podcasts table.
These columns were used for the old chat system podcast linking which
has been replaced by the new-chat content-based podcast generation.
"""
from collections.abc import Sequence
import sqlalchemy as sa
from sqlalchemy import inspect
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "50"
down_revision: str | None = "49"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Upgrade schema - Remove chat_id and chat_state_version from podcasts."""
conn = op.get_bind()
inspector = inspect(conn)
columns = [col["name"] for col in inspector.get_columns("podcasts")]
if "chat_id" in columns:
op.drop_column("podcasts", "chat_id")
if "chat_state_version" in columns:
op.drop_column("podcasts", "chat_state_version")
def downgrade() -> None:
"""Downgrade schema - Re-add chat_id and chat_state_version to podcasts."""
op.add_column(
"podcasts",
sa.Column("chat_id", sa.Integer(), nullable=True),
)
op.add_column(
"podcasts",
sa.Column("chat_state_version", sa.String(100), nullable=True),
)

View file

@ -0,0 +1,114 @@
"""Add NewLLMConfig table for configurable LLM + prompt settings
Revision ID: 51
Revises: 50
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "51"
down_revision: str | None = "50"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""
Add the new_llm_configs table that combines LLM model settings with prompt configuration.
This table includes:
- LLM model configuration (provider, model_name, api_key, etc.)
- Configurable system instructions
- Citation toggle
"""
# Create new_llm_configs table only if it doesn't already exist
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'new_llm_configs'
) THEN
CREATE TABLE new_llm_configs (
id SERIAL PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
-- Basic info
name VARCHAR(100) NOT NULL,
description VARCHAR(500),
-- LLM Model Configuration (same as llm_configs, excluding language)
provider litellmprovider NOT NULL,
custom_provider VARCHAR(100),
model_name VARCHAR(100) NOT NULL,
api_key TEXT NOT NULL,
api_base VARCHAR(500),
litellm_params JSONB DEFAULT '{}',
-- Prompt Configuration
system_instructions TEXT NOT NULL DEFAULT '',
use_default_system_instructions BOOLEAN NOT NULL DEFAULT TRUE,
citations_enabled BOOLEAN NOT NULL DEFAULT TRUE,
-- Default flag
is_default BOOLEAN NOT NULL DEFAULT FALSE,
-- Foreign key to search space
search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE
);
END IF;
END$$;
"""
)
# Create indexes if they don't exist
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_indexes
WHERE tablename = 'new_llm_configs' AND indexname = 'ix_new_llm_configs_id'
) THEN
CREATE INDEX ix_new_llm_configs_id ON new_llm_configs(id);
END IF;
IF NOT EXISTS (
SELECT 1 FROM pg_indexes
WHERE tablename = 'new_llm_configs' AND indexname = 'ix_new_llm_configs_created_at'
) THEN
CREATE INDEX ix_new_llm_configs_created_at ON new_llm_configs(created_at);
END IF;
IF NOT EXISTS (
SELECT 1 FROM pg_indexes
WHERE tablename = 'new_llm_configs' AND indexname = 'ix_new_llm_configs_name'
) THEN
CREATE INDEX ix_new_llm_configs_name ON new_llm_configs(name);
END IF;
IF NOT EXISTS (
SELECT 1 FROM pg_indexes
WHERE tablename = 'new_llm_configs' AND indexname = 'ix_new_llm_configs_search_space_id'
) THEN
CREATE INDEX ix_new_llm_configs_search_space_id ON new_llm_configs(search_space_id);
END IF;
END$$;
"""
)
def downgrade() -> None:
"""Remove the new_llm_configs table."""
# Drop indexes
op.execute("DROP INDEX IF EXISTS ix_new_llm_configs_search_space_id")
op.execute("DROP INDEX IF EXISTS ix_new_llm_configs_name")
op.execute("DROP INDEX IF EXISTS ix_new_llm_configs_created_at")
op.execute("DROP INDEX IF EXISTS ix_new_llm_configs_id")
# Drop table
op.execute("DROP TABLE IF EXISTS new_llm_configs")

View file

@ -0,0 +1,130 @@
"""Rename LLM preference columns in searchspaces table
Revision ID: 52
Revises: 51
Create Date: 2024-12-22
This migration renames the LLM preference columns:
- fast_llm_id -> agent_llm_id
- long_context_llm_id -> document_summary_llm_id
- strategic_llm_id is removed (data migrated to document_summary_llm_id)
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "52"
down_revision = "51"
branch_labels = None
depends_on = None
def upgrade():
# First, migrate any strategic_llm_id values to document_summary_llm_id
# (only if document_summary_llm_id/long_context_llm_id is NULL)
# Use IF EXISTS check to handle case where column might not exist
op.execute(
"""
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'strategic_llm_id'
) THEN
UPDATE searchspaces
SET long_context_llm_id = strategic_llm_id
WHERE long_context_llm_id IS NULL AND strategic_llm_id IS NOT NULL;
END IF;
END$$;
"""
)
# Rename columns (only if they exist with old names)
op.execute(
"""
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'fast_llm_id'
) THEN
ALTER TABLE searchspaces RENAME COLUMN fast_llm_id TO agent_llm_id;
END IF;
END$$;
"""
)
op.execute(
"""
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'long_context_llm_id'
) THEN
ALTER TABLE searchspaces RENAME COLUMN long_context_llm_id TO document_summary_llm_id;
END IF;
END$$;
"""
)
# Drop the strategic_llm_id column if it exists
op.execute(
"""
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'strategic_llm_id'
) THEN
ALTER TABLE searchspaces DROP COLUMN strategic_llm_id;
END IF;
END$$;
"""
)
def downgrade():
# Add back the strategic_llm_id column
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'strategic_llm_id'
) THEN
ALTER TABLE searchspaces ADD COLUMN strategic_llm_id INTEGER;
END IF;
END$$;
"""
)
# Rename columns back
op.execute(
"""
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'agent_llm_id'
) THEN
ALTER TABLE searchspaces RENAME COLUMN agent_llm_id TO fast_llm_id;
END IF;
END$$;
"""
)
op.execute(
"""
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'searchspaces' AND column_name = 'document_summary_llm_id'
) THEN
ALTER TABLE searchspaces RENAME COLUMN document_summary_llm_id TO long_context_llm_id;
END IF;
END$$;
"""
)

View file

@ -0,0 +1,244 @@
"""Migrate data from old llm_configs to new_llm_configs and cleanup
Revision ID: 53
Revises: 52
Create Date: 2024-12-22
This migration:
1. Migrates data from old llm_configs table to new_llm_configs (preserving user configs)
2. Drops the old llm_configs table (no longer used)
3. Removes the is_default column from new_llm_configs (roles now determine which config to use)
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "53"
down_revision = "52"
branch_labels = None
depends_on = None
def upgrade():
# STEP 1: Migrate data from old llm_configs to new_llm_configs
# This preserves any user-created configurations
op.execute(
"""
DO $$
BEGIN
-- Only migrate if both tables exist
IF EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'llm_configs'
) AND EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'new_llm_configs'
) THEN
-- Insert old configs into new table (skipping duplicates by name+search_space_id)
INSERT INTO new_llm_configs (
name,
description,
provider,
custom_provider,
model_name,
api_key,
api_base,
litellm_params,
system_instructions,
use_default_system_instructions,
citations_enabled,
is_default,
search_space_id,
created_at
)
SELECT
lc.name,
NULL as description, -- Old table didn't have description
lc.provider,
lc.custom_provider,
lc.model_name,
lc.api_key,
lc.api_base,
COALESCE(lc.litellm_params, '{}'::jsonb),
'' as system_instructions, -- Use defaults
TRUE as use_default_system_instructions,
TRUE as citations_enabled,
FALSE as is_default,
lc.search_space_id,
COALESCE(lc.created_at, NOW())
FROM llm_configs lc
WHERE lc.search_space_id IS NOT NULL
AND NOT EXISTS (
-- Skip if a config with same name already exists in new_llm_configs for this search space
SELECT 1 FROM new_llm_configs nlc
WHERE nlc.name = lc.name
AND nlc.search_space_id = lc.search_space_id
);
-- Log how many configs were migrated
RAISE NOTICE 'Migrated % configs from llm_configs to new_llm_configs',
(SELECT COUNT(*) FROM llm_configs WHERE search_space_id IS NOT NULL);
END IF;
END$$;
"""
)
# STEP 2: Update searchspaces to point to new_llm_configs for their agent LLM
# If a search space had an agent_llm_id pointing to old llm_configs,
# try to find the corresponding config in new_llm_configs
op.execute(
"""
DO $$
BEGIN
IF EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'llm_configs'
) THEN
-- Update agent_llm_id to point to migrated config in new_llm_configs
UPDATE searchspaces ss
SET agent_llm_id = (
SELECT nlc.id
FROM new_llm_configs nlc
JOIN llm_configs lc ON lc.name = nlc.name AND lc.search_space_id = nlc.search_space_id
WHERE lc.id = ss.agent_llm_id
AND nlc.search_space_id = ss.id
LIMIT 1
)
WHERE ss.agent_llm_id IS NOT NULL
AND ss.agent_llm_id > 0 -- Only positive IDs (not global configs)
AND EXISTS (
SELECT 1 FROM llm_configs lc WHERE lc.id = ss.agent_llm_id
);
-- Update document_summary_llm_id similarly
UPDATE searchspaces ss
SET document_summary_llm_id = (
SELECT nlc.id
FROM new_llm_configs nlc
JOIN llm_configs lc ON lc.name = nlc.name AND lc.search_space_id = nlc.search_space_id
WHERE lc.id = ss.document_summary_llm_id
AND nlc.search_space_id = ss.id
LIMIT 1
)
WHERE ss.document_summary_llm_id IS NOT NULL
AND ss.document_summary_llm_id > 0 -- Only positive IDs (not global configs)
AND EXISTS (
SELECT 1 FROM llm_configs lc WHERE lc.id = ss.document_summary_llm_id
);
END IF;
END$$;
"""
)
# STEP 3: Drop the is_default column from new_llm_configs
# (role assignments now determine which config to use)
op.execute(
"""
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'new_llm_configs' AND column_name = 'is_default'
) THEN
ALTER TABLE new_llm_configs DROP COLUMN is_default;
END IF;
END$$;
"""
)
# STEP 4: Drop the old llm_configs table (data has been migrated)
op.execute("DROP TABLE IF EXISTS llm_configs CASCADE")
def downgrade():
# Recreate the old llm_configs table
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'llm_configs'
) THEN
CREATE TABLE llm_configs (
id SERIAL PRIMARY KEY,
name VARCHAR(100) NOT NULL,
provider litellmprovider NOT NULL,
custom_provider VARCHAR(100),
model_name VARCHAR(100) NOT NULL,
api_key TEXT NOT NULL,
api_base VARCHAR(500),
language VARCHAR(50),
litellm_params JSONB DEFAULT '{}',
search_space_id INTEGER REFERENCES searchspaces(id) ON DELETE CASCADE,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE
);
-- Create indexes
CREATE INDEX IF NOT EXISTS ix_llm_configs_id ON llm_configs(id);
CREATE INDEX IF NOT EXISTS ix_llm_configs_name ON llm_configs(name);
CREATE INDEX IF NOT EXISTS ix_llm_configs_created_at ON llm_configs(created_at);
END IF;
END$$;
"""
)
# Migrate data back from new_llm_configs to llm_configs
op.execute(
"""
DO $$
BEGIN
IF EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'new_llm_configs'
) THEN
INSERT INTO llm_configs (
name,
provider,
custom_provider,
model_name,
api_key,
api_base,
language,
litellm_params,
search_space_id,
created_at
)
SELECT
nlc.name,
nlc.provider,
nlc.custom_provider,
nlc.model_name,
nlc.api_key,
nlc.api_base,
'English' as language, -- Default language
COALESCE(nlc.litellm_params, '{}'::jsonb),
nlc.search_space_id,
nlc.created_at
FROM new_llm_configs nlc
WHERE nlc.search_space_id IS NOT NULL
AND NOT EXISTS (
SELECT 1 FROM llm_configs lc
WHERE lc.name = nlc.name
AND lc.search_space_id = nlc.search_space_id
);
END IF;
END$$;
"""
)
# Add back the is_default column to new_llm_configs
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'new_llm_configs' AND column_name = 'is_default'
) THEN
ALTER TABLE new_llm_configs ADD COLUMN is_default BOOLEAN NOT NULL DEFAULT FALSE;
END IF;
END$$;
"""
)

View file

@ -1,27 +1,80 @@
"""Chat agents module."""
"""
SurfSense New Chat Agent Module.
from app.agents.new_chat.chat_deepagent import (
This module provides the SurfSense deep agent with configurable tools
for knowledge base search, podcast generation, and more.
Directory Structure:
- tools/: All agent tools (knowledge_base, podcast, link_preview, etc.)
- chat_deepagent.py: Main agent factory
- system_prompt.py: System prompts and instructions
- context.py: Context schema for the agent
- checkpointer.py: LangGraph checkpointer setup
- llm_config.py: LLM configuration utilities
- utils.py: Shared utilities
"""
# Agent factory
from .chat_deepagent import create_surfsense_deep_agent
# Context
from .context import SurfSenseContextSchema
# LLM config
from .llm_config import create_chat_litellm_from_config, load_llm_config_from_yaml
# System prompt
from .system_prompt import (
SURFSENSE_CITATION_INSTRUCTIONS,
SURFSENSE_SYSTEM_PROMPT,
SurfSenseContextSchema,
build_surfsense_system_prompt,
create_chat_litellm_from_config,
)
# Tools - registry exports
# Tools - factory exports (for direct use)
# Tools - knowledge base utilities
from .tools import (
BUILTIN_TOOLS,
ToolDefinition,
build_tools,
create_display_image_tool,
create_generate_podcast_tool,
create_link_preview_tool,
create_scrape_webpage_tool,
create_search_knowledge_base_tool,
create_surfsense_deep_agent,
format_documents_for_context,
load_llm_config_from_yaml,
get_all_tool_names,
get_default_enabled_tools,
get_tool_by_name,
search_knowledge_base_async,
)
__all__ = [
# Tools registry
"BUILTIN_TOOLS",
# System prompt
"SURFSENSE_CITATION_INSTRUCTIONS",
"SURFSENSE_SYSTEM_PROMPT",
# Context
"SurfSenseContextSchema",
"ToolDefinition",
"build_surfsense_system_prompt",
"build_tools",
# LLM config
"create_chat_litellm_from_config",
# Tool factories
"create_display_image_tool",
"create_generate_podcast_tool",
"create_link_preview_tool",
"create_scrape_webpage_tool",
"create_search_knowledge_base_tool",
# Agent factory
"create_surfsense_deep_agent",
# Knowledge base utilities
"format_documents_for_context",
"get_all_tool_names",
"get_default_enabled_tools",
"get_tool_by_name",
"load_llm_config_from_yaml",
"search_knowledge_base_async",
]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,94 @@
"""
PostgreSQL-based checkpointer for LangGraph agents.
This module provides a persistent checkpointer using AsyncPostgresSaver
that stores conversation state in the PostgreSQL database.
"""
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from app.config import config
# Global checkpointer instance (initialized lazily)
_checkpointer: AsyncPostgresSaver | None = None
_checkpointer_context = None # Store the context manager for cleanup
_checkpointer_initialized: bool = False
def get_postgres_connection_string() -> str:
"""
Convert the async DATABASE_URL to a sync postgres connection string for psycopg3.
The DATABASE_URL is typically in format:
postgresql+asyncpg://user:pass@host:port/dbname
We need to convert it to:
postgresql://user:pass@host:port/dbname
"""
db_url = config.DATABASE_URL
# Handle asyncpg driver prefix
if db_url.startswith("postgresql+asyncpg://"):
return db_url.replace("postgresql+asyncpg://", "postgresql://")
# Handle other async prefixes
if "+asyncpg" in db_url:
return db_url.replace("+asyncpg", "")
return db_url
async def get_checkpointer() -> AsyncPostgresSaver:
"""
Get or create the global AsyncPostgresSaver instance.
This function:
1. Creates the checkpointer if it doesn't exist
2. Sets up the required database tables on first call
3. Returns the cached instance on subsequent calls
Returns:
AsyncPostgresSaver: The configured checkpointer instance
"""
global _checkpointer, _checkpointer_context, _checkpointer_initialized
if _checkpointer is None:
conn_string = get_postgres_connection_string()
# from_conn_string returns an async context manager
# We need to enter the context to get the actual checkpointer
_checkpointer_context = AsyncPostgresSaver.from_conn_string(conn_string)
_checkpointer = await _checkpointer_context.__aenter__()
# Setup tables on first call (idempotent)
if not _checkpointer_initialized:
await _checkpointer.setup()
_checkpointer_initialized = True
return _checkpointer
async def setup_checkpointer_tables() -> None:
"""
Explicitly setup the checkpointer tables.
This can be called during application startup to ensure
tables exist before any agent calls.
"""
await get_checkpointer()
print("[Checkpointer] PostgreSQL checkpoint tables ready")
async def close_checkpointer() -> None:
"""
Close the checkpointer connection.
This should be called during application shutdown.
"""
global _checkpointer, _checkpointer_context, _checkpointer_initialized
if _checkpointer_context is not None:
await _checkpointer_context.__aexit__(None, None, None)
_checkpointer = None
_checkpointer_context = None
_checkpointer_initialized = False
print("[Checkpointer] PostgreSQL connection closed")

View file

@ -0,0 +1,28 @@
"""
Context schema definitions for SurfSense agents.
This module defines the custom state schema used by the SurfSense deep agent.
"""
from typing import TypedDict
class SurfSenseContextSchema(TypedDict):
"""
Custom state schema for the SurfSense deep agent.
This extends the default agent state with custom fields.
The default state already includes:
- messages: Conversation history
- todos: Task list from TodoListMiddleware
- files: Virtual filesystem from FilesystemMiddleware
We're adding fields needed for knowledge base search:
- search_space_id: The user's search space ID
- db_session: Database session (injected at runtime)
- connector_service: Connector service instance (injected at runtime)
"""
search_space_id: int
# These are runtime-injected and won't be serialized
# db_session and connector_service are passed when invoking the agent

View file

@ -0,0 +1,361 @@
"""
LLM configuration utilities for SurfSense agents.
This module provides functions for loading LLM configurations from:
1. YAML files (global configs with negative IDs)
2. Database NewLLMConfig table (user-created configs with positive IDs)
It also provides utilities for creating ChatLiteLLM instances and
managing prompt configurations.
"""
from dataclasses import dataclass
from pathlib import Path
import yaml
from langchain_litellm import ChatLiteLLM
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
# Provider mapping for LiteLLM model string construction
PROVIDER_MAP = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",
"XAI": "xai",
"BEDROCK": "bedrock",
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
"ZHIPU": "openai",
"REPLICATE": "replicate",
"PERPLEXITY": "perplexity",
"ANYSCALE": "anyscale",
"DEEPINFRA": "deepinfra",
"CEREBRAS": "cerebras",
"SAMBANOVA": "sambanova",
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
"COMETAPI": "cometapi",
"HUGGINGFACE": "huggingface",
"CUSTOM": "custom",
}
@dataclass
class AgentConfig:
"""
Complete configuration for the SurfSense agent.
This combines LLM settings with prompt configuration from NewLLMConfig.
"""
# LLM Model Settings
provider: str
model_name: str
api_key: str
api_base: str | None = None
custom_provider: str | None = None
litellm_params: dict | None = None
# Prompt Configuration
system_instructions: str | None = None
use_default_system_instructions: bool = True
citations_enabled: bool = True
# Metadata
config_id: int | None = None
config_name: str | None = None
@classmethod
def from_new_llm_config(cls, config) -> "AgentConfig":
"""
Create an AgentConfig from a NewLLMConfig database model.
Args:
config: NewLLMConfig database model instance
Returns:
AgentConfig instance
"""
return cls(
provider=config.provider.value
if hasattr(config.provider, "value")
else str(config.provider),
model_name=config.model_name,
api_key=config.api_key,
api_base=config.api_base,
custom_provider=config.custom_provider,
litellm_params=config.litellm_params,
system_instructions=config.system_instructions,
use_default_system_instructions=config.use_default_system_instructions,
citations_enabled=config.citations_enabled,
config_id=config.id,
config_name=config.name,
)
@classmethod
def from_yaml_config(cls, yaml_config: dict) -> "AgentConfig":
"""
Create an AgentConfig from a YAML configuration dictionary.
YAML configs now support the same prompt configuration fields as NewLLMConfig:
- system_instructions: Custom system instructions (empty string uses defaults)
- use_default_system_instructions: Whether to use default instructions
- citations_enabled: Whether citations are enabled
Args:
yaml_config: Configuration dictionary from YAML file
Returns:
AgentConfig instance
"""
# Get system instructions from YAML, default to empty string
system_instructions = yaml_config.get("system_instructions", "")
return cls(
provider=yaml_config.get("provider", "").upper(),
model_name=yaml_config.get("model_name", ""),
api_key=yaml_config.get("api_key", ""),
api_base=yaml_config.get("api_base"),
custom_provider=yaml_config.get("custom_provider"),
litellm_params=yaml_config.get("litellm_params"),
# Prompt configuration from YAML (with defaults for backwards compatibility)
system_instructions=system_instructions if system_instructions else None,
use_default_system_instructions=yaml_config.get(
"use_default_system_instructions", True
),
citations_enabled=yaml_config.get("citations_enabled", True),
config_id=yaml_config.get("id"),
config_name=yaml_config.get("name"),
)
def load_llm_config_from_yaml(llm_config_id: int = -1) -> dict | None:
"""
Load a specific LLM config from global_llm_config.yaml.
Args:
llm_config_id: The id of the config to load (default: -1)
Returns:
LLM config dict or None if not found
"""
# Get the config file path
base_dir = Path(__file__).resolve().parent.parent.parent.parent
config_file = base_dir / "app" / "config" / "global_llm_config.yaml"
# Fallback to example file if main config doesn't exist
if not config_file.exists():
config_file = base_dir / "app" / "config" / "global_llm_config.example.yaml"
if not config_file.exists():
print("Error: No global_llm_config.yaml or example file found")
return None
try:
with open(config_file, encoding="utf-8") as f:
data = yaml.safe_load(f)
configs = data.get("global_llm_configs", [])
for cfg in configs:
if isinstance(cfg, dict) and cfg.get("id") == llm_config_id:
return cfg
print(f"Error: Global LLM config id {llm_config_id} not found")
return None
except Exception as e:
print(f"Error loading config: {e}")
return None
async def load_new_llm_config_from_db(
session: AsyncSession,
config_id: int,
) -> "AgentConfig | None":
"""
Load a NewLLMConfig from the database by ID.
Args:
session: AsyncSession for database access
config_id: The ID of the NewLLMConfig to load
Returns:
AgentConfig instance or None if not found
"""
# Import here to avoid circular imports
from app.db import NewLLMConfig
try:
result = await session.execute(
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
)
config = result.scalars().first()
if not config:
print(f"Error: NewLLMConfig with id {config_id} not found")
return None
return AgentConfig.from_new_llm_config(config)
except Exception as e:
print(f"Error loading NewLLMConfig from database: {e}")
return None
async def load_agent_llm_config_for_search_space(
session: AsyncSession,
search_space_id: int,
) -> "AgentConfig | None":
"""
Load the agent LLM configuration for a search space.
This loads the LLM config based on the search space's agent_llm_id setting:
- Positive ID: Load from NewLLMConfig database table
- Negative ID: Load from YAML global configs
- None: Falls back to first global config (id=-1)
Args:
session: AsyncSession for database access
search_space_id: The search space ID
Returns:
AgentConfig instance or None if not found
"""
# Import here to avoid circular imports
from app.db import SearchSpace
try:
# Get the search space to check its agent_llm_id preference
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if not search_space:
print(f"Error: SearchSpace with id {search_space_id} not found")
return None
# Use agent_llm_id from search space, fallback to -1 (first global config)
config_id = (
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
)
# Load the config using the unified loader
return await load_agent_config(session, config_id, search_space_id)
except Exception as e:
print(f"Error loading agent LLM config for search space {search_space_id}: {e}")
return None
async def load_agent_config(
session: AsyncSession,
config_id: int,
search_space_id: int | None = None,
) -> "AgentConfig | None":
"""
Load an agent configuration, supporting both YAML (negative IDs) and database (positive IDs) configs.
This is the main entry point for loading configurations:
- Negative IDs: Load from YAML file (global configs)
- Positive IDs: Load from NewLLMConfig database table
Args:
session: AsyncSession for database access
config_id: The config ID (negative for YAML, positive for database)
search_space_id: Optional search space ID for context
Returns:
AgentConfig instance or None if not found
"""
if config_id < 0:
# Load from YAML (global configs have negative IDs)
yaml_config = load_llm_config_from_yaml(config_id)
if yaml_config:
return AgentConfig.from_yaml_config(yaml_config)
return None
else:
# Load from database (NewLLMConfig)
return await load_new_llm_config_from_db(session, config_id)
def create_chat_litellm_from_config(llm_config: dict) -> ChatLiteLLM | None:
"""
Create a ChatLiteLLM instance from a global LLM config dictionary.
Args:
llm_config: LLM configuration dictionary from YAML
Returns:
ChatLiteLLM instance or None on error
"""
# Build the model string
if llm_config.get("custom_provider"):
model_string = f"{llm_config['custom_provider']}/{llm_config['model_name']}"
else:
provider = llm_config.get("provider", "").upper()
provider_prefix = PROVIDER_MAP.get(provider, provider.lower())
model_string = f"{provider_prefix}/{llm_config['model_name']}"
# Create ChatLiteLLM instance with streaming enabled
litellm_kwargs = {
"model": model_string,
"api_key": llm_config.get("api_key"),
"streaming": True, # Enable streaming for real-time token streaming
}
# Add optional parameters
if llm_config.get("api_base"):
litellm_kwargs["api_base"] = llm_config["api_base"]
# Add any additional litellm parameters
if llm_config.get("litellm_params"):
litellm_kwargs.update(llm_config["litellm_params"])
return ChatLiteLLM(**litellm_kwargs)
def create_chat_litellm_from_agent_config(
agent_config: AgentConfig,
) -> ChatLiteLLM | None:
"""
Create a ChatLiteLLM instance from an AgentConfig.
Args:
agent_config: AgentConfig instance
Returns:
ChatLiteLLM instance or None on error
"""
# Build the model string
if agent_config.custom_provider:
model_string = f"{agent_config.custom_provider}/{agent_config.model_name}"
else:
provider_prefix = PROVIDER_MAP.get(
agent_config.provider, agent_config.provider.lower()
)
model_string = f"{provider_prefix}/{agent_config.model_name}"
# Create ChatLiteLLM instance with streaming enabled
litellm_kwargs = {
"model": model_string,
"api_key": agent_config.api_key,
"streaming": True, # Enable streaming for real-time token streaming
}
# Add optional parameters
if agent_config.api_base:
litellm_kwargs["api_base"] = agent_config.api_base
# Add any additional litellm parameters
if agent_config.litellm_params:
litellm_kwargs.update(agent_config.litellm_params)
return ChatLiteLLM(**litellm_kwargs)

View file

@ -0,0 +1,346 @@
"""
System prompt building for SurfSense agents.
This module provides functions and constants for building the SurfSense system prompt
with configurable user instructions and citation support.
The prompt is composed of three parts:
1. System Instructions (configurable via NewLLMConfig)
2. Tools Instructions (always included, not configurable)
3. Citation Instructions (toggleable via NewLLMConfig.citations_enabled)
"""
from datetime import UTC, datetime
# Default system instructions - can be overridden via NewLLMConfig.system_instructions
SURFSENSE_SYSTEM_INSTRUCTIONS = """
<system_instruction>
You are SurfSense, a reasoning and acting AI agent designed to answer user questions using the user's personal knowledge base.
Today's date (UTC): {resolved_today}
</system_instruction>
"""
SURFSENSE_TOOLS_INSTRUCTIONS = """
<tools>
You have access to the following tools:
1. search_knowledge_base: Search the user's personal knowledge base for relevant information.
- Args:
- query: The search query - be specific and include key terms
- top_k: Number of results to retrieve (default: 10)
- start_date: Optional ISO date/datetime (e.g. "2025-12-12" or "2025-12-12T00:00:00+00:00")
- end_date: Optional ISO date/datetime (e.g. "2025-12-19" or "2025-12-19T23:59:59+00:00")
- connectors_to_search: Optional list of connector enums to search. If omitted, searches all.
- Returns: Formatted string with relevant documents and their content
2. generate_podcast: Generate an audio podcast from provided content.
- Use this when the user asks to create, generate, or make a podcast.
- Trigger phrases: "give me a podcast about", "create a podcast", "generate a podcast", "make a podcast", "turn this into a podcast"
- Args:
- source_content: The text content to convert into a podcast. This MUST be comprehensive and include:
* If discussing the current conversation: Include a detailed summary of the FULL chat history (all user questions and your responses)
* If based on knowledge base search: Include the key findings and insights from the search results
* You can combine both: conversation context + search results for richer podcasts
* The more detailed the source_content, the better the podcast quality
- podcast_title: Optional title for the podcast (default: "SurfSense Podcast")
- user_prompt: Optional instructions for podcast style/format (e.g., "Make it casual and fun")
- Returns: A task_id for tracking. The podcast will be generated in the background.
- IMPORTANT: Only one podcast can be generated at a time. If a podcast is already being generated, the tool will return status "already_generating".
- After calling this tool, inform the user that podcast generation has started and they will see the player when it's ready (takes 3-5 minutes).
3. link_preview: Fetch metadata for a URL to display a rich preview card.
- IMPORTANT: Use this tool WHENEVER the user shares or mentions a URL/link in their message.
- This fetches the page's Open Graph metadata (title, description, thumbnail) to show a preview card.
- NOTE: This tool only fetches metadata, NOT the full page content. It cannot read the article text.
- Trigger scenarios:
* User shares a URL (e.g., "Check out https://example.com")
* User pastes a link in their message
* User asks about a URL or link
- Args:
- url: The URL to fetch metadata for (must be a valid HTTP/HTTPS URL)
- Returns: A rich preview card with title, description, thumbnail, and domain
- The preview card will automatically be displayed in the chat.
4. display_image: Display an image in the chat with metadata.
- Use this tool when you want to show an image to the user.
- This displays the image with an optional title, description, and source attribution.
- Common use cases:
* Showing an image from a URL mentioned in the conversation
* Displaying a diagram, chart, or illustration you're referencing
* Showing visual examples when explaining concepts
- Args:
- src: The URL of the image to display (must be a valid HTTP/HTTPS image URL)
- alt: Alternative text describing the image (for accessibility)
- title: Optional title to display below the image
- description: Optional description providing context about the image
- Returns: An image card with the image, title, and description
- The image will automatically be displayed in the chat.
5. scrape_webpage: Scrape and extract the main content from a webpage.
- Use this when the user wants you to READ and UNDERSTAND the actual content of a webpage.
- IMPORTANT: This is different from link_preview:
* link_preview: Only fetches metadata (title, description, thumbnail) for display
* scrape_webpage: Actually reads the FULL page content so you can analyze/summarize it
- Trigger scenarios:
* "Read this article and summarize it"
* "What does this page say about X?"
* "Summarize this blog post for me"
* "Tell me the key points from this article"
* "What's in this webpage?"
* "Can you analyze this article?"
- Args:
- url: The URL of the webpage to scrape (must be HTTP/HTTPS)
- max_length: Maximum content length to return (default: 50000 chars)
- Returns: The page title, description, full content (in markdown), word count, and metadata
- After scraping, you will have the full article text and can analyze, summarize, or answer questions about it.
- IMAGES: The scraped content may contain image URLs in markdown format like `![alt text](image_url)`.
* When you find relevant/important images in the scraped content, use the `display_image` tool to show them to the user.
* This makes your response more visual and engaging.
* Prioritize showing: diagrams, charts, infographics, key illustrations, or images that help explain the content.
* Don't show every image - just the most relevant 1-3 images that enhance understanding.
</tools>
<tool_call_examples>
- User: "Fetch all my notes and what's in them?"
- Call: `search_knowledge_base(query="*", top_k=50, connectors_to_search=["NOTE"])`
- User: "What did I discuss on Slack last week about the React migration?"
- Call: `search_knowledge_base(query="React migration", connectors_to_search=["SLACK_CONNECTOR"], start_date="YYYY-MM-DD", end_date="YYYY-MM-DD")`
- User: "Give me a podcast about AI trends based on what we discussed"
- First search for relevant content, then call: `generate_podcast(source_content="Based on our conversation and search results: [detailed summary of chat + search findings]", podcast_title="AI Trends Podcast")`
- User: "Create a podcast summary of this conversation"
- Call: `generate_podcast(source_content="Complete conversation summary:\\n\\nUser asked about [topic 1]:\\n[Your detailed response]\\n\\nUser then asked about [topic 2]:\\n[Your detailed response]\\n\\n[Continue for all exchanges in the conversation]", podcast_title="Conversation Summary")`
- User: "Make a podcast about quantum computing"
- First search: `search_knowledge_base(query="quantum computing")`
- Then: `generate_podcast(source_content="Key insights about quantum computing from the knowledge base:\\n\\n[Comprehensive summary of all relevant search results with key facts, concepts, and findings]", podcast_title="Quantum Computing Explained")`
- User: "Check out https://dev.to/some-article"
- Call: `link_preview(url="https://dev.to/some-article")`
- User: "What's this blog post about? https://example.com/blog/post"
- Call: `link_preview(url="https://example.com/blog/post")`
- User: "https://github.com/some/repo"
- Call: `link_preview(url="https://github.com/some/repo")`
- User: "Show me this image: https://example.com/image.png"
- Call: `display_image(src="https://example.com/image.png", alt="User shared image")`
- User: "Can you display a diagram of a neural network?"
- Call: `display_image(src="https://example.com/neural-network.png", alt="Neural network diagram", title="Neural Network Architecture", description="A visual representation of a neural network with input, hidden, and output layers")`
- User: "Read this article and summarize it for me: https://example.com/blog/ai-trends"
- Call: `scrape_webpage(url="https://example.com/blog/ai-trends")`
- After getting the content, provide a summary based on the scraped text
- User: "What does this page say about machine learning? https://docs.example.com/ml-guide"
- Call: `scrape_webpage(url="https://docs.example.com/ml-guide")`
- Then answer the question using the extracted content
- User: "Summarize this blog post: https://medium.com/some-article"
- Call: `scrape_webpage(url="https://medium.com/some-article")`
- Provide a comprehensive summary of the article content
- User: "Read this tutorial and explain it: https://example.com/ml-tutorial"
- First: `scrape_webpage(url="https://example.com/ml-tutorial")`
- Then, if the content contains useful diagrams/images like `![Neural Network Diagram](https://example.com/nn-diagram.png)`:
- Call: `display_image(src="https://example.com/nn-diagram.png", alt="Neural Network Diagram", title="Neural Network Architecture")`
- Then provide your explanation, referencing the displayed image
</tool_call_examples>
"""
SURFSENSE_CITATION_INSTRUCTIONS = """
<citation_instructions>
CRITICAL CITATION REQUIREMENTS:
1. For EVERY piece of information you include from the documents, add a citation in the format [citation:chunk_id] where chunk_id is the exact value from the `<chunk id='...'>` tag inside `<document_content>`.
2. Make sure ALL factual statements from the documents have proper citations.
3. If multiple chunks support the same point, include all relevant citations [citation:chunk_id1], [citation:chunk_id2].
4. You MUST use the exact chunk_id values from the `<chunk id='...'>` attributes. Do not create your own citation numbers.
5. Every citation MUST be in the format [citation:chunk_id] where chunk_id is the exact chunk id value.
6. Never modify or change the chunk_id - always use the original values exactly as provided in the chunk tags.
7. Do not return citations as clickable links.
8. Never format citations as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only.
9. Citations must ONLY appear as [citation:chunk_id] or [citation:chunk_id1], [citation:chunk_id2] format - never with parentheses, hyperlinks, or other formatting.
10. Never make up chunk IDs. Only use chunk_id values that are explicitly provided in the `<chunk id='...'>` tags.
11. If you are unsure about a chunk_id, do not include a citation rather than guessing or making one up.
<document_structure_example>
The documents you receive are structured like this:
<document>
<document_metadata>
<document_id>42</document_id>
<document_type>GITHUB_CONNECTOR</document_type>
<title><![CDATA[Some repo / file / issue title]]></title>
<url><![CDATA[https://example.com]]></url>
<metadata_json><![CDATA[{{"any":"other metadata"}}]]></metadata_json>
</document_metadata>
<document_content>
<chunk id='123'><![CDATA[First chunk text...]]></chunk>
<chunk id='124'><![CDATA[Second chunk text...]]></chunk>
</document_content>
</document>
IMPORTANT: You MUST cite using the chunk ids (e.g. 123, 124). Do NOT cite document_id.
</document_structure_example>
<citation_format>
- Every fact from the documents must have a citation in the format [citation:chunk_id] where chunk_id is the EXACT id value from a `<chunk id='...'>` tag
- Citations should appear at the end of the sentence containing the information they support
- Multiple citations should be separated by commas: [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3]
- No need to return references section. Just citations in answer.
- NEVER create your own citation format - use the exact chunk_id values from the documents in the [citation:chunk_id] format
- NEVER format citations as clickable links or as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only
- NEVER make up chunk IDs if you are unsure about the chunk_id. It is better to omit the citation than to guess
</citation_format>
<citation_examples>
CORRECT citation formats:
- [citation:5]
- [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3]
INCORRECT citation formats (DO NOT use):
- Using parentheses and markdown links: ([citation:5](https://github.com/MODSetter/SurfSense))
- Using parentheses around brackets: ([citation:5])
- Using hyperlinked text: [link to source 5](https://example.com)
- Using footnote style: ... library¹
- Making up source IDs when source_id is unknown
- Using old IEEE format: [1], [2], [3]
- Using source types instead of IDs: [citation:GITHUB_CONNECTOR] instead of [citation:5]
</citation_examples>
<citation_output_example>
Based on your GitHub repositories and video content, Python's asyncio library provides tools for writing concurrent code using the async/await syntax [citation:5]. It's particularly useful for I/O-bound and high-level structured network code [citation:5].
The key advantage of asyncio is that it can improve performance by allowing other code to run while waiting for I/O operations to complete [citation:12]. This makes it excellent for scenarios like web scraping, API calls, database operations, or any situation where your program spends time waiting for external resources.
However, from your video learning, it's important to note that asyncio is not suitable for CPU-bound tasks as it runs on a single thread [citation:12]. For computationally intensive work, you'd want to use multiprocessing instead.
</citation_output_example>
</citation_instructions>
"""
# Anti-citation prompt - used when citations are disabled
# This explicitly tells the model NOT to include citations
SURFSENSE_NO_CITATION_INSTRUCTIONS = """
<citation_instructions>
IMPORTANT: Citations are DISABLED for this configuration.
DO NOT include any citations in your responses. Specifically:
1. Do NOT use the [citation:chunk_id] format anywhere in your response.
2. Do NOT reference document IDs, chunk IDs, or source IDs.
3. Simply provide the information naturally without any citation markers.
4. Write your response as if you're having a normal conversation, incorporating the information from your knowledge seamlessly.
When answering questions based on documents from the knowledge base:
- Present the information directly and confidently
- Do not mention that information comes from specific documents or chunks
- Integrate facts naturally into your response without attribution markers
Your goal is to provide helpful, informative answers in a clean, readable format without any citation notation.
</citation_instructions>
"""
def build_surfsense_system_prompt(
today: datetime | None = None,
) -> str:
"""
Build the SurfSense system prompt with default settings.
This is a convenience function that builds the prompt with:
- Default system instructions
- Tools instructions (always included)
- Citation instructions enabled
Args:
today: Optional datetime for today's date (defaults to current UTC date)
Returns:
Complete system prompt string
"""
resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat()
return (
SURFSENSE_SYSTEM_INSTRUCTIONS.format(resolved_today=resolved_today)
+ SURFSENSE_TOOLS_INSTRUCTIONS
+ SURFSENSE_CITATION_INSTRUCTIONS
)
def build_configurable_system_prompt(
custom_system_instructions: str | None = None,
use_default_system_instructions: bool = True,
citations_enabled: bool = True,
today: datetime | None = None,
) -> str:
"""
Build a configurable SurfSense system prompt based on NewLLMConfig settings.
The prompt is composed of three parts:
1. System Instructions - either custom or default SURFSENSE_SYSTEM_INSTRUCTIONS
2. Tools Instructions - always included (SURFSENSE_TOOLS_INSTRUCTIONS)
3. Citation Instructions - either SURFSENSE_CITATION_INSTRUCTIONS or SURFSENSE_NO_CITATION_INSTRUCTIONS
Args:
custom_system_instructions: Custom system instructions to use. If empty/None and
use_default_system_instructions is True, defaults to
SURFSENSE_SYSTEM_INSTRUCTIONS.
use_default_system_instructions: Whether to use default instructions when
custom_system_instructions is empty/None.
citations_enabled: Whether to include citation instructions (True) or
anti-citation instructions (False).
today: Optional datetime for today's date (defaults to current UTC date)
Returns:
Complete system prompt string
"""
resolved_today = (today or datetime.now(UTC)).astimezone(UTC).date().isoformat()
# Determine system instructions
if custom_system_instructions and custom_system_instructions.strip():
# Use custom instructions, injecting the date placeholder if present
system_instructions = custom_system_instructions.format(
resolved_today=resolved_today
)
elif use_default_system_instructions:
# Use default instructions
system_instructions = SURFSENSE_SYSTEM_INSTRUCTIONS.format(
resolved_today=resolved_today
)
else:
# No system instructions (edge case)
system_instructions = ""
# Tools instructions are always included
tools_instructions = SURFSENSE_TOOLS_INSTRUCTIONS
# Citation instructions based on toggle
citation_instructions = (
SURFSENSE_CITATION_INSTRUCTIONS
if citations_enabled
else SURFSENSE_NO_CITATION_INSTRUCTIONS
)
return system_instructions + tools_instructions + citation_instructions
def get_default_system_instructions() -> str:
"""
Get the default system instructions template.
This is useful for populating the UI with the default value when
creating a new NewLLMConfig.
Returns:
Default system instructions string (with {resolved_today} placeholder)
"""
return SURFSENSE_SYSTEM_INSTRUCTIONS.strip()
SURFSENSE_SYSTEM_PROMPT = build_surfsense_system_prompt()

View file

@ -0,0 +1,52 @@
"""
Tools module for SurfSense deep agent.
This module contains all the tools available to the SurfSense agent.
To add a new tool, see the documentation in registry.py.
Available tools:
- search_knowledge_base: Search the user's personal knowledge base
- generate_podcast: Generate audio podcasts from content
- link_preview: Fetch rich previews for URLs
- display_image: Display images in chat
- scrape_webpage: Extract content from webpages
"""
# Registry exports
# Tool factory exports (for direct use)
from .display_image import create_display_image_tool
from .knowledge_base import (
create_search_knowledge_base_tool,
format_documents_for_context,
search_knowledge_base_async,
)
from .link_preview import create_link_preview_tool
from .podcast import create_generate_podcast_tool
from .registry import (
BUILTIN_TOOLS,
ToolDefinition,
build_tools,
get_all_tool_names,
get_default_enabled_tools,
get_tool_by_name,
)
from .scrape_webpage import create_scrape_webpage_tool
__all__ = [
# Registry
"BUILTIN_TOOLS",
"ToolDefinition",
"build_tools",
# Tool factories
"create_display_image_tool",
"create_generate_podcast_tool",
"create_link_preview_tool",
"create_scrape_webpage_tool",
"create_search_knowledge_base_tool",
# Knowledge base utilities
"format_documents_for_context",
"get_all_tool_names",
"get_default_enabled_tools",
"get_tool_by_name",
"search_knowledge_base_async",
]

View file

@ -0,0 +1,105 @@
"""
Display image tool for the SurfSense agent.
This module provides a tool for displaying images in the chat UI
with metadata like title, description, and source attribution.
"""
import hashlib
from typing import Any
from urllib.parse import urlparse
from langchain_core.tools import tool
def extract_domain(url: str) -> str:
"""Extract the domain from a URL."""
try:
parsed = urlparse(url)
domain = parsed.netloc
# Remove 'www.' prefix if present
if domain.startswith("www."):
domain = domain[4:]
return domain
except Exception:
return ""
def generate_image_id(src: str) -> str:
"""Generate a unique ID for an image."""
hash_val = hashlib.md5(src.encode()).hexdigest()[:12]
return f"image-{hash_val}"
def create_display_image_tool():
"""
Factory function to create the display_image tool.
Returns:
A configured tool function for displaying images.
"""
@tool
async def display_image(
src: str,
alt: str = "Image",
title: str | None = None,
description: str | None = None,
) -> dict[str, Any]:
"""
Display an image in the chat with metadata.
Use this tool when you want to show an image to the user.
This displays the image with an optional title, description,
and source attribution.
Common use cases:
- Showing an image from a URL the user mentioned
- Displaying a diagram or chart you're referencing
- Showing example images when explaining concepts
Args:
src: The URL of the image to display (must be a valid HTTP/HTTPS URL)
alt: Alternative text describing the image (for accessibility)
title: Optional title to display below the image
description: Optional description providing context about the image
Returns:
A dictionary containing image metadata for the UI to render:
- id: Unique identifier for this image
- assetId: The image URL (for deduplication)
- src: The image URL
- alt: Alt text for accessibility
- title: Image title (if provided)
- description: Image description (if provided)
- domain: Source domain
"""
image_id = generate_image_id(src)
# Ensure URL has protocol
if not src.startswith(("http://", "https://")):
src = f"https://{src}"
domain = extract_domain(src)
# Determine aspect ratio based on common image sources
ratio = "16:9" # Default
if "unsplash.com" in src or "pexels.com" in src:
ratio = "16:9"
elif (
"imgur.com" in src or "github.com" in src or "githubusercontent.com" in src
):
ratio = "auto"
return {
"id": image_id,
"assetId": src,
"src": src,
"alt": alt,
"title": title,
"description": description,
"domain": domain,
"ratio": ratio,
}
return display_image

View file

@ -0,0 +1,607 @@
"""
Knowledge base search tool for the SurfSense agent.
This module provides:
- Connector constants and normalization
- Async knowledge base search across multiple connectors
- Document formatting for LLM context
- Tool factory for creating search_knowledge_base tools
"""
import json
from datetime import datetime
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.connector_service import ConnectorService
# =============================================================================
# Connector Constants and Normalization
# =============================================================================
# Canonical connector values used internally by ConnectorService
_ALL_CONNECTORS: list[str] = [
"EXTENSION",
"FILE",
"SLACK_CONNECTOR",
"NOTION_CONNECTOR",
"YOUTUBE_VIDEO",
"GITHUB_CONNECTOR",
"ELASTICSEARCH_CONNECTOR",
"LINEAR_CONNECTOR",
"JIRA_CONNECTOR",
"CONFLUENCE_CONNECTOR",
"CLICKUP_CONNECTOR",
"GOOGLE_CALENDAR_CONNECTOR",
"GOOGLE_GMAIL_CONNECTOR",
"DISCORD_CONNECTOR",
"AIRTABLE_CONNECTOR",
"TAVILY_API",
"SEARXNG_API",
"LINKUP_API",
"BAIDU_SEARCH_API",
"LUMA_CONNECTOR",
"NOTE",
"BOOKSTACK_CONNECTOR",
"CRAWLED_URL",
]
def _normalize_connectors(connectors_to_search: list[str] | None) -> list[str]:
"""
Normalize connectors provided by the model.
- Accepts user-facing enums like WEBCRAWLER_CONNECTOR and maps them to canonical
ConnectorService types.
- Drops unknown values.
- If None/empty, defaults to searching across all known connectors.
"""
if not connectors_to_search:
return list(_ALL_CONNECTORS)
normalized: list[str] = []
for raw in connectors_to_search:
c = (raw or "").strip().upper()
if not c:
continue
if c == "WEBCRAWLER_CONNECTOR":
c = "CRAWLED_URL"
normalized.append(c)
# de-dupe while preserving order + filter unknown
seen: set[str] = set()
out: list[str] = []
for c in normalized:
if c in seen:
continue
if c not in _ALL_CONNECTORS:
continue
seen.add(c)
out.append(c)
return out if out else list(_ALL_CONNECTORS)
# =============================================================================
# Document Formatting
# =============================================================================
def format_documents_for_context(documents: list[dict[str, Any]]) -> str:
"""
Format retrieved documents into a readable context string for the LLM.
Args:
documents: List of document dictionaries from connector search
Returns:
Formatted string with document contents and metadata
"""
if not documents:
return ""
# Group chunks by document id (preferred) to produce the XML structure.
#
# IMPORTANT: ConnectorService returns **document-grouped** results of the form:
# {
# "document": {...},
# "chunks": [{"chunk_id": 123, "content": "..."}, ...],
# "source": "NOTION_CONNECTOR" | "FILE" | ...
# }
#
# We must preserve chunk_id so citations like [citation:123] are possible.
grouped: dict[str, dict[str, Any]] = {}
for doc in documents:
document_info = (doc.get("document") or {}) if isinstance(doc, dict) else {}
metadata = (
(document_info.get("metadata") or {})
if isinstance(document_info, dict)
else {}
)
if not metadata and isinstance(doc, dict):
# Some result shapes may place metadata at the top level.
metadata = doc.get("metadata") or {}
source = (
(doc.get("source") if isinstance(doc, dict) else None)
or metadata.get("document_type")
or "UNKNOWN"
)
# Document identity (prefer document_id; otherwise fall back to type+title+url)
document_id_val = document_info.get("id")
title = (
document_info.get("title") or metadata.get("title") or "Untitled Document"
)
url = (
metadata.get("url")
or metadata.get("source")
or metadata.get("page_url")
or ""
)
doc_key = (
str(document_id_val)
if document_id_val is not None
else f"{source}::{title}::{url}"
)
if doc_key not in grouped:
grouped[doc_key] = {
"document_id": document_id_val
if document_id_val is not None
else doc_key,
"document_type": metadata.get("document_type") or source,
"title": title,
"url": url,
"metadata": metadata,
"chunks": [],
}
# Prefer document-grouped chunks if available
chunks_list = doc.get("chunks") if isinstance(doc, dict) else None
if isinstance(chunks_list, list) and chunks_list:
for ch in chunks_list:
if not isinstance(ch, dict):
continue
chunk_id = ch.get("chunk_id") or ch.get("id")
content = (ch.get("content") or "").strip()
if not content:
continue
grouped[doc_key]["chunks"].append(
{"chunk_id": chunk_id, "content": content}
)
continue
# Fallback: treat this as a flat chunk-like object
if not isinstance(doc, dict):
continue
chunk_id = doc.get("chunk_id") or doc.get("id")
content = (doc.get("content") or "").strip()
if not content:
continue
grouped[doc_key]["chunks"].append({"chunk_id": chunk_id, "content": content})
# Render XML expected by citation instructions
parts: list[str] = []
for g in grouped.values():
metadata_json = json.dumps(g["metadata"], ensure_ascii=False)
parts.append("<document>")
parts.append("<document_metadata>")
parts.append(f" <document_id>{g['document_id']}</document_id>")
parts.append(f" <document_type>{g['document_type']}</document_type>")
parts.append(f" <title><![CDATA[{g['title']}]]></title>")
parts.append(f" <url><![CDATA[{g['url']}]]></url>")
parts.append(f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>")
parts.append("</document_metadata>")
parts.append("")
parts.append("<document_content>")
for ch in g["chunks"]:
ch_content = ch["content"]
ch_id = ch["chunk_id"]
if ch_id is None:
parts.append(f" <chunk><![CDATA[{ch_content}]]></chunk>")
else:
parts.append(f" <chunk id='{ch_id}'><![CDATA[{ch_content}]]></chunk>")
parts.append("</document_content>")
parts.append("</document>")
parts.append("")
return "\n".join(parts).strip()
# =============================================================================
# Knowledge Base Search
# =============================================================================
async def search_knowledge_base_async(
query: str,
search_space_id: int,
db_session: AsyncSession,
connector_service: ConnectorService,
connectors_to_search: list[str] | None = None,
top_k: int = 10,
start_date: datetime | None = None,
end_date: datetime | None = None,
) -> str:
"""
Search the user's knowledge base for relevant documents.
This is the async implementation that searches across multiple connectors.
Args:
query: The search query
search_space_id: The user's search space ID
db_session: Database session
connector_service: Initialized connector service
connectors_to_search: Optional list of connector types to search. If omitted, searches all.
top_k: Number of results per connector
start_date: Optional start datetime (UTC) for filtering documents
end_date: Optional end datetime (UTC) for filtering documents
Returns:
Formatted string with search results
"""
all_documents = []
# Resolve date range (default last 2 years)
from app.agents.new_chat.utils import resolve_date_range
resolved_start_date, resolved_end_date = resolve_date_range(
start_date=start_date,
end_date=end_date,
)
connectors = _normalize_connectors(connectors_to_search)
for connector in connectors:
try:
if connector == "YOUTUBE_VIDEO":
_, chunks = await connector_service.search_youtube(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "EXTENSION":
_, chunks = await connector_service.search_extension(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "CRAWLED_URL":
_, chunks = await connector_service.search_crawled_urls(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "FILE":
_, chunks = await connector_service.search_files(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "SLACK_CONNECTOR":
_, chunks = await connector_service.search_slack(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "NOTION_CONNECTOR":
_, chunks = await connector_service.search_notion(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "GITHUB_CONNECTOR":
_, chunks = await connector_service.search_github(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "LINEAR_CONNECTOR":
_, chunks = await connector_service.search_linear(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "TAVILY_API":
_, chunks = await connector_service.search_tavily(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
)
all_documents.extend(chunks)
elif connector == "SEARXNG_API":
_, chunks = await connector_service.search_searxng(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
)
all_documents.extend(chunks)
elif connector == "LINKUP_API":
# Keep behavior aligned with researcher: default "standard"
_, chunks = await connector_service.search_linkup(
user_query=query,
search_space_id=search_space_id,
mode="standard",
)
all_documents.extend(chunks)
elif connector == "BAIDU_SEARCH_API":
_, chunks = await connector_service.search_baidu(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
)
all_documents.extend(chunks)
elif connector == "DISCORD_CONNECTOR":
_, chunks = await connector_service.search_discord(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "JIRA_CONNECTOR":
_, chunks = await connector_service.search_jira(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "GOOGLE_CALENDAR_CONNECTOR":
_, chunks = await connector_service.search_google_calendar(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "AIRTABLE_CONNECTOR":
_, chunks = await connector_service.search_airtable(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "GOOGLE_GMAIL_CONNECTOR":
_, chunks = await connector_service.search_google_gmail(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "CONFLUENCE_CONNECTOR":
_, chunks = await connector_service.search_confluence(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "CLICKUP_CONNECTOR":
_, chunks = await connector_service.search_clickup(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "LUMA_CONNECTOR":
_, chunks = await connector_service.search_luma(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "ELASTICSEARCH_CONNECTOR":
_, chunks = await connector_service.search_elasticsearch(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "NOTE":
_, chunks = await connector_service.search_notes(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
elif connector == "BOOKSTACK_CONNECTOR":
_, chunks = await connector_service.search_bookstack(
user_query=query,
search_space_id=search_space_id,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
all_documents.extend(chunks)
except Exception as e:
print(f"Error searching connector {connector}: {e}")
continue
# Deduplicate by content hash
seen_doc_ids: set[Any] = set()
seen_hashes: set[int] = set()
deduplicated: list[dict[str, Any]] = []
for doc in all_documents:
doc_id = (doc.get("document", {}) or {}).get("id")
content = (doc.get("content", "") or "").strip()
content_hash = hash(content)
if (doc_id and doc_id in seen_doc_ids) or content_hash in seen_hashes:
continue
if doc_id:
seen_doc_ids.add(doc_id)
seen_hashes.add(content_hash)
deduplicated.append(doc)
return format_documents_for_context(deduplicated)
def create_search_knowledge_base_tool(
search_space_id: int,
db_session: AsyncSession,
connector_service: ConnectorService,
):
"""
Factory function to create the search_knowledge_base tool with injected dependencies.
Args:
search_space_id: The user's search space ID
db_session: Database session
connector_service: Initialized connector service
Returns:
A configured tool function
"""
@tool
async def search_knowledge_base(
query: str,
top_k: int = 10,
start_date: str | None = None,
end_date: str | None = None,
connectors_to_search: list[str] | None = None,
) -> str:
"""
Search the user's personal knowledge base for relevant information.
Use this tool to find documents, notes, files, web pages, and other content
that may help answer the user's question.
IMPORTANT:
- If the user requests a specific source type (e.g. "my notes", "Slack messages"),
pass `connectors_to_search=[...]` using the enums below.
- If `connectors_to_search` is omitted/empty, the system will search broadly.
## Available connector enums for `connectors_to_search`
- EXTENSION: "Web content saved via SurfSense browser extension" (personal browsing history)
- FILE: "User-uploaded documents (PDFs, Word, etc.)" (personal files)
- NOTE: "SurfSense Notes" (notes created inside SurfSense)
- SLACK_CONNECTOR: "Slack conversations and shared content" (personal workspace communications)
- NOTION_CONNECTOR: "Notion workspace pages and databases" (personal knowledge management)
- YOUTUBE_VIDEO: "YouTube video transcripts and metadata" (personally saved videos)
- GITHUB_CONNECTOR: "GitHub repository content and issues" (personal repositories and interactions)
- ELASTICSEARCH_CONNECTOR: "Elasticsearch indexed documents and data" (personal Elasticsearch instances and custom data sources)
- LINEAR_CONNECTOR: "Linear project issues and discussions" (personal project management)
- JIRA_CONNECTOR: "Jira project issues, tickets, and comments" (personal project tracking)
- CONFLUENCE_CONNECTOR: "Confluence pages and comments" (personal project documentation)
- CLICKUP_CONNECTOR: "ClickUp tasks and project data" (personal task management)
- GOOGLE_CALENDAR_CONNECTOR: "Google Calendar events, meetings, and schedules" (personal calendar and time management)
- GOOGLE_GMAIL_CONNECTOR: "Google Gmail emails and conversations" (personal emails and communications)
- DISCORD_CONNECTOR: "Discord server conversations and shared content" (personal community communications)
- AIRTABLE_CONNECTOR: "Airtable records, tables, and database content" (personal data management and organization)
- TAVILY_API: "Tavily search API results" (personalized search results)
- SEARXNG_API: "SearxNG search API results" (personalized search results)
- LINKUP_API: "Linkup search API results" (personalized search results)
- BAIDU_SEARCH_API: "Baidu search API results" (personalized search results)
- LUMA_CONNECTOR: "Luma events"
- WEBCRAWLER_CONNECTOR: "Webpages indexed by SurfSense" (personally selected websites)
- BOOKSTACK_CONNECTOR: "BookStack pages" (personal documentation)
NOTE: `WEBCRAWLER_CONNECTOR` is mapped internally to the canonical document type `CRAWLED_URL`.
Args:
query: The search query - be specific and include key terms
top_k: Number of results to retrieve (default: 10)
start_date: Optional ISO date/datetime (e.g. "2025-12-12" or "2025-12-12T00:00:00+00:00")
end_date: Optional ISO date/datetime (e.g. "2025-12-19" or "2025-12-19T23:59:59+00:00")
connectors_to_search: Optional list of connector enums to search. If omitted, searches all.
Returns:
Formatted string with relevant documents and their content
"""
from app.agents.new_chat.utils import parse_date_or_datetime
parsed_start: datetime | None = None
parsed_end: datetime | None = None
if start_date:
parsed_start = parse_date_or_datetime(start_date)
if end_date:
parsed_end = parse_date_or_datetime(end_date)
return await search_knowledge_base_async(
query=query,
search_space_id=search_space_id,
db_session=db_session,
connector_service=connector_service,
connectors_to_search=connectors_to_search,
top_k=top_k,
start_date=parsed_start,
end_date=parsed_end,
)
return search_knowledge_base

View file

@ -0,0 +1,295 @@
"""
Link preview tool for the SurfSense agent.
This module provides a tool for fetching URL metadata (title, description,
Open Graph image, etc.) to display rich link previews in the chat UI.
"""
import hashlib
import re
from typing import Any
from urllib.parse import urlparse
import httpx
from langchain_core.tools import tool
def extract_domain(url: str) -> str:
"""Extract the domain from a URL."""
try:
parsed = urlparse(url)
domain = parsed.netloc
# Remove 'www.' prefix if present
if domain.startswith("www."):
domain = domain[4:]
return domain
except Exception:
return ""
def extract_og_content(html: str, property_name: str) -> str | None:
"""Extract Open Graph meta content from HTML."""
# Try og:property first
pattern = rf'<meta[^>]+property=["\']og:{property_name}["\'][^>]+content=["\']([^"\']+)["\']'
match = re.search(pattern, html, re.IGNORECASE)
if match:
return match.group(1)
# Try content before property
pattern = rf'<meta[^>]+content=["\']([^"\']+)["\'][^>]+property=["\']og:{property_name}["\']'
match = re.search(pattern, html, re.IGNORECASE)
if match:
return match.group(1)
return None
def extract_twitter_content(html: str, name: str) -> str | None:
"""Extract Twitter Card meta content from HTML."""
pattern = (
rf'<meta[^>]+name=["\']twitter:{name}["\'][^>]+content=["\']([^"\']+)["\']'
)
match = re.search(pattern, html, re.IGNORECASE)
if match:
return match.group(1)
# Try content before name
pattern = (
rf'<meta[^>]+content=["\']([^"\']+)["\'][^>]+name=["\']twitter:{name}["\']'
)
match = re.search(pattern, html, re.IGNORECASE)
if match:
return match.group(1)
return None
def extract_meta_description(html: str) -> str | None:
"""Extract meta description from HTML."""
pattern = r'<meta[^>]+name=["\']description["\'][^>]+content=["\']([^"\']+)["\']'
match = re.search(pattern, html, re.IGNORECASE)
if match:
return match.group(1)
# Try content before name
pattern = r'<meta[^>]+content=["\']([^"\']+)["\'][^>]+name=["\']description["\']'
match = re.search(pattern, html, re.IGNORECASE)
if match:
return match.group(1)
return None
def extract_title(html: str) -> str | None:
"""Extract title from HTML."""
# Try og:title first
og_title = extract_og_content(html, "title")
if og_title:
return og_title
# Try twitter:title
twitter_title = extract_twitter_content(html, "title")
if twitter_title:
return twitter_title
# Fall back to <title> tag
pattern = r"<title[^>]*>([^<]+)</title>"
match = re.search(pattern, html, re.IGNORECASE)
if match:
return match.group(1).strip()
return None
def extract_description(html: str) -> str | None:
"""Extract description from HTML."""
# Try og:description first
og_desc = extract_og_content(html, "description")
if og_desc:
return og_desc
# Try twitter:description
twitter_desc = extract_twitter_content(html, "description")
if twitter_desc:
return twitter_desc
# Fall back to meta description
return extract_meta_description(html)
def extract_image(html: str) -> str | None:
"""Extract image URL from HTML."""
# Try og:image first
og_image = extract_og_content(html, "image")
if og_image:
return og_image
# Try twitter:image
twitter_image = extract_twitter_content(html, "image")
if twitter_image:
return twitter_image
return None
def generate_preview_id(url: str) -> str:
"""Generate a unique ID for a link preview."""
hash_val = hashlib.md5(url.encode()).hexdigest()[:12]
return f"link-preview-{hash_val}"
def create_link_preview_tool():
"""
Factory function to create the link_preview tool.
Returns:
A configured tool function for fetching link previews.
"""
@tool
async def link_preview(url: str) -> dict[str, Any]:
"""
Fetch metadata for a URL to display a rich link preview.
Use this tool when the user shares a URL or asks about a specific webpage.
This tool fetches the page's Open Graph metadata (title, description, image)
to display a nice preview card in the chat.
Common triggers include:
- User shares a URL in the chat
- User asks "What's this link about?" or similar
- User says "Show me a preview of this page"
- User wants to preview an article or webpage
Args:
url: The URL to fetch metadata for. Must be a valid HTTP/HTTPS URL.
Returns:
A dictionary containing:
- id: Unique identifier for this preview
- assetId: The URL itself (for deduplication)
- kind: "link" (type of media card)
- href: The URL to open when clicked
- title: Page title
- description: Page description (if available)
- thumb: Thumbnail/preview image URL (if available)
- domain: The domain name
- error: Error message (if fetch failed)
"""
preview_id = generate_preview_id(url)
domain = extract_domain(url)
# Validate URL
if not url.startswith(("http://", "https://")):
url = f"https://{url}"
try:
async with httpx.AsyncClient(
timeout=10.0,
follow_redirects=True,
headers={
"User-Agent": "Mozilla/5.0 (compatible; SurfSenseBot/1.0; +https://surfsense.net)",
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
"Accept-Language": "en-US,en;q=0.5",
},
) as client:
response = await client.get(url)
response.raise_for_status()
# Get content type to ensure it's HTML
content_type = response.headers.get("content-type", "")
if "text/html" not in content_type.lower():
# Not an HTML page, return basic info
return {
"id": preview_id,
"assetId": url,
"kind": "link",
"href": url,
"title": url.split("/")[-1] or domain,
"description": f"File from {domain}",
"domain": domain,
}
html = response.text
# Extract metadata
title = extract_title(html) or domain
description = extract_description(html)
image = extract_image(html)
# Make sure image URL is absolute
if image and not image.startswith(("http://", "https://")):
if image.startswith("//"):
image = f"https:{image}"
elif image.startswith("/"):
parsed = urlparse(url)
image = f"{parsed.scheme}://{parsed.netloc}{image}"
# Clean up title and description (unescape HTML entities)
if title:
title = (
title.replace("&amp;", "&")
.replace("&lt;", "<")
.replace("&gt;", ">")
.replace("&quot;", '"')
.replace("&#39;", "'")
.replace("&apos;", "'")
)
if description:
description = (
description.replace("&amp;", "&")
.replace("&lt;", "<")
.replace("&gt;", ">")
.replace("&quot;", '"')
.replace("&#39;", "'")
.replace("&apos;", "'")
)
# Truncate long descriptions
if len(description) > 200:
description = description[:197] + "..."
return {
"id": preview_id,
"assetId": url,
"kind": "link",
"href": url,
"title": title,
"description": description,
"thumb": image,
"domain": domain,
}
except httpx.TimeoutException:
return {
"id": preview_id,
"assetId": url,
"kind": "link",
"href": url,
"title": domain or "Link",
"domain": domain,
"error": "Request timed out",
}
except httpx.HTTPStatusError as e:
return {
"id": preview_id,
"assetId": url,
"kind": "link",
"href": url,
"title": domain or "Link",
"domain": domain,
"error": f"HTTP {e.response.status_code}",
}
except Exception as e:
error_message = str(e)
print(f"[link_preview] Error fetching {url}: {error_message}")
return {
"id": preview_id,
"assetId": url,
"kind": "link",
"href": url,
"title": domain or "Link",
"domain": domain,
"error": f"Failed to fetch: {error_message[:50]}",
}
return link_preview

View file

@ -0,0 +1,173 @@
"""
Podcast generation tool for the SurfSense agent.
This module provides a factory function for creating the generate_podcast tool
that submits a Celery task for background podcast generation. The frontend
polls for completion and auto-updates when the podcast is ready.
Duplicate request prevention:
- Only one podcast can be generated at a time per search space
- Uses Redis to track active podcast tasks
- Returns a friendly message if a podcast is already being generated
"""
import os
from typing import Any
import redis
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
# Redis connection for tracking active podcast tasks
# Uses the same Redis instance as Celery
REDIS_URL = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
_redis_client: redis.Redis | None = None
def get_redis_client() -> redis.Redis:
"""Get or create Redis client for podcast task tracking."""
global _redis_client
if _redis_client is None:
_redis_client = redis.from_url(REDIS_URL, decode_responses=True)
return _redis_client
def get_active_podcast_key(search_space_id: int) -> str:
"""Generate Redis key for tracking active podcast task."""
return f"podcast:active:{search_space_id}"
def get_active_podcast_task(search_space_id: int) -> str | None:
"""Check if there's an active podcast task for this search space."""
try:
client = get_redis_client()
return client.get(get_active_podcast_key(search_space_id))
except Exception:
# If Redis is unavailable, allow the request (fail open)
return None
def set_active_podcast_task(search_space_id: int, task_id: str) -> None:
"""Mark a podcast task as active for this search space."""
try:
client = get_redis_client()
# Set with 30-minute expiry as safety net (podcast should complete before this)
client.setex(get_active_podcast_key(search_space_id), 1800, task_id)
except Exception as e:
print(f"[generate_podcast] Warning: Could not set active task in Redis: {e}")
def clear_active_podcast_task(search_space_id: int) -> None:
"""Clear the active podcast task for this search space."""
try:
client = get_redis_client()
client.delete(get_active_podcast_key(search_space_id))
except Exception as e:
print(f"[generate_podcast] Warning: Could not clear active task in Redis: {e}")
def create_generate_podcast_tool(
search_space_id: int,
db_session: AsyncSession,
):
"""
Factory function to create the generate_podcast tool with injected dependencies.
Args:
search_space_id: The user's search space ID
db_session: Database session (not used - Celery creates its own)
Returns:
A configured tool function for generating podcasts
"""
@tool
async def generate_podcast(
source_content: str,
podcast_title: str = "SurfSense Podcast",
user_prompt: str | None = None,
) -> dict[str, Any]:
"""
Generate a podcast from the provided content.
Use this tool when the user asks to create, generate, or make a podcast.
Common triggers include phrases like:
- "Give me a podcast about this"
- "Create a podcast from this conversation"
- "Generate a podcast summary"
- "Make a podcast about..."
- "Turn this into a podcast"
The tool will start generating a podcast in the background.
The podcast will be available once generation completes.
IMPORTANT: Only one podcast can be generated at a time. If a podcast
is already being generated, this tool will return a message asking
the user to wait.
Args:
source_content: The text content to convert into a podcast.
This can be a summary, research findings, or any text
the user wants transformed into an audio podcast.
podcast_title: Title for the podcast (default: "SurfSense Podcast")
user_prompt: Optional instructions for podcast style, tone, or format.
For example: "Make it casual and fun" or "Focus on the key insights"
Returns:
A dictionary containing:
- status: "processing" (task submitted), "already_generating", or "error"
- task_id: The Celery task ID for polling status (if processing)
- title: The podcast title
- message: Status message for the user
"""
try:
# Check if a podcast is already being generated for this search space
active_task_id = get_active_podcast_task(search_space_id)
if active_task_id:
print(
f"[generate_podcast] Blocked duplicate request. Active task: {active_task_id}"
)
return {
"status": "already_generating",
"task_id": active_task_id,
"title": podcast_title,
"message": "A podcast is already being generated. Please wait for it to complete before requesting another one.",
}
# Import Celery task here to avoid circular imports
from app.tasks.celery_tasks.podcast_tasks import (
generate_content_podcast_task,
)
# Submit Celery task for background processing
task = generate_content_podcast_task.delay(
source_content=source_content,
search_space_id=search_space_id,
podcast_title=podcast_title,
user_prompt=user_prompt,
)
# Mark this task as active
set_active_podcast_task(search_space_id, task.id)
print(f"[generate_podcast] Submitted Celery task: {task.id}")
# Return immediately with task_id for polling
return {
"status": "processing",
"task_id": task.id,
"title": podcast_title,
"message": "Podcast generation started. This may take a few minutes.",
}
except Exception as e:
error_message = str(e)
print(f"[generate_podcast] Error submitting task: {error_message}")
return {
"status": "error",
"error": error_message,
"title": podcast_title,
"task_id": None,
}
return generate_podcast

View file

@ -0,0 +1,230 @@
"""
Tools registry for SurfSense deep agent.
This module provides a registry pattern for managing tools in the SurfSense agent.
It makes it easy for OSS contributors to add new tools by:
1. Creating a tool factory function in a new file in this directory
2. Registering the tool in the BUILTIN_TOOLS list below
Example of adding a new tool:
------------------------------
1. Create your tool file (e.g., `tools/my_tool.py`):
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
def create_my_tool(search_space_id: int, db_session: AsyncSession):
@tool
async def my_tool(param: str) -> dict:
'''My tool description.'''
# Your implementation
return {"result": "success"}
return my_tool
2. Import and register in this file:
from .my_tool import create_my_tool
# Add to BUILTIN_TOOLS list:
ToolDefinition(
name="my_tool",
description="Description of what your tool does",
factory=lambda deps: create_my_tool(
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
),
requires=["search_space_id", "db_session"],
),
"""
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any
from langchain_core.tools import BaseTool
from .display_image import create_display_image_tool
from .knowledge_base import create_search_knowledge_base_tool
from .link_preview import create_link_preview_tool
from .podcast import create_generate_podcast_tool
from .scrape_webpage import create_scrape_webpage_tool
# =============================================================================
# Tool Definition
# =============================================================================
@dataclass
class ToolDefinition:
"""
Definition of a tool that can be added to the agent.
Attributes:
name: Unique identifier for the tool
description: Human-readable description of what the tool does
factory: Callable that creates the tool. Receives a dict of dependencies.
requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session")
enabled_by_default: Whether the tool is enabled when no explicit config is provided
"""
name: str
description: str
factory: Callable[[dict[str, Any]], BaseTool]
requires: list[str] = field(default_factory=list)
enabled_by_default: bool = True
# =============================================================================
# Built-in Tools Registry
# =============================================================================
# Registry of all built-in tools
# Contributors: Add your new tools here!
BUILTIN_TOOLS: list[ToolDefinition] = [
# Core tool - searches the user's knowledge base
ToolDefinition(
name="search_knowledge_base",
description="Search the user's personal knowledge base for relevant information",
factory=lambda deps: create_search_knowledge_base_tool(
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
connector_service=deps["connector_service"],
),
requires=["search_space_id", "db_session", "connector_service"],
),
# Podcast generation tool
ToolDefinition(
name="generate_podcast",
description="Generate an audio podcast from provided content",
factory=lambda deps: create_generate_podcast_tool(
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
),
requires=["search_space_id", "db_session"],
),
# Link preview tool - fetches Open Graph metadata for URLs
ToolDefinition(
name="link_preview",
description="Fetch metadata for a URL to display a rich preview card",
factory=lambda deps: create_link_preview_tool(),
requires=[],
),
# Display image tool - shows images in the chat
ToolDefinition(
name="display_image",
description="Display an image in the chat with metadata",
factory=lambda deps: create_display_image_tool(),
requires=[],
),
# Web scraping tool - extracts content from webpages
ToolDefinition(
name="scrape_webpage",
description="Scrape and extract the main content from a webpage",
factory=lambda deps: create_scrape_webpage_tool(
firecrawl_api_key=deps.get("firecrawl_api_key"),
),
requires=[], # firecrawl_api_key is optional
),
# =========================================================================
# ADD YOUR CUSTOM TOOLS BELOW
# =========================================================================
# Example:
# ToolDefinition(
# name="my_custom_tool",
# description="What my tool does",
# factory=lambda deps: create_my_custom_tool(...),
# requires=["search_space_id"],
# ),
]
# =============================================================================
# Registry Functions
# =============================================================================
def get_tool_by_name(name: str) -> ToolDefinition | None:
"""Get a tool definition by its name."""
for tool_def in BUILTIN_TOOLS:
if tool_def.name == name:
return tool_def
return None
def get_all_tool_names() -> list[str]:
"""Get names of all registered tools."""
return [tool_def.name for tool_def in BUILTIN_TOOLS]
def get_default_enabled_tools() -> list[str]:
"""Get names of tools that are enabled by default."""
return [tool_def.name for tool_def in BUILTIN_TOOLS if tool_def.enabled_by_default]
def build_tools(
dependencies: dict[str, Any],
enabled_tools: list[str] | None = None,
disabled_tools: list[str] | None = None,
additional_tools: list[BaseTool] | None = None,
) -> list[BaseTool]:
"""
Build the list of tools for the agent.
Args:
dependencies: Dict containing all possible dependencies:
- search_space_id: The search space ID
- db_session: Database session
- connector_service: Connector service instance
- firecrawl_api_key: Optional Firecrawl API key
enabled_tools: Explicit list of tool names to enable. If None, uses defaults.
disabled_tools: List of tool names to disable (applied after enabled_tools).
additional_tools: Extra tools to add (e.g., custom tools not in registry).
Returns:
List of configured tool instances ready for the agent.
Example:
# Use all default tools
tools = build_tools(deps)
# Use only specific tools
tools = build_tools(deps, enabled_tools=["search_knowledge_base", "link_preview"])
# Use defaults but disable podcast
tools = build_tools(deps, disabled_tools=["generate_podcast"])
# Add custom tools
tools = build_tools(deps, additional_tools=[my_custom_tool])
"""
# Determine which tools to enable
if enabled_tools is not None:
tool_names_to_use = set(enabled_tools)
else:
tool_names_to_use = set(get_default_enabled_tools())
# Apply disabled list
if disabled_tools:
tool_names_to_use -= set(disabled_tools)
# Build the tools
tools: list[BaseTool] = []
for tool_def in BUILTIN_TOOLS:
if tool_def.name not in tool_names_to_use:
continue
# Check that all required dependencies are provided
missing_deps = [dep for dep in tool_def.requires if dep not in dependencies]
if missing_deps:
raise ValueError(
f"Tool '{tool_def.name}' requires dependencies: {missing_deps}"
)
# Create the tool
tool = tool_def.factory(dependencies)
tools.append(tool)
# Add any additional custom tools
if additional_tools:
tools.extend(additional_tools)
return tools

View file

@ -0,0 +1,198 @@
"""
Web scraping tool for the SurfSense agent.
This module provides a tool for scraping and extracting content from webpages
using the existing WebCrawlerConnector. The scraped content can be used by
the agent to answer questions about web pages.
"""
import hashlib
from typing import Any
from urllib.parse import urlparse
from langchain_core.tools import tool
from app.connectors.webcrawler_connector import WebCrawlerConnector
def extract_domain(url: str) -> str:
"""Extract the domain from a URL."""
try:
parsed = urlparse(url)
domain = parsed.netloc
# Remove 'www.' prefix if present
if domain.startswith("www."):
domain = domain[4:]
return domain
except Exception:
return ""
def generate_scrape_id(url: str) -> str:
"""Generate a unique ID for a scraped webpage."""
hash_val = hashlib.md5(url.encode()).hexdigest()[:12]
return f"scrape-{hash_val}"
def truncate_content(content: str, max_length: int = 50000) -> tuple[str, bool]:
"""
Truncate content to a maximum length.
Returns:
Tuple of (truncated_content, was_truncated)
"""
if len(content) <= max_length:
return content, False
# Try to truncate at a sentence boundary
truncated = content[:max_length]
last_period = truncated.rfind(".")
last_newline = truncated.rfind("\n\n")
# Use the later of the two boundaries, or just truncate
boundary = max(last_period, last_newline)
if boundary > max_length * 0.8: # Only use boundary if it's not too far back
truncated = content[: boundary + 1]
return truncated + "\n\n[Content truncated...]", True
def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
"""
Factory function to create the scrape_webpage tool.
Args:
firecrawl_api_key: Optional Firecrawl API key for premium web scraping.
Falls back to Chromium/Trafilatura if not provided.
Returns:
A configured tool function for scraping webpages.
"""
@tool
async def scrape_webpage(
url: str,
max_length: int = 50000,
) -> dict[str, Any]:
"""
Scrape and extract the main content from a webpage.
Use this tool when the user wants you to read, summarize, or answer
questions about a specific webpage's content. This tool actually
fetches and reads the full page content.
Common triggers:
- "Read this article and summarize it"
- "What does this page say about X?"
- "Summarize this blog post for me"
- "Tell me the key points from this article"
- "What's in this webpage?"
Args:
url: The URL of the webpage to scrape (must be HTTP/HTTPS)
max_length: Maximum content length to return (default: 50000 chars)
Returns:
A dictionary containing:
- id: Unique identifier for this scrape
- assetId: The URL (for deduplication)
- kind: "article" (type of content)
- href: The URL to open when clicked
- title: Page title
- description: Brief description or excerpt
- content: The extracted main content (markdown format)
- domain: The domain name
- word_count: Approximate word count
- was_truncated: Whether content was truncated
- error: Error message (if scraping failed)
"""
scrape_id = generate_scrape_id(url)
domain = extract_domain(url)
# Validate and normalize URL
if not url.startswith(("http://", "https://")):
url = f"https://{url}"
try:
# Create webcrawler connector
connector = WebCrawlerConnector(firecrawl_api_key=firecrawl_api_key)
# Crawl the URL
result, error = await connector.crawl_url(url, formats=["markdown"])
if error:
return {
"id": scrape_id,
"assetId": url,
"kind": "article",
"href": url,
"title": domain or "Webpage",
"domain": domain,
"error": error,
}
if not result:
return {
"id": scrape_id,
"assetId": url,
"kind": "article",
"href": url,
"title": domain or "Webpage",
"domain": domain,
"error": "No content returned from crawler",
}
# Extract content and metadata
content = result.get("content", "")
metadata = result.get("metadata", {})
# Get title from metadata
title = metadata.get("title", "")
if not title:
title = domain or url.split("/")[-1] or "Webpage"
# Get description from metadata
description = metadata.get("description", "")
if not description and content:
# Use first paragraph as description
first_para = content.split("\n\n")[0] if content else ""
description = (
first_para[:300] + "..." if len(first_para) > 300 else first_para
)
# Truncate content if needed
content, was_truncated = truncate_content(content, max_length)
# Calculate word count
word_count = len(content.split())
return {
"id": scrape_id,
"assetId": url,
"kind": "article",
"href": url,
"title": title,
"description": description,
"content": content,
"domain": domain,
"word_count": word_count,
"was_truncated": was_truncated,
"crawler_type": result.get("crawler_type", "unknown"),
"author": metadata.get("author"),
"date": metadata.get("date"),
}
except Exception as e:
error_message = str(e)
print(f"[scrape_webpage] Error scraping {url}: {error_message}")
return {
"id": scrape_id,
"assetId": url,
"kind": "article",
"href": url,
"title": domain or "Webpage",
"domain": domain,
"error": f"Failed to scrape: {error_message[:100]}",
}
return scrape_webpage

View file

@ -0,0 +1,63 @@
"""
Utility functions for SurfSense agents.
This module provides shared utility functions used across the new_chat agent modules.
"""
from datetime import UTC, datetime, timedelta
def parse_date_or_datetime(value: str) -> datetime:
"""
Parse either an ISO date (YYYY-MM-DD) or ISO datetime into an aware UTC datetime.
- If `value` is a date, interpret it as start-of-day in UTC.
- If `value` is a datetime without timezone, assume UTC.
Args:
value: ISO date or datetime string
Returns:
Aware datetime object in UTC
Raises:
ValueError: If the date string is empty or invalid
"""
raw = (value or "").strip()
if not raw:
raise ValueError("Empty date string")
# Date-only
if "T" not in raw:
d = datetime.fromisoformat(raw).date()
return datetime(d.year, d.month, d.day, tzinfo=UTC)
# Datetime (may be naive)
dt = datetime.fromisoformat(raw)
if dt.tzinfo is None:
return dt.replace(tzinfo=UTC)
return dt.astimezone(UTC)
def resolve_date_range(
start_date: datetime | None,
end_date: datetime | None,
) -> tuple[datetime, datetime]:
"""
Resolve a date range, defaulting to the last 2 years if not provided.
Ensures start_date <= end_date.
Args:
start_date: Optional start datetime (UTC)
end_date: Optional end datetime (UTC)
Returns:
Tuple of (resolved_start_date, resolved_end_date) in UTC
"""
resolved_end = end_date or datetime.now(UTC)
resolved_start = start_date or (resolved_end - timedelta(days=730))
if resolved_start > resolved_end:
resolved_start, resolved_end = resolved_end, resolved_start
return resolved_start, resolved_end

View file

@ -16,7 +16,6 @@ class Configuration:
# create assistants (https://langchain-ai.github.io/langgraph/cloud/how-tos/configuration_cloud/)
# and when you invoke the graph
podcast_title: str
user_id: str
search_space_id: int
user_prompt: str | None = None

View file

@ -12,7 +12,7 @@ from litellm import aspeech
from app.config import config as app_config
from app.services.kokoro_tts_service import get_kokoro_tts_service
from app.services.llm_service import get_user_long_context_llm
from app.services.llm_service import get_document_summary_llm
from .configuration import Configuration
from .prompts import get_podcast_generation_prompt
@ -27,14 +27,15 @@ async def create_podcast_transcript(
# Get configuration from runnable config
configuration = Configuration.from_runnable_config(config)
user_id = configuration.user_id
search_space_id = configuration.search_space_id
user_prompt = configuration.user_prompt
# Get user's long context LLM
llm = await get_user_long_context_llm(state.db_session, user_id, search_space_id)
# Get search space's document summary LLM
llm = await get_document_summary_llm(state.db_session, search_space_id)
if not llm:
error_message = f"No long context LLM configured for user {user_id} in search space {search_space_id}"
error_message = (
f"No document summary LLM configured for search space {search_space_id}"
)
print(error_message)
raise RuntimeError(error_message)

View file

@ -1,30 +0,0 @@
"""Define the configurable parameters for the agent."""
from __future__ import annotations
from dataclasses import dataclass, fields
from langchain_core.runnables import RunnableConfig
@dataclass(kw_only=True)
class Configuration:
"""The configuration for the agent."""
# Input parameters provided at invocation
user_query: str
connectors_to_search: list[str]
user_id: str
search_space_id: int
document_ids_to_add_in_context: list[int]
language: str | None = None
top_k: int = 10
@classmethod
def from_runnable_config(
cls, config: RunnableConfig | None = None
) -> Configuration:
"""Create a Configuration instance from a RunnableConfig object."""
configurable = (config.get("configurable") or {}) if config else {}
_fields = {f.name for f in fields(cls) if f.init}
return cls(**{k: v for k, v in configurable.items() if k in _fields})

View file

@ -1,47 +0,0 @@
from langgraph.graph import StateGraph
from .configuration import Configuration
from .nodes import (
generate_further_questions,
handle_qna_workflow,
reformulate_user_query,
)
from .state import State
def build_graph():
"""
Build and return the LangGraph workflow.
This function constructs the researcher agent graph for Q&A workflow.
The workflow follows a simple path:
1. Reformulate user query based on chat history
2. Handle QNA workflow (fetch documents and generate answer)
3. Generate follow-up questions
Returns:
A compiled LangGraph workflow
"""
# Define a new graph with state class
workflow = StateGraph(State, config_schema=Configuration)
# Add nodes to the graph
workflow.add_node("reformulate_user_query", reformulate_user_query)
workflow.add_node("handle_qna_workflow", handle_qna_workflow)
workflow.add_node("generate_further_questions", generate_further_questions)
# Define the edges - simple linear flow for QNA
workflow.add_edge("__start__", "reformulate_user_query")
workflow.add_edge("reformulate_user_query", "handle_qna_workflow")
workflow.add_edge("handle_qna_workflow", "generate_further_questions")
workflow.add_edge("generate_further_questions", "__end__")
# Compile the workflow into an executable graph
graph = workflow.compile()
graph.name = "Surfsense Researcher" # This defines the custom name in LangSmith
return graph
# Compile the graph once when the module is loaded
graph = build_graph()

File diff suppressed because it is too large Load diff

View file

@ -1,140 +0,0 @@
import datetime
def _build_language_instruction(language: str | None = None):
"""Build language instruction for prompts."""
if language:
return f"\n\nIMPORTANT: Please respond in {language} language. All your responses, explanations, and analysis should be written in {language}."
return ""
def get_further_questions_system_prompt():
return f"""
Today's date: {datetime.datetime.now().strftime("%Y-%m-%d")}
<further_questions_system>
You are an expert research assistant specializing in generating contextually relevant follow-up questions. Your task is to analyze the chat history and available documents to suggest further questions that would naturally extend the conversation and provide additional value to the user.
<input>
- chat_history: Provided in XML format within <chat_history> tags, containing <user> and <assistant> message pairs that show the chronological conversation flow. This provides context about what has already been discussed.
- available_documents: Provided in XML format within <documents> tags, containing individual <document> elements with <document_metadata> and <document_content> sections. Each document contains multiple `<chunk id='...'>...</chunk>` blocks inside <document_content>. This helps understand what information is accessible for answering potential follow-up questions.
</input>
<output_format>
A JSON object with the following structure:
{{
"further_questions": [
{{
"id": 0,
"question": "further qn 1"
}},
{{
"id": 1,
"question": "further qn 2"
}}
]
}}
</output_format>
<instructions>
1. **Analyze Chat History:** Review the entire conversation flow to understand:
* The main topics and themes discussed
* The user's interests and areas of focus
* Questions that have been asked and answered
* Any gaps or areas that could be explored further
* The depth level of the current discussion
2. **Evaluate Available Documents:** Consider the documents in context to identify:
* Additional information that hasn't been explored yet
* Related topics that could be of interest
* Specific details or data points that could warrant deeper investigation
* Cross-references or connections between different documents
3. **Generate Relevant Follow-up Questions:** Create 3-5 further questions that:
* Are directly related to the ongoing conversation but explore new angles
* Can be reasonably answered using the available documents or knowledge base
* Progress the conversation forward rather than repeating previous topics
* Match the user's apparent level of interest and expertise
* Are specific and actionable rather than overly broad
* Consider practical applications, comparisons, deeper analysis, or related concepts
4. **Ensure Question Quality:** Each question should:
* Be clear and well-formulated
* Provide genuine value to the user
* Be distinct from other suggested questions
* Be answerable within the current context
* Encourage meaningful exploration of the topic
5. **Prioritize and Order:** Arrange questions by relevance and natural progression:
* Most directly related to the current discussion first
* Questions that build upon previous answers
* Questions that explore practical applications or implications
* Questions that introduce related but new concepts
6. **Adhere Strictly to Output Format:** Ensure the final output is a valid JSON object with:
* Correct field names (`further_questions`, `id`, `question`)
* Sequential numbering starting from 0
* Proper data types and JSON formatting
</instructions>
<examples>
Input:
<chat_history>
<user>What are the best machine learning algorithms for text classification?</user>
<assistant>For text classification, several algorithms work well depending on your specific needs:
**Traditional Methods:**
- **Support Vector Machines (SVM)** - Excellent for high-dimensional text data
- **Naive Bayes** - Simple, fast, and works well with small datasets
- **Logistic Regression** - Good baseline with interpretable results
**Modern Deep Learning:**
- **Neural Networks** - More complex but can capture intricate patterns
- **Transformer models** - State-of-the-art for most text classification tasks
The choice depends on your dataset size, computational resources, and accuracy requirements.</assistant>
</chat_history>
<documents>
<document>
<metadata>
<source_id>101</source_id>
<source_type>FILE</source_type>
</metadata>
<content>
# Machine Learning for Text Classification: A Comprehensive Guide
## Performance Comparison
Recent studies show that transformer-based models achieve 95%+ accuracy on most text classification benchmarks, while traditional methods like SVM typically achieve 85-90% accuracy.
## Dataset Considerations
- Small datasets (< 1000 samples): Naive Bayes, SVM
- Large datasets (> 10,000 samples): Neural networks, transformers
- Imbalanced datasets: Require special handling with techniques like SMOTE
</content>
</document>
</documents>
Output:
{{
"further_questions": [
{{
"id": 0,
"question": "What are the key differences in performance between traditional algorithms like SVM and modern deep learning approaches for text classification?"
}},
{{
"id": 1,
"question": "How do you handle imbalanced datasets when training text classification models?"
}},
{{
"id": 2,
"question": "What preprocessing techniques are most effective for improving text classification accuracy?"
}},
{{
"id": 3,
"question": "Are there specific domains or use cases where certain classification algorithms perform better than others?"
}}
]
}}
</examples>
</further_questions_system>
"""

View file

@ -1,5 +0,0 @@
"""QnA Agent."""
from .graph import graph
__all__ = ["graph"]

View file

@ -1,31 +0,0 @@
"""Define the configurable parameters for the agent."""
from __future__ import annotations
from dataclasses import dataclass, fields
from typing import Any
from langchain_core.runnables import RunnableConfig
@dataclass(kw_only=True)
class Configuration:
"""The configuration for the Q&A agent."""
# Configuration parameters for the Q&A agent
user_query: str # The user's question to answer
reformulated_query: str # The reformulated query
relevant_documents: list[
Any
] # Documents provided directly to the agent for answering
search_space_id: int # Search space identifier
language: str | None = None # Language for responses
@classmethod
def from_runnable_config(
cls, config: RunnableConfig | None = None
) -> Configuration:
"""Create a Configuration instance from a RunnableConfig object."""
configurable = (config.get("configurable") or {}) if config else {}
_fields = {f.name for f in fields(cls) if f.init}
return cls(**{k: v for k, v in configurable.items() if k in _fields})

View file

@ -1,201 +0,0 @@
"""Default system prompts for Q&A agent.
The prompt system is modular with 3 parts:
- Part 1 (Base): Core instructions for answering questions (no citations)
- Part 2 (Citations): Citation-specific instructions and formatting rules
- Part 3 (Custom): User's custom instructions (empty by default)
Combinations:
- Part 1 only: Answers without citations
- Part 1 + Part 2: Answers with citations
- Part 1 + Part 2 + Part 3: Answers with citations and custom instructions
"""
# Part 1: Base system prompt for answering without citations
DEFAULT_QNA_BASE_PROMPT = """Today's date: {date}
You are SurfSense, an advanced AI research assistant that provides detailed, well-researched answers to user questions by synthesizing information from multiple personal knowledge sources.{language_instruction}
{chat_history_section}
<knowledge_sources>
- EXTENSION: "Web content saved via SurfSense browser extension" (personal browsing history)
- FILE: "User-uploaded documents (PDFs, Word, etc.)" (personal files)
- SLACK_CONNECTOR: "Slack conversations and shared content" (personal workspace communications)
- NOTION_CONNECTOR: "Notion workspace pages and databases" (personal knowledge management)
- YOUTUBE_VIDEO: "YouTube video transcripts and metadata" (personally saved videos)
- GITHUB_CONNECTOR: "GitHub repository content and issues" (personal repositories and interactions)
- ELASTICSEARCH_CONNECTOR: "Elasticsearch indexed documents and data" (personal Elasticsearch instances and custom data sources)
- LINEAR_CONNECTOR: "Linear project issues and discussions" (personal project management)
- JIRA_CONNECTOR: "Jira project issues, tickets, and comments" (personal project tracking)
- CONFLUENCE_CONNECTOR: "Confluence pages and comments" (personal project documentation)
- CLICKUP_CONNECTOR: "ClickUp tasks and project data" (personal task management)
- GOOGLE_CALENDAR_CONNECTOR: "Google Calendar events, meetings, and schedules" (personal calendar and time management)
- GOOGLE_GMAIL_CONNECTOR: "Google Gmail emails and conversations" (personal emails and communications)
- DISCORD_CONNECTOR: "Discord server conversations and shared content" (personal community communications)
- AIRTABLE_CONNECTOR: "Airtable records, tables, and database content" (personal data management and organization)
- TAVILY_API: "Tavily search API results" (personalized search results)
- LINKUP_API: "Linkup search API results" (personalized search results)
- LUMA_CONNECTOR: "Luma events"
- WEBCRAWLER_CONNECTOR: "Webpages indexed by SurfSense" (personally selected websites)
</knowledge_sources>
<instructions>
1. Review the chat history to understand the conversation context and any previous topics discussed.
2. Carefully analyze all provided documents in the <document> sections.
3. Extract relevant information that directly addresses the user's question.
4. Provide a comprehensive, detailed answer using information from the user's personal knowledge sources.
5. Structure your answer logically and conversationally, as if having a detailed discussion with the user.
6. Use your own words to synthesize and connect ideas from the documents.
7. If documents contain conflicting information, acknowledge this and present both perspectives.
8. If the user's question cannot be fully answered with the provided documents, clearly state what information is missing.
9. Provide actionable insights and practical information when relevant to the user's question.
10. Use the chat history to maintain conversation continuity and refer to previous discussions when relevant.
11. Remember that all knowledge sources contain personal information - provide answers that reflect this personal context.
12. Be conversational and engaging while maintaining accuracy.
</instructions>
<format>
- Write in a clear, conversational tone suitable for detailed Q&A discussions
- Provide comprehensive answers that thoroughly address the user's question
- Use appropriate paragraphs and structure for readability
- ALWAYS provide personalized answers that reflect the user's own knowledge and context
- Be thorough and detailed in your explanations while remaining focused on the user's specific question
- If asking follow-up questions would be helpful, suggest them at the end of your response
</format>
<user_query_instructions>
When you see a user query, focus exclusively on providing a detailed, comprehensive answer using information from the provided documents, which contain the user's personal knowledge and data.
Make sure your response:
1. Considers the chat history for context and conversation continuity
2. Directly and thoroughly answers the user's question with personalized information from their own knowledge sources
3. Is conversational, engaging, and detailed
4. Acknowledges the personal nature of the information being provided
5. Offers follow-up suggestions when appropriate
</user_query_instructions>
"""
# Part 2: Citation-specific instructions to add citation capabilities
DEFAULT_QNA_CITATION_INSTRUCTIONS = """
<citation_instructions>
CRITICAL CITATION REQUIREMENTS:
1. For EVERY piece of information you include from the documents, add a citation in the format [citation:chunk_id] where chunk_id is the exact value from the `<chunk id='...'>` tag inside `<document_content>`.
2. Make sure ALL factual statements from the documents have proper citations.
3. If multiple chunks support the same point, include all relevant citations [citation:chunk_id1], [citation:chunk_id2].
4. You MUST use the exact chunk_id values from the `<chunk id='...'>` attributes. Do not create your own citation numbers.
5. Every citation MUST be in the format [citation:chunk_id] where chunk_id is the exact chunk id value.
6. Never modify or change the chunk_id - always use the original values exactly as provided in the chunk tags.
7. Do not return citations as clickable links.
8. Never format citations as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only.
9. Citations must ONLY appear as [citation:chunk_id] or [citation:chunk_id1], [citation:chunk_id2] format - never with parentheses, hyperlinks, or other formatting.
10. Never make up chunk IDs. Only use chunk_id values that are explicitly provided in the `<chunk id='...'>` tags.
11. If you are unsure about a chunk_id, do not include a citation rather than guessing or making one up.
<document_structure_example>
The documents you receive are structured like this:
<document>
<document_metadata>
<document_id>42</document_id>
<document_type>GITHUB_CONNECTOR</document_type>
<title><![CDATA[Some repo / file / issue title]]></title>
<url><![CDATA[https://example.com]]></url>
<metadata_json><![CDATA[{{"any":"other metadata"}}]]></metadata_json>
</document_metadata>
<document_content>
<chunk id='123'><![CDATA[First chunk text...]]></chunk>
<chunk id='124'><![CDATA[Second chunk text...]]></chunk>
</document_content>
</document>
IMPORTANT: You MUST cite using the chunk ids (e.g. 123, 124). Do NOT cite document_id.
</document_structure_example>
<citation_format>
- Every fact from the documents must have a citation in the format [citation:chunk_id] where chunk_id is the EXACT id value from a `<chunk id='...'>` tag
- Citations should appear at the end of the sentence containing the information they support
- Multiple citations should be separated by commas: [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3]
- No need to return references section. Just citations in answer.
- NEVER create your own citation format - use the exact chunk_id values from the documents in the [citation:chunk_id] format
- NEVER format citations as clickable links or as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only
- NEVER make up chunk IDs if you are unsure about the chunk_id. It is better to omit the citation than to guess
</citation_format>
<citation_examples>
CORRECT citation formats:
- [citation:5]
- [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3]
INCORRECT citation formats (DO NOT use):
- Using parentheses and markdown links: ([citation:5](https://github.com/MODSetter/SurfSense))
- Using parentheses around brackets: ([citation:5])
- Using hyperlinked text: [link to source 5](https://example.com)
- Using footnote style: ... library¹
- Making up source IDs when source_id is unknown
- Using old IEEE format: [1], [2], [3]
- Using source types instead of IDs: [citation:GITHUB_CONNECTOR] instead of [citation:5]
</citation_examples>
<citation_output_example>
Based on your GitHub repositories and video content, Python's asyncio library provides tools for writing concurrent code using the async/await syntax [citation:5]. It's particularly useful for I/O-bound and high-level structured network code [citation:5].
The key advantage of asyncio is that it can improve performance by allowing other code to run while waiting for I/O operations to complete [citation:12]. This makes it excellent for scenarios like web scraping, API calls, database operations, or any situation where your program spends time waiting for external resources.
However, from your video learning, it's important to note that asyncio is not suitable for CPU-bound tasks as it runs on a single thread [citation:12]. For computationally intensive work, you'd want to use multiprocessing instead.
</citation_output_example>
</citation_instructions>
"""
# Part 3: User's custom instructions (empty by default, can be set by user from UI)
DEFAULT_QNA_CUSTOM_INSTRUCTIONS = ""
# Full prompt with all parts combined (for backward compatibility and migration)
DEFAULT_QNA_CITATION_PROMPT = (
DEFAULT_QNA_BASE_PROMPT
+ DEFAULT_QNA_CITATION_INSTRUCTIONS
+ DEFAULT_QNA_CUSTOM_INSTRUCTIONS
)
DEFAULT_QNA_NO_DOCUMENTS_PROMPT = """Today's date: {date}
You are SurfSense, an advanced AI research assistant that provides helpful, detailed answers to user questions in a conversational manner.{language_instruction}
{chat_history_section}
<context>
The user has asked a question but there are no specific documents from their personal knowledge base available to answer it. You should provide a helpful response based on:
1. The conversation history and context
2. Your general knowledge and expertise
3. Understanding of the user's needs and interests based on our conversation
</context>
<instructions>
1. Provide a comprehensive, helpful answer to the user's question
2. Draw upon the conversation history to understand context and the user's specific needs
3. Use your general knowledge to provide accurate, detailed information
4. Be conversational and engaging, as if having a detailed discussion with the user
5. Acknowledge when you're drawing from general knowledge rather than their personal sources
6. Provide actionable insights and practical information when relevant
7. Structure your answer logically and clearly
8. If the question would benefit from personalized information from their knowledge base, gently suggest they might want to add relevant content to SurfSense
9. Be honest about limitations while still being maximally helpful
10. Maintain the helpful, knowledgeable tone that users expect from SurfSense
</instructions>
<format>
- Write in a clear, conversational tone suitable for detailed Q&A discussions
- Provide comprehensive answers that thoroughly address the user's question
- Use appropriate paragraphs and structure for readability
- No citations are needed since you're using general knowledge
- Be thorough and detailed in your explanations while remaining focused on the user's specific question
- If asking follow-up questions would be helpful, suggest them at the end of your response
- When appropriate, mention that adding relevant content to their SurfSense knowledge base could provide more personalized answers
</format>
<user_query_instructions>
When answering the user's question without access to their personal documents:
1. Review the chat history to understand conversation context and maintain continuity
2. Provide the most helpful and comprehensive answer possible using general knowledge
3. Be conversational and engaging
4. Draw upon conversation history for context
5. Be clear that you're providing general information
6. Suggest ways the user could get more personalized answers by expanding their knowledge base when relevant
</user_query_instructions>
"""

View file

@ -1,21 +0,0 @@
from langgraph.graph import StateGraph
from .configuration import Configuration
from .nodes import answer_question, rerank_documents
from .state import State
# Define a new graph
workflow = StateGraph(State, config_schema=Configuration)
# Add the nodes to the graph
workflow.add_node("rerank_documents", rerank_documents)
workflow.add_node("answer_question", answer_question)
# Connect the nodes
workflow.add_edge("__start__", "rerank_documents")
workflow.add_edge("rerank_documents", "answer_question")
workflow.add_edge("answer_question", "__end__")
# Compile the workflow into an executable graph
graph = workflow.compile()
graph.name = "SurfSense QnA Agent" # This defines the custom name in LangSmith

View file

@ -1,297 +0,0 @@
import datetime
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy import select
from app.db import SearchSpace
from app.services.reranker_service import RerankerService
from ..utils import (
calculate_token_count,
format_documents_section,
langchain_chat_history_to_str,
optimize_documents_for_token_limit,
)
from .configuration import Configuration
from .default_prompts import (
DEFAULT_QNA_BASE_PROMPT,
DEFAULT_QNA_CITATION_INSTRUCTIONS,
DEFAULT_QNA_NO_DOCUMENTS_PROMPT,
)
from .state import State
def _build_language_instruction(language: str | None = None):
"""Build language instruction for prompts."""
if language:
return f"\n\nIMPORTANT: Please respond in {language} language. All your responses, explanations, and analysis should be written in {language}."
return ""
def _build_chat_history_section(chat_history: str | None = None):
"""Build chat history section for prompts."""
if chat_history:
return f"""
<chat_history>
{chat_history if chat_history else "NO CHAT HISTORY PROVIDED"}
</chat_history>
"""
return """
<chat_history>
NO CHAT HISTORY PROVIDED
</chat_history>
"""
def _format_system_prompt(
prompt_template: str,
chat_history: str | None = None,
language: str | None = None,
):
"""Format a system prompt template with dynamic values."""
date = datetime.datetime.now().strftime("%Y-%m-%d")
language_instruction = _build_language_instruction(language)
chat_history_section = _build_chat_history_section(chat_history)
return prompt_template.format(
date=date,
language_instruction=language_instruction,
chat_history_section=chat_history_section,
)
async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, Any]:
"""
Rerank the documents based on relevance to the user's question.
This node takes the relevant documents provided in the configuration,
reranks them using the reranker service based on the user's query,
and updates the state with the reranked documents.
Documents are now document-grouped with a `chunks` list. Reranking is done
using the concatenated `content` field, and the full structure (including
`chunks`) is preserved for proper citation formatting.
If reranking is disabled, returns the original documents without processing.
Returns:
Dict containing the reranked documents.
"""
# Get configuration and relevant documents
configuration = Configuration.from_runnable_config(config)
documents = configuration.relevant_documents
user_query = configuration.user_query
reformulated_query = configuration.reformulated_query
# If no documents were provided, return empty list
if not documents or len(documents) == 0:
return {"reranked_documents": []}
# Get reranker service from app config
reranker_service = RerankerService.get_reranker_instance()
# If reranking is not enabled, sort by existing score and return
if not reranker_service:
print("Reranking is disabled. Sorting documents by existing score.")
sorted_documents = sorted(
documents, key=lambda x: x.get("score", 0), reverse=True
)
return {"reranked_documents": sorted_documents}
# Perform reranking
try:
# Pass documents directly to reranker - it will use:
# - "content" (concatenated chunk text) for scoring
# - "chunk_id" (primary chunk id) for matching
# The full document structure including "chunks" is preserved
reranked_docs = reranker_service.rerank_documents(
user_query + "\n" + reformulated_query, documents
)
# Sort by score in descending order
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
print(f"Reranked {len(reranked_docs)} documents for Q&A query: {user_query}")
return {"reranked_documents": reranked_docs}
except Exception as e:
print(f"Error during reranking: {e!s}")
# Fall back to original documents if reranking fails
return {"reranked_documents": documents}
async def answer_question(
state: State, config: RunnableConfig, writer: StreamWriter
) -> dict[str, Any]:
"""
Answer the user's question using the provided documents with real-time streaming.
This node takes the relevant documents provided in the configuration and uses
an LLM to generate a comprehensive answer to the user's question with
proper citations. The citations follow [citation:chunk_id] format using chunk IDs from the
`<chunk id='...'>` tags in the provided documents. If no documents are provided, it will use chat history to generate
an answer.
The response is streamed token-by-token for real-time updates to the frontend.
Returns:
Dict containing the final answer in the "final_answer" key.
"""
from app.services.llm_service import get_fast_llm
# Get configuration and relevant documents from configuration
configuration = Configuration.from_runnable_config(config)
documents = state.reranked_documents
user_query = configuration.user_query
search_space_id = configuration.search_space_id
language = configuration.language
# Get streaming service from state
streaming_service = state.streaming_service
# Fetch search space to get QnA configuration
result = await state.db_session.execute(
select(SearchSpace).where(SearchSpace.id == search_space_id)
)
search_space = result.scalar_one_or_none()
if not search_space:
error_message = f"Search space {search_space_id} not found"
print(error_message)
raise RuntimeError(error_message)
# Get QnA configuration from search space
citations_enabled = search_space.citations_enabled
custom_instructions_text = search_space.qna_custom_instructions or ""
# Use constants for base prompt and citation instructions
qna_base_prompt = DEFAULT_QNA_BASE_PROMPT
qna_citation_instructions = (
DEFAULT_QNA_CITATION_INSTRUCTIONS if citations_enabled else ""
)
qna_custom_instructions = (
f"\n<special_important_custom_instructions>\n{custom_instructions_text}\n</special_important_custom_instructions>"
if custom_instructions_text
else ""
)
# Get search space's fast LLM
llm = await get_fast_llm(state.db_session, search_space_id)
if not llm:
error_message = f"No fast LLM configured for search space {search_space_id}"
print(error_message)
raise RuntimeError(error_message)
# Determine if we have documents and optimize for token limits
has_documents_initially = documents and len(documents) > 0
chat_history_str = langchain_chat_history_to_str(state.chat_history)
if has_documents_initially:
# Compose the full citation prompt: base + citation instructions + custom instructions
full_citation_prompt_template = (
qna_base_prompt + qna_citation_instructions + qna_custom_instructions
)
# Create base message template for token calculation (without documents)
base_human_message_template = f"""
User's question:
<user_query>
{user_query}
</user_query>
Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner.
"""
# Use initial system prompt for token calculation
initial_system_prompt = _format_system_prompt(
full_citation_prompt_template, chat_history_str, language
)
base_messages = [
SystemMessage(content=initial_system_prompt),
HumanMessage(content=base_human_message_template),
]
# Optimize documents to fit within token limits
optimized_documents, has_optimized_documents = (
optimize_documents_for_token_limit(documents, base_messages, llm.model)
)
# Update state based on optimization result
documents = optimized_documents
has_documents = has_optimized_documents
else:
has_documents = False
# Choose system prompt based on final document availability
# With documents: use base + citation instructions + custom instructions
# Without documents: use the default no-documents prompt from constants
if has_documents:
full_citation_prompt_template = (
qna_base_prompt + qna_citation_instructions + qna_custom_instructions
)
system_prompt = _format_system_prompt(
full_citation_prompt_template, chat_history_str, language
)
else:
system_prompt = _format_system_prompt(
DEFAULT_QNA_NO_DOCUMENTS_PROMPT + qna_custom_instructions,
chat_history_str,
language,
)
# Generate documents section
documents_text = (
format_documents_section(
documents, "Source material from your personal knowledge base"
)
if has_documents
else ""
)
# Create final human message content
instruction_text = (
"Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner."
if has_documents
else "Please provide a helpful answer to the user's question based on our conversation history and your general knowledge. Engage in a conversational manner."
)
human_message_content = f"""
{documents_text}
User's question:
<user_query>
{user_query}
</user_query>
{instruction_text}
"""
# Create final messages for the LLM
messages_with_chat_history = [
SystemMessage(content=system_prompt),
HumanMessage(content=human_message_content),
]
# Log final token count
total_tokens = calculate_token_count(messages_with_chat_history, llm.model)
print(f"Final token count: {total_tokens}")
# Stream the LLM response token by token
final_answer = ""
async for chunk in llm.astream(messages_with_chat_history):
# Extract the content from the chunk
if hasattr(chunk, "content") and chunk.content:
token = chunk.content
final_answer += token
# Stream the token to the frontend via custom stream
if streaming_service:
writer({"yield_value": streaming_service.format_text_chunk(token)})
return {"final_answer": final_answer}

View file

@ -1,32 +0,0 @@
"""Define the state structures for the agent."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.streaming_service import StreamingService
@dataclass
class State:
"""Defines the dynamic state for the Q&A agent during execution.
This state tracks the database session, chat history, and the outputs
generated by the agent's nodes during question answering.
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
for more information.
"""
# Runtime context
db_session: AsyncSession
# Streaming service for real-time token streaming
streaming_service: StreamingService | None = None
chat_history: list[Any] | None = field(default_factory=list)
# OUTPUT: Populated by agent nodes
reranked_documents: list[Any] | None = None
final_answer: str | None = None

View file

@ -1,38 +0,0 @@
"""Define the state structures for the agent."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.streaming_service import StreamingService
@dataclass
class State:
"""Defines the dynamic state for the agent during execution.
This state tracks the database session and the outputs generated by the agent's nodes.
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
for more information.
"""
# Runtime context (not part of actual graph state)
db_session: AsyncSession
# Streaming service
streaming_service: StreamingService
chat_history: list[Any] | None = field(default_factory=list)
reformulated_query: str | None = field(default=None)
further_questions: Any | None = field(default=None)
# Temporary field to hold reranked documents from sub-agents for further question generation
reranked_documents: list[Any] | None = field(default=None)
# OUTPUT: Populated by agent nodes
# Using field to explicitly mark as part of state
final_written_report: str | None = field(default=None)

View file

@ -1,292 +0,0 @@
import json
from typing import Any, NamedTuple
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from litellm import get_model_info, token_counter
class DocumentTokenInfo(NamedTuple):
"""Information about a document and its token cost."""
index: int
document: dict[str, Any]
formatted_content: str
token_count: int
def get_connector_emoji(connector_name: str) -> str:
"""Get an appropriate emoji for a connector type."""
connector_emojis = {
"YOUTUBE_VIDEO": "📹",
"EXTENSION": "🧩",
"FILE": "📄",
"SLACK_CONNECTOR": "💬",
"NOTION_CONNECTOR": "📘",
"GITHUB_CONNECTOR": "🐙",
"LINEAR_CONNECTOR": "📊",
"JIRA_CONNECTOR": "🎫",
"DISCORD_CONNECTOR": "🗨️",
"TAVILY_API": "🔍",
"LINKUP_API": "🔗",
"BAIDU_SEARCH_API": "🇨🇳",
"GOOGLE_CALENDAR_CONNECTOR": "📅",
"AIRTABLE_CONNECTOR": "🗃️",
"LUMA_CONNECTOR": "",
"ELASTICSEARCH_CONNECTOR": "",
"WEBCRAWLER_CONNECTOR": "🌐",
"BOOKSTACK_CONNECTOR": "📚",
"NOTE": "📝",
}
return connector_emojis.get(connector_name, "🔎")
def get_connector_friendly_name(connector_name: str) -> str:
"""Convert technical connector IDs to user-friendly names."""
connector_friendly_names = {
"YOUTUBE_VIDEO": "YouTube",
"EXTENSION": "Browser Extension",
"FILE": "Files",
"SLACK_CONNECTOR": "Slack",
"NOTION_CONNECTOR": "Notion",
"GITHUB_CONNECTOR": "GitHub",
"LINEAR_CONNECTOR": "Linear",
"JIRA_CONNECTOR": "Jira",
"CONFLUENCE_CONNECTOR": "Confluence",
"GOOGLE_CALENDAR_CONNECTOR": "Google Calendar",
"DISCORD_CONNECTOR": "Discord",
"TAVILY_API": "Tavily Search",
"LINKUP_API": "Linkup Search",
"BAIDU_SEARCH_API": "Baidu Search",
"AIRTABLE_CONNECTOR": "Airtable",
"LUMA_CONNECTOR": "Luma",
"ELASTICSEARCH_CONNECTOR": "Elasticsearch",
"WEBCRAWLER_CONNECTOR": "Web Pages",
"BOOKSTACK_CONNECTOR": "BookStack",
"NOTE": "Notes",
}
return connector_friendly_names.get(connector_name, connector_name)
def convert_langchain_messages_to_dict(
messages: list[BaseMessage],
) -> list[dict[str, str]]:
"""Convert LangChain messages to format expected by token_counter."""
role_mapping = {"system": "system", "human": "user", "ai": "assistant"}
converted_messages = []
for msg in messages:
role = role_mapping.get(getattr(msg, "type", None), "user")
converted_messages.append({"role": role, "content": str(msg.content)})
return converted_messages
def format_document_for_citation(document: dict[str, Any]) -> str:
"""Format a single document for citation in the new document+chunks XML format.
IMPORTANT:
- Citations must reference real DB chunk IDs: `[citation:<chunk_id>]`
- Document metadata is included under <document_metadata>, but citations are NOT document_id-based.
"""
def _to_cdata(value: Any) -> str:
text = "" if value is None else str(value)
# Safely nest CDATA even if the content includes "]]>"
return "<![CDATA[" + text.replace("]]>", "]]]]><![CDATA[>") + "]]>"
doc_info = document.get("document", {}) or {}
metadata = doc_info.get("metadata", {}) or {}
doc_id = doc_info.get("id", "")
title = doc_info.get("title", "")
document_type = doc_info.get("document_type", "CRAWLED_URL")
url = (
metadata.get("url")
or metadata.get("source")
or metadata.get("page_url")
or metadata.get("VisitedWebPageURL")
or ""
)
metadata_json = json.dumps(metadata, ensure_ascii=False)
chunks = document.get("chunks") or []
if not chunks:
# Fallback: treat `content` as a single chunk (no chunk_id available for citation)
chunks = [{"chunk_id": "", "content": document.get("content", "")}]
chunks_xml = "\n".join(
[
f"<chunk id='{chunk.get('chunk_id', '')}'>{_to_cdata(chunk.get('content', ''))}</chunk>"
for chunk in chunks
]
)
return f"""<document>
<document_metadata>
<document_id>{doc_id}</document_id>
<document_type>{document_type}</document_type>
<title>{_to_cdata(title)}</title>
<url>{_to_cdata(url)}</url>
<metadata_json>{_to_cdata(metadata_json)}</metadata_json>
</document_metadata>
<document_content>
{chunks_xml}
</document_content>
</document>"""
def format_documents_section(
documents: list[dict[str, Any]], section_title: str = "Source material"
) -> str:
"""Format multiple documents into a complete documents section."""
if not documents:
return ""
formatted_docs = [format_document_for_citation(doc) for doc in documents]
return f"""{section_title}:
<documents>
{chr(10).join(formatted_docs)}
</documents>"""
def calculate_document_token_costs(
documents: list[dict[str, Any]], model: str
) -> list[DocumentTokenInfo]:
"""Pre-calculate token costs for each document."""
document_token_info = []
for i, doc in enumerate(documents):
formatted_doc = format_document_for_citation(doc)
# Calculate token count for this document
token_count = token_counter(
messages=[{"role": "user", "content": formatted_doc}], model=model
)
document_token_info.append(
DocumentTokenInfo(
index=i,
document=doc,
formatted_content=formatted_doc,
token_count=token_count,
)
)
return document_token_info
def find_optimal_documents_with_binary_search(
document_tokens: list[DocumentTokenInfo], available_tokens: int
) -> list[DocumentTokenInfo]:
"""Use binary search to find the maximum number of documents that fit within token limit."""
if not document_tokens or available_tokens <= 0:
return []
left, right = 0, len(document_tokens)
optimal_docs = []
while left <= right:
mid = (left + right) // 2
current_docs = document_tokens[:mid]
current_token_sum = sum(doc_info.token_count for doc_info in current_docs)
if current_token_sum <= available_tokens:
optimal_docs = current_docs
left = mid + 1
else:
right = mid - 1
return optimal_docs
def get_model_context_window(model_name: str) -> int:
"""Get the total context window size for a model (input + output tokens)."""
try:
model_info = get_model_info(model_name)
context_window = model_info.get("max_input_tokens", 4096) # Default fallback
return context_window
except Exception as e:
print(
f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}"
)
return 4096 # Conservative fallback
def optimize_documents_for_token_limit(
documents: list[dict[str, Any]], base_messages: list[BaseMessage], model_name: str
) -> tuple[list[dict[str, Any]], bool]:
"""
Optimize documents to fit within token limits using binary search.
Args:
documents: List of documents with content and metadata
base_messages: Base messages without documents (chat history + system + human message template)
model_name: Model name for token counting (required)
output_token_buffer: Number of tokens to reserve for model output
Returns:
Tuple of (optimized_documents, has_documents_remaining)
"""
if not documents:
return [], False
model = model_name
context_window = get_model_context_window(model)
# Calculate base token cost
base_messages_dict = convert_langchain_messages_to_dict(base_messages)
base_tokens = token_counter(messages=base_messages_dict, model=model)
available_tokens_for_docs = context_window - base_tokens
print(
f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}"
)
if available_tokens_for_docs <= 0:
print("No tokens available for documents after base content and output buffer")
return [], False
# Calculate token costs for all documents
document_token_info = calculate_document_token_costs(documents, model)
# Find optimal number of documents using binary search
optimal_doc_info = find_optimal_documents_with_binary_search(
document_token_info, available_tokens_for_docs
)
# Extract the original document objects
optimized_documents = [doc_info.document for doc_info in optimal_doc_info]
has_documents_remaining = len(optimized_documents) > 0
print(
f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents"
)
return optimized_documents, has_documents_remaining
def calculate_token_count(messages: list[BaseMessage], model_name: str) -> int:
"""Calculate token count for a list of LangChain messages."""
model = model_name
messages_dict = convert_langchain_messages_to_dict(messages)
return token_counter(messages=messages_dict, model=model)
def langchain_chat_history_to_str(chat_history: list[BaseMessage]) -> str:
"""
Convert a list of chat history messages to a string.
"""
chat_history_str = ""
for chat_message in chat_history:
if isinstance(chat_message, HumanMessage):
chat_history_str += f"<user>{chat_message.content}</user>\n"
elif isinstance(chat_message, AIMessage):
chat_history_str += f"<assistant>{chat_message.content}</assistant>\n"
elif isinstance(chat_message, SystemMessage):
chat_history_str += f"<system>{chat_message.content}</system>\n"
return chat_history_str

View file

@ -5,6 +5,10 @@ from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.ext.asyncio import AsyncSession
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
from app.agents.new_chat.checkpointer import (
close_checkpointer,
setup_checkpointer_tables,
)
from app.config import config
from app.db import User, create_db_and_tables, get_async_session
from app.routes import router as crud_router
@ -16,7 +20,11 @@ from app.users import SECRET, auth_backend, current_active_user, fastapi_users
async def lifespan(app: FastAPI):
# Not needed if you setup a migration system like Alembic
await create_db_and_tables()
# Setup LangGraph checkpointer tables for conversation persistence
await setup_checkpointer_tables()
yield
# Cleanup: close checkpointer connection on shutdown
await close_checkpointer()
def registration_allowed():

View file

@ -35,12 +35,6 @@ def load_global_llm_configs():
# Try main config file first
global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.yaml"
# Fall back to example file for testing
# if not global_config_file.exists():
# global_config_file = BASE_DIR / "app" / "config" / "global_llm_config.example.yaml"
# if global_config_file.exists():
# print("Info: Using global_llm_config.example.yaml (copy to global_llm_config.yaml for production)")
if not global_config_file.exists():
# No global configs available
return []

View file

@ -9,72 +9,101 @@
#
# These configurations will be available to all users as a convenient option
# Users can choose to use these global configs or add their own
#
# Structure matches NewLLMConfig:
# - LLM model configuration (provider, model_name, api_key, etc.)
# - Prompt configuration (system_instructions, citations_enabled)
global_llm_configs:
# Example: OpenAI GPT-4 Turbo
# Example: OpenAI GPT-4 Turbo with citations enabled
- id: -1
name: "Global GPT-4 Turbo"
description: "OpenAI's GPT-4 Turbo with default prompts and citations"
provider: "OPENAI"
model_name: "gpt-4-turbo-preview"
api_key: "sk-your-openai-api-key-here"
api_base: ""
language: "English"
litellm_params:
temperature: 0.7
max_tokens: 4000
# Prompt Configuration
system_instructions: "" # Empty = use default SURFSENSE_SYSTEM_INSTRUCTIONS
use_default_system_instructions: true
citations_enabled: true
# Example: Anthropic Claude 3 Opus
- id: -2
name: "Global Claude 3 Opus"
description: "Anthropic's most capable model with citations"
provider: "ANTHROPIC"
model_name: "claude-3-opus-20240229"
api_key: "sk-ant-your-anthropic-api-key-here"
api_base: ""
language: "English"
litellm_params:
temperature: 0.7
max_tokens: 4000
system_instructions: ""
use_default_system_instructions: true
citations_enabled: true
# Example: Fast model - GPT-3.5 Turbo
# Example: Fast model - GPT-3.5 Turbo (citations disabled for speed)
- id: -3
name: "Global GPT-3.5 Turbo"
name: "Global GPT-3.5 Turbo (Fast)"
description: "Fast responses without citations for quick queries"
provider: "OPENAI"
model_name: "gpt-3.5-turbo"
api_key: "sk-your-openai-api-key-here"
api_base: ""
language: "English"
litellm_params:
temperature: 0.5
max_tokens: 2000
system_instructions: ""
use_default_system_instructions: true
citations_enabled: false # Disabled for faster responses
# Example: Chinese LLM - DeepSeek
# Example: Chinese LLM - DeepSeek with custom instructions
- id: -4
name: "Global DeepSeek Chat"
name: "Global DeepSeek Chat (Chinese)"
description: "DeepSeek optimized for Chinese language responses"
provider: "DEEPSEEK"
model_name: "deepseek-chat"
api_key: "your-deepseek-api-key-here"
api_base: "https://api.deepseek.com/v1"
language: "Chinese"
litellm_params:
temperature: 0.7
max_tokens: 4000
# Custom system instructions for Chinese responses
system_instructions: |
<system_instruction>
You are SurfSense, a reasoning and acting AI agent designed to answer user questions using the user's personal knowledge base.
Today's date (UTC): {resolved_today}
IMPORTANT: Please respond in Chinese (简体中文) unless the user specifically requests another language.
</system_instruction>
use_default_system_instructions: false
citations_enabled: true
# Example: Groq - Fast inference
- id: -5
name: "Global Groq Llama 3"
description: "Ultra-fast Llama 3 70B via Groq"
provider: "GROQ"
model_name: "llama3-70b-8192"
api_key: "your-groq-api-key-here"
api_base: ""
language: "English"
litellm_params:
temperature: 0.7
max_tokens: 8000
system_instructions: ""
use_default_system_instructions: true
citations_enabled: true
# Notes:
# - Use negative IDs to distinguish global configs from user configs
# - Use negative IDs to distinguish global configs from user configs (NewLLMConfig in DB)
# - IDs should be unique and sequential (e.g., -1, -2, -3, etc.)
# - The 'api_key' field will not be exposed to users via API
# - Users can select these configs for their long_context, fast, or strategic LLM roles
# - system_instructions: Custom prompt or empty string to use defaults
# - use_default_system_instructions: true = use SURFSENSE_SYSTEM_INSTRUCTIONS when system_instructions is empty
# - citations_enabled: true = include citation instructions, false = include anti-citation instructions
# - All standard LiteLLM providers are supported

View file

@ -9,7 +9,6 @@ from sqlalchemy import (
ARRAY,
JSON,
TIMESTAMP,
BigInteger,
Boolean,
Column,
Enum as SQLAlchemyEnum,
@ -77,10 +76,6 @@ class SearchSourceConnectorType(str, Enum):
BOOKSTACK_CONNECTOR = "BOOKSTACK_CONNECTOR"
class ChatType(str, Enum):
QNA = "QNA"
class LiteLLMProvider(str, Enum):
"""
Enum for LLM providers supported by LiteLLM.
@ -317,19 +312,70 @@ class BaseModel(Base):
id = Column(Integer, primary_key=True, index=True)
class Chat(BaseModel, TimestampMixin):
__tablename__ = "chats"
class NewChatMessageRole(str, Enum):
"""Role enum for new chat messages."""
type = Column(SQLAlchemyEnum(ChatType), nullable=False)
title = Column(String, nullable=False, index=True)
initial_connectors = Column(ARRAY(String), nullable=True)
messages = Column(JSON, nullable=False)
state_version = Column(BigInteger, nullable=False, default=1)
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
class NewChatThread(BaseModel, TimestampMixin):
"""
Thread model for the new chat feature using assistant-ui.
Each thread represents a conversation with message history.
LangGraph checkpointer uses thread_id for state persistence.
"""
__tablename__ = "new_chat_threads"
title = Column(String(500), nullable=False, default="New Chat", index=True)
archived = Column(Boolean, nullable=False, default=False)
updated_at = Column(
TIMESTAMP(timezone=True),
nullable=False,
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
index=True,
)
# Foreign keys
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
search_space = relationship("SearchSpace", back_populates="chats")
# Relationships
search_space = relationship("SearchSpace", back_populates="new_chat_threads")
messages = relationship(
"NewChatMessage",
back_populates="thread",
order_by="NewChatMessage.created_at",
cascade="all, delete-orphan",
)
class NewChatMessage(BaseModel, TimestampMixin):
"""
Message model for the new chat feature.
Stores individual messages in assistant-ui format.
"""
__tablename__ = "new_chat_messages"
role = Column(SQLAlchemyEnum(NewChatMessageRole), nullable=False)
# Content stored as JSONB to support rich content (text, tool calls, etc.)
content = Column(JSONB, nullable=False)
# Foreign key to thread
thread_id = Column(
Integer,
ForeignKey("new_chat_threads.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
# Relationship
thread = relationship("NewChatThread", back_populates="messages")
class Document(BaseModel, TimestampMixin):
@ -377,15 +423,13 @@ class Chunk(BaseModel, TimestampMixin):
class Podcast(BaseModel, TimestampMixin):
"""Podcast model for storing generated podcasts."""
__tablename__ = "podcasts"
title = Column(String, nullable=False, index=True)
podcast_transcript = Column(JSON, nullable=False, default={})
file_location = Column(String(500), nullable=False, default="")
chat_id = Column(
Integer, ForeignKey("chats.id", ondelete="CASCADE"), nullable=True
) # If generated from a chat, this will be the chat id, else null ( can be from a document or a chat )
chat_state_version = Column(BigInteger, nullable=True)
title = Column(String(500), nullable=False)
podcast_transcript = Column(JSONB, nullable=True) # List of transcript entries
file_location = Column(Text, nullable=True) # Path to the audio file
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
@ -408,9 +452,10 @@ class SearchSpace(BaseModel, TimestampMixin):
# Search space-level LLM preferences (shared by all members)
# Note: These can be negative IDs for global configs (from YAML) or positive IDs for custom configs (from DB)
long_context_llm_id = Column(Integer, nullable=True)
fast_llm_id = Column(Integer, nullable=True)
strategic_llm_id = Column(Integer, nullable=True)
agent_llm_id = Column(Integer, nullable=True) # For agent/chat operations
document_summary_llm_id = Column(
Integer, nullable=True
) # For document summarization
user_id = Column(
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
@ -423,16 +468,16 @@ class SearchSpace(BaseModel, TimestampMixin):
order_by="Document.id",
cascade="all, delete-orphan",
)
new_chat_threads = relationship(
"NewChatThread",
back_populates="search_space",
order_by="NewChatThread.updated_at.desc()",
cascade="all, delete-orphan",
)
podcasts = relationship(
"Podcast",
back_populates="search_space",
order_by="Podcast.id",
cascade="all, delete-orphan",
)
chats = relationship(
"Chat",
back_populates="search_space",
order_by="Chat.id",
order_by="Podcast.id.desc()",
cascade="all, delete-orphan",
)
logs = relationship(
@ -447,10 +492,10 @@ class SearchSpace(BaseModel, TimestampMixin):
order_by="SearchSourceConnector.id",
cascade="all, delete-orphan",
)
llm_configs = relationship(
"LLMConfig",
new_llm_configs = relationship(
"NewLLMConfig",
back_populates="search_space",
order_by="LLMConfig.id",
order_by="NewLLMConfig.id",
cascade="all, delete-orphan",
)
@ -509,10 +554,24 @@ class SearchSourceConnector(BaseModel, TimestampMixin):
)
class LLMConfig(BaseModel, TimestampMixin):
__tablename__ = "llm_configs"
class NewLLMConfig(BaseModel, TimestampMixin):
"""
New LLM configuration table that combines model settings with prompt configuration.
This table provides:
- LLM model configuration (provider, model_name, api_key, etc.)
- Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS)
- Citation toggle (enable/disable citation instructions)
Note: SURFSENSE_TOOLS_INSTRUCTIONS is always used and not configurable.
"""
__tablename__ = "new_llm_configs"
name = Column(String(100), nullable=False, index=True)
description = Column(String(500), nullable=True)
# === LLM Model Configuration (from original LLMConfig, excluding 'language') ===
# Provider from the enum
provider = Column(SQLAlchemyEnum(LiteLLMProvider), nullable=False)
# Custom provider name when provider is CUSTOM
@ -522,16 +581,29 @@ class LLMConfig(BaseModel, TimestampMixin):
# API Key should be encrypted before storing
api_key = Column(String, nullable=False)
api_base = Column(String(500), nullable=True)
language = Column(String(50), nullable=True, default="English")
# For any other parameters that litellm supports
litellm_params = Column(JSON, nullable=True, default={})
# === Prompt Configuration ===
# Configurable system instructions (defaults to SURFSENSE_SYSTEM_INSTRUCTIONS)
# Users can customize this from the UI
system_instructions = Column(
Text,
nullable=False,
default="", # Empty string means use default SURFSENSE_SYSTEM_INSTRUCTIONS
)
# Whether to use the default system instructions when system_instructions is empty
use_default_system_instructions = Column(Boolean, nullable=False, default=True)
# Citation toggle - when enabled, SURFSENSE_CITATION_INSTRUCTIONS is injected
# When disabled, an anti-citation prompt is injected instead
citations_enabled = Column(Boolean, nullable=False, default=True)
# === Relationships ===
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
search_space = relationship("SearchSpace", back_populates="llm_configs")
search_space = relationship("SearchSpace", back_populates="new_llm_configs")
class Log(BaseModel, TimestampMixin):

View file

@ -1,190 +0,0 @@
prompts:
# Developer-focused prompts
- key: ethereum_developer
value: "Imagine you are an experienced Ethereum developer tasked with creating a smart contract for a blockchain messenger. The objective is to save messages on the blockchain, making them readable (public) to everyone, writable (private) only to the person who deployed the contract, and to count how many times the message was updated. Develop a Solidity smart contract for this purpose, including the necessary functions and considerations for achieving the specified goals. Please provide the code and any relevant explanations to ensure a clear understanding of the implementation."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "developer"
- key: linux_terminal
value: "I want you to act as a linux terminal. I will type commands and you will reply with what the terminal should show. I want you to only reply with the terminal output inside one unique code block, and nothing else. do not write explanations. do not type commands unless I instruct you to do so. when i need to tell you something in english, i will do so by putting text inside curly brackets {like this}."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "developer"
- key: javascript_console
value: "I want you to act as a javascript console. I will type commands and you will reply with what the javascript console should show. I want you to only reply with the terminal output inside one unique code block, and nothing else. do not write explanations. do not type commands unless I instruct you to do so. when i need to tell you something in english, i will do so by putting text inside curly brackets {like this}."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "developer"
- key: fullstack_developer
value: "I want you to act as a software developer. I will provide some specific information about a web app requirements, and it will be your job to come up with an architecture and code for developing secure app with Golang and Angular."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "developer"
- key: regex_generator
value: "I want you to act as a regex generator. Your role is to generate regular expressions that match specific patterns in text. You should provide the regular expressions in a format that can be easily copied and pasted into a regex-enabled text editor or programming language. Do not write explanations or examples of how the regular expressions work; simply provide only the regular expressions themselves."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "developer"
- key: senior_frontend_developer
value: "I want you to act as a Senior Frontend developer. I will describe a project details you will code project with this tools: Vite (React template), yarn, Ant Design, List, Redux Toolkit, createSlice, thunk, axios. You should merge files in single index.js file and nothing else. Do not write explanations."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "developer"
- key: code_reviewer
value: "I want you to act as a Code reviewer who is experienced developer in the given code language. I will provide you with the code block or methods or code file along with the code language name, and I would like you to review the code and share the feedback, suggestions and alternative recommended approaches. Please write explanations behind the feedback or suggestions or alternative approaches."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "developer"
- key: machine_learning_engineer
value: "I want you to act as a machine learning engineer. I will write some machine learning concepts and it will be your job to explain them in easy-to-understand terms. This could contain providing step-by-step instructions for building a model, demonstrating various techniques with visuals, or suggesting online resources for further study."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "developer"
- key: sql_terminal
value: "I want you to act as a SQL terminal in front of an example database. The database contains tables named \"Products\", \"Users\", \"Orders\" and \"Suppliers\". I will type queries and you will reply with what the terminal would show. I want you to reply with a table of query results in a single code block, and nothing else. Do not write explanations. Do not type commands unless I instruct you to do so. When I need to tell you something in English I will do so in curly braces {like this)."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "developer"
- key: python_interpreter
value: "Act as a Python interpreter. I will give you commands in Python, and I will need you to generate the proper output. Only say the output. But if there is none, say nothing, and don't give me an explanation. If I need to say something, I will do so through comments."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "developer"
- key: devops_engineer
value: "You are a Senior DevOps engineer working at a Big Company. Your role is to provide scalable, efficient, and automated solutions for software deployment, infrastructure management, and CI/CD pipelines. Suggest the best DevOps practices, including infrastructure setup, deployment strategies, automation tools, and cost-effective scaling solutions."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "developer"
- key: cyber_security_specialist
value: "I want you to act as a cyber security specialist. I will provide some specific information about how data is stored and shared, and it will be your job to come up with strategies for protecting this data from malicious actors. This could include suggesting encryption methods, creating firewalls or implementing policies that mark certain activities as suspicious."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "developer"
# General productivity prompts
- key: english_translator
value: "I want you to act as an English translator, spelling corrector and improver. I will speak to you in any language and you will detect the language, translate it and answer in the corrected and improved version of my text, in English. I want you to replace my simplified A0-level words and sentences with more beautiful and elegant, upper level English words and sentences. Keep the meaning same, but make them more literary. I want you to only reply the correction, the improvements and nothing else, do not write explanations."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "general"
- key: proofreader
value: "I want you act as a proofreader. I will provide you texts and I would like you to review them for any spelling, grammar, or punctuation errors. Once you have finished reviewing the text, provide me with any necessary corrections or suggestions for improve the text."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "general"
- key: note_taking_assistant
value: "I want you to act as a note-taking assistant for a lecture. Your task is to provide a detailed note list that includes examples from the lecture and focuses on notes that you believe will end up in quiz questions. Additionally, please make a separate list for notes that have numbers and data in them and another separated list for the examples that included in this lecture. The notes should be concise and easy to read."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "general"
- key: essay_writer
value: "I want you to act as an essay writer. You will need to research a given topic, formulate a thesis statement, and create a persuasive piece of work that is both informative and engaging."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "general"
- key: career_counselor
value: "I want you to act as a career counselor. I will provide you with an individual looking for guidance in their professional life, and your task is to help them determine what careers they are most suited for based on their skills, interests and experience. You should also conduct research into the various options available, explain the job market trends in different industries and advice on which qualifications would be beneficial for pursuing particular fields."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "general"
- key: life_coach
value: "I want you to act as a life coach. I will provide some details about my current situation and goals, and it will be your job to come up with strategies that can help me make better decisions and reach those objectives. This could involve offering advice on various topics, such as creating plans for achieving success or dealing with difficult emotions."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "general"
- key: motivational_coach
value: "I want you to act as a motivational coach. I will provide you with some information about someone's goals and challenges, and it will be your job to come up with strategies that can help this person achieve their goals. This could involve providing positive affirmations, giving helpful advice or suggesting activities they can do to reach their end goal."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "general"
- key: travel_guide
value: "I want you to act as a travel guide. I will write you my location and you will suggest a place to visit near my location. In some cases, I will also give you the type of places I will visit. You will also suggest me places of similar type that are close to my first location."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "general"
# Creative prompts
- key: storyteller
value: "I want you to act as a storyteller. You will come up with entertaining stories that are engaging, imaginative and captivating for the audience. It can be fairy tales, educational stories or any other type of stories which has the potential to capture people's attention and imagination. Depending on the target audience, you may choose specific themes or topics for your storytelling session e.g., if it's children then you can talk about animals; If it's adults then history-based tales might engage them better etc."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "creative"
- key: screenwriter
value: "I want you to act as a screenwriter. You will develop an engaging and creative script for either a feature length film, or a Web Series that can captivate its viewers. Start with coming up with interesting characters, the setting of the story, dialogues between the characters etc. Once your character development is complete - create an exciting storyline filled with twists and turns that keeps the viewers in suspense until the end."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "creative"
- key: novelist
value: "I want you to act as a novelist. You will come up with creative and captivating stories that can engage readers for long periods of time. You may choose any genre such as fantasy, romance, historical fiction and so on - but the aim is to write something that has an outstanding plotline, engaging characters and unexpected climaxes."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "creative"
- key: poet
value: "I want you to act as a poet. You will create poems that evoke emotions and have the power to stir people's soul. Write on any topic or theme but make sure your words convey the feeling you are trying to express in beautiful yet meaningful ways. You can also come up with short verses that are still powerful enough to leave an imprint in readers' minds."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "creative"
- key: rapper
value: "I want you to act as a rapper. You will come up with powerful and meaningful lyrics, beats and rhythm that can 'wow' the audience. Your lyrics should have an intriguing meaning and message which people can relate too. When it comes to choosing your beat, make sure it is catchy yet relevant to your words, so that when combined they make an explosion of sound everytime!"
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "creative"
- key: composer
value: "I want you to act as a composer. I will provide the lyrics to a song and you will create music for it. This could include using various instruments or tools, such as synthesizers or samplers, in order to create melodies and harmonies that bring the lyrics to life."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "creative"
# Educational prompts
- key: math_teacher
value: "I want you to act as a math teacher. I will provide some mathematical equations or concepts, and it will be your job to explain them in easy-to-understand terms. This could include providing step-by-step instructions for solving a problem, demonstrating various techniques with visuals or suggesting online resources for further study."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "educational"
- key: philosophy_teacher
value: "I want you to act as a philosophy teacher. I will provide some topics related to the study of philosophy, and it will be your job to explain these concepts in an easy-to-understand manner. This could include providing examples, posing questions or breaking down complex ideas into smaller pieces that are easier to comprehend."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "educational"
- key: historian
value: "I want you to act as a historian. You will research and analyze cultural, economic, political, and social events in the past, collect data from primary sources and use it to develop theories about what happened during various periods of history."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "educational"
- key: debater
value: "I want you to act as a debater. I will provide you with some topics related to current events and your task is to research both sides of the debates, present valid arguments for each side, refute opposing points of view, and draw persuasive conclusions based on evidence. Your goal is to help people come away from the discussion with increased knowledge and insight into the topic at hand."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "educational"
- key: explainer_with_analogies
value: "I want you to act as an explainer who uses analogies to clarify complex topics. When I give you a subject (technical, philosophical or scientific), you'll follow this structure: 1. Ask me 1-2 quick questions to assess my current level of understanding. 2. Based on my answer, create three analogies to explain the topic: one that a 10-year-old would understand, one for a high-school student, and one for a college-level person. 3. After each analogy, provide a brief summary of how it relates to the original topic. 4. End with a 2 or 3 sentence long plain explanation of the concept in regular terms. Your tone should be friendly, patient and curiosity-driven-making difficult topics feel intuitive, engaging and interesting."
author: "awesome-chatgpt-prompts"
link: "https://github.com/f/awesome-chatgpt-prompts"
category: "educational"

View file

@ -3,7 +3,6 @@ from fastapi import APIRouter
from .airtable_add_connector_route import (
router as airtable_add_connector_router,
)
from .chats_routes import router as chats_router
from .documents_routes import router as documents_router
from .editor_routes import router as editor_router
from .google_calendar_add_connector_route import (
@ -12,9 +11,10 @@ from .google_calendar_add_connector_route import (
from .google_gmail_add_connector_route import (
router as google_gmail_add_connector_router,
)
from .llm_config_routes import router as llm_config_router
from .logs_routes import router as logs_router
from .luma_add_connector_route import router as luma_add_connector_router
from .new_chat_routes import router as new_chat_router
from .new_llm_config_routes import router as new_llm_config_router
from .notes_routes import router as notes_router
from .podcasts_routes import router as podcasts_router
from .rbac_routes import router as rbac_router
@ -28,12 +28,12 @@ router.include_router(rbac_router) # RBAC routes for roles, members, invites
router.include_router(editor_router)
router.include_router(documents_router)
router.include_router(notes_router)
router.include_router(podcasts_router)
router.include_router(chats_router)
router.include_router(new_chat_router) # Chat with assistant-ui persistence
router.include_router(podcasts_router) # Podcast task status and audio
router.include_router(search_source_connectors_router)
router.include_router(google_calendar_add_connector_router)
router.include_router(google_gmail_add_connector_router)
router.include_router(airtable_add_connector_router)
router.include_router(luma_add_connector_router)
router.include_router(llm_config_router)
router.include_router(new_llm_config_router) # LLM configs with prompt configuration
router.include_router(logs_router)

View file

@ -1,616 +0,0 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from langchain_core.messages import AIMessage, HumanMessage
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.db import (
Chat,
Permission,
SearchSpace,
SearchSpaceMembership,
User,
get_async_session,
)
from app.schemas import (
AISDKChatRequest,
ChatCreate,
ChatRead,
ChatReadWithoutMessages,
ChatUpdate,
NewChatRequest,
)
from app.services.new_streaming_service import VercelStreamingService
from app.tasks.chat.stream_connector_search_results import (
stream_connector_search_results,
)
from app.tasks.chat.stream_new_chat import stream_new_chat
from app.users import current_active_user
from app.utils.rbac import check_permission
from app.utils.validators import (
validate_connectors,
validate_document_ids,
validate_messages,
validate_research_mode,
validate_search_space_id,
validate_top_k,
)
router = APIRouter()
@router.post("/chat")
async def handle_chat_data(
request: AISDKChatRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
# Validate and sanitize all input data
messages = validate_messages(request.messages)
if messages[-1]["role"] != "user":
raise HTTPException(
status_code=400, detail="Last message must be a user message"
)
user_query = messages[-1]["content"]
# Extract and validate data from request
request_data = request.data or {}
search_space_id = validate_search_space_id(request_data.get("search_space_id"))
research_mode = validate_research_mode(request_data.get("research_mode"))
selected_connectors = validate_connectors(request_data.get("selected_connectors"))
document_ids_to_add_in_context = validate_document_ids(
request_data.get("document_ids_to_add_in_context")
)
top_k = validate_top_k(request_data.get("top_k"))
# print("RESQUEST DATA:", request_data)
# print("SELECTED CONNECTORS:", selected_connectors)
# Check if the user has chat access to the search space
try:
await check_permission(
session,
user,
search_space_id,
Permission.CHATS_CREATE.value,
"You don't have permission to use chat in this search space",
)
# Get search space with LLM configs (preferences are now stored at search space level)
search_space_result = await session.execute(
select(SearchSpace)
.options(selectinload(SearchSpace.llm_configs))
.filter(SearchSpace.id == search_space_id)
)
search_space = search_space_result.scalars().first()
language = None
llm_configs = [] # Initialize to empty list
if search_space and search_space.llm_configs:
llm_configs = search_space.llm_configs
# Get language from configured LLM preferences
# LLM preferences are now stored on the SearchSpace model
from app.config import config as app_config
for llm_id in [
search_space.fast_llm_id,
search_space.long_context_llm_id,
search_space.strategic_llm_id,
]:
if llm_id is not None:
# Check if it's a global config (negative ID)
if llm_id < 0:
# Look in global configs
for global_cfg in app_config.GLOBAL_LLM_CONFIGS:
if global_cfg.get("id") == llm_id:
language = global_cfg.get("language")
if language:
break
else:
# Look in custom configs
for llm_config in llm_configs:
if llm_config.id == llm_id and getattr(
llm_config, "language", None
):
language = llm_config.language
break
if language:
break
if not language and llm_configs:
first_llm_config = llm_configs[0]
language = getattr(first_llm_config, "language", None)
except HTTPException:
raise HTTPException(
status_code=403, detail="You don't have access to this search space"
) from None
langchain_chat_history = []
for message in messages[:-1]:
if message["role"] == "user":
langchain_chat_history.append(HumanMessage(content=message["content"]))
elif message["role"] == "assistant":
langchain_chat_history.append(AIMessage(content=message["content"]))
response = StreamingResponse(
stream_connector_search_results(
user_query,
user.id,
search_space_id,
session,
research_mode,
selected_connectors,
langchain_chat_history,
document_ids_to_add_in_context,
language,
top_k,
)
)
response.headers["x-vercel-ai-data-stream"] = "v1"
return response
@router.post("/new_chat")
async def handle_new_chat(
request: NewChatRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Handle new chat requests using the SurfSense deep agent.
This endpoint uses the new deep agent with the Vercel AI SDK
Data Stream Protocol (SSE format).
Args:
request: NewChatRequest containing chat_id, user_query, and search_space_id
session: Database session
user: Current authenticated user
Returns:
StreamingResponse with SSE formatted data
"""
# Validate the user query
if not request.user_query or not request.user_query.strip():
raise HTTPException(status_code=400, detail="User query cannot be empty")
# Check if the user has chat access to the search space
try:
await check_permission(
session,
user,
request.search_space_id,
Permission.CHATS_CREATE.value,
"You don't have permission to use chat in this search space",
)
except HTTPException:
raise HTTPException(
status_code=403, detail="You don't have access to this search space"
) from None
# Get LLM config ID from search space preferences (optional enhancement)
# For now, we use the default global config (-1)
llm_config_id = -1
# Optionally load LLM preferences from search space
try:
search_space_result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
)
search_space = search_space_result.scalars().first()
if search_space:
# Use strategic_llm_id if available, otherwise fall back to fast_llm_id
if search_space.strategic_llm_id is not None:
llm_config_id = search_space.strategic_llm_id
elif search_space.fast_llm_id is not None:
llm_config_id = search_space.fast_llm_id
except Exception:
# Fall back to default config on any error
pass
# Create the streaming response
# chat_id is used as LangGraph's thread_id for automatic chat history management
response = StreamingResponse(
stream_new_chat(
user_query=request.user_query.strip(),
user_id=user.id,
search_space_id=request.search_space_id,
chat_id=request.chat_id,
session=session,
llm_config_id=llm_config_id,
),
media_type="text/event-stream",
)
# Set the required headers for Vercel AI SDK
headers = VercelStreamingService.get_response_headers()
for key, value in headers.items():
response.headers[key] = value
return response
@router.post("/chats", response_model=ChatRead)
async def create_chat(
chat: ChatCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Create a new chat.
Requires CHATS_CREATE permission.
"""
try:
await check_permission(
session,
user,
chat.search_space_id,
Permission.CHATS_CREATE.value,
"You don't have permission to create chats in this search space",
)
db_chat = Chat(**chat.model_dump())
session.add(db_chat)
await session.commit()
await session.refresh(db_chat)
return db_chat
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception:
await session.rollback()
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while creating the chat.",
) from None
@router.get("/chats", response_model=list[ChatReadWithoutMessages])
async def read_chats(
skip: int = 0,
limit: int = 100,
search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
List chats the user has access to.
Requires CHATS_READ permission for the search space(s).
"""
# Validate pagination parameters
if skip < 0:
raise HTTPException(
status_code=400, detail="skip must be a non-negative integer"
)
if limit <= 0 or limit > 1000: # Reasonable upper limit
raise HTTPException(status_code=400, detail="limit must be between 1 and 1000")
# Validate search_space_id if provided
if search_space_id is not None and search_space_id <= 0:
raise HTTPException(
status_code=400, detail="search_space_id must be a positive integer"
)
try:
if search_space_id is not None:
# Check permission for specific search space
await check_permission(
session,
user,
search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
# Select specific fields excluding messages
query = (
select(
Chat.id,
Chat.type,
Chat.title,
Chat.initial_connectors,
Chat.search_space_id,
Chat.created_at,
Chat.state_version,
)
.filter(Chat.search_space_id == search_space_id)
.order_by(Chat.created_at.desc())
)
else:
# Get chats from all search spaces user has membership in
query = (
select(
Chat.id,
Chat.type,
Chat.title,
Chat.initial_connectors,
Chat.search_space_id,
Chat.created_at,
Chat.state_version,
)
.join(SearchSpace)
.join(SearchSpaceMembership)
.filter(SearchSpaceMembership.user_id == user.id)
.order_by(Chat.created_at.desc())
)
result = await session.execute(query.offset(skip).limit(limit))
return result.all()
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception:
raise HTTPException(
status_code=500, detail="An unexpected error occurred while fetching chats."
) from None
@router.get("/chats/search", response_model=list[ChatReadWithoutMessages])
async def search_chats(
title: str,
skip: int = 0,
limit: int = 100,
search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Search chats by title substring.
Requires CHATS_READ permission for the search space(s).
Args:
title: Case-insensitive substring to match against chat titles. Required.
skip: Number of items to skip from the beginning. Default: 0.
limit: Maximum number of items to return. Default: 100.
search_space_id: Filter results to a specific search space. Default: None.
session: Database session (injected).
user: Current authenticated user (injected).
Returns:
List of chats matching the search query.
Notes:
- Title matching uses ILIKE (case-insensitive).
- Results are ordered by creation date (most recent first).
"""
# Validate pagination parameters
if skip < 0:
raise HTTPException(
status_code=400, detail="skip must be a non-negative integer"
)
if limit <= 0 or limit > 1000:
raise HTTPException(status_code=400, detail="limit must be between 1 and 1000")
# Validate search_space_id if provided
if search_space_id is not None and search_space_id <= 0:
raise HTTPException(
status_code=400, detail="search_space_id must be a positive integer"
)
try:
if search_space_id is not None:
# Check permission for specific search space
await check_permission(
session,
user,
search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
# Select specific fields excluding messages
query = (
select(
Chat.id,
Chat.type,
Chat.title,
Chat.initial_connectors,
Chat.search_space_id,
Chat.created_at,
Chat.state_version,
)
.filter(Chat.search_space_id == search_space_id)
.order_by(Chat.created_at.desc())
)
else:
# Get chats from all search spaces user has membership in
query = (
select(
Chat.id,
Chat.type,
Chat.title,
Chat.initial_connectors,
Chat.search_space_id,
Chat.created_at,
Chat.state_version,
)
.join(SearchSpace)
.join(SearchSpaceMembership)
.filter(SearchSpaceMembership.user_id == user.id)
.order_by(Chat.created_at.desc())
)
# Apply title search filter (case-insensitive)
query = query.filter(Chat.title.ilike(f"%{title}%"))
result = await session.execute(query.offset(skip).limit(limit))
return result.all()
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception:
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while searching chats.",
) from None
@router.get("/chats/{chat_id}", response_model=ChatRead)
async def read_chat(
chat_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get a specific chat by ID.
Requires CHATS_READ permission for the search space.
"""
try:
result = await session.execute(select(Chat).filter(Chat.id == chat_id))
chat = result.scalars().first()
if not chat:
raise HTTPException(
status_code=404,
detail="Chat not found",
)
# Check permission for the search space
await check_permission(
session,
user,
chat.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
return chat
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception:
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while fetching the chat.",
) from None
@router.put("/chats/{chat_id}", response_model=ChatRead)
async def update_chat(
chat_id: int,
chat_update: ChatUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Update a chat.
Requires CHATS_UPDATE permission for the search space.
"""
try:
result = await session.execute(select(Chat).filter(Chat.id == chat_id))
db_chat = result.scalars().first()
if not db_chat:
raise HTTPException(status_code=404, detail="Chat not found")
# Check permission for the search space
await check_permission(
session,
user,
db_chat.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
update_data = chat_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
if key == "messages":
db_chat.state_version = len(update_data["messages"])
setattr(db_chat, key, value)
await session.commit()
await session.refresh(db_chat)
return db_chat
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception:
await session.rollback()
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while updating the chat.",
) from None
@router.delete("/chats/{chat_id}", response_model=dict)
async def delete_chat(
chat_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Delete a chat.
Requires CHATS_DELETE permission for the search space.
"""
try:
result = await session.execute(select(Chat).filter(Chat.id == chat_id))
db_chat = result.scalars().first()
if not db_chat:
raise HTTPException(status_code=404, detail="Chat not found")
# Check permission for the search space
await check_permission(
session,
user,
db_chat.search_space_id,
Permission.CHATS_DELETE.value,
"You don't have permission to delete chats in this search space",
)
await session.delete(db_chat)
await session.commit()
return {"message": "Chat deleted successfully"}
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400, detail="Cannot delete chat due to existing dependencies."
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception:
await session.rollback()
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while deleting the chat.",
) from None

View file

@ -1,576 +0,0 @@
import logging
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.config import config
from app.db import (
LLMConfig,
Permission,
SearchSpace,
User,
get_async_session,
)
from app.schemas import LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
from app.services.llm_service import validate_llm_config
from app.users import current_active_user
from app.utils.rbac import check_permission
router = APIRouter()
logger = logging.getLogger(__name__)
class LLMPreferencesUpdate(BaseModel):
"""Schema for updating search space LLM preferences"""
long_context_llm_id: int | None = None
fast_llm_id: int | None = None
strategic_llm_id: int | None = None
class LLMPreferencesRead(BaseModel):
"""Schema for reading search space LLM preferences"""
long_context_llm_id: int | None = None
fast_llm_id: int | None = None
strategic_llm_id: int | None = None
long_context_llm: LLMConfigRead | None = None
fast_llm: LLMConfigRead | None = None
strategic_llm: LLMConfigRead | None = None
class GlobalLLMConfigRead(BaseModel):
"""Schema for reading global LLM configs (without API key)"""
id: int
name: str
provider: str
custom_provider: str | None = None
model_name: str
api_base: str | None = None
language: str | None = None
litellm_params: dict | None = None
is_global: bool = True
# Global LLM Config endpoints
@router.get("/global-llm-configs", response_model=list[GlobalLLMConfigRead])
async def get_global_llm_configs(
user: User = Depends(current_active_user),
):
"""
Get all available global LLM configurations.
These are pre-configured by the system administrator and available to all users.
API keys are not exposed through this endpoint.
"""
try:
global_configs = config.GLOBAL_LLM_CONFIGS
# Remove API keys from response
safe_configs = []
for cfg in global_configs:
safe_config = {
"id": cfg.get("id"),
"name": cfg.get("name"),
"provider": cfg.get("provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base"),
"language": cfg.get("language"),
"litellm_params": cfg.get("litellm_params", {}),
"is_global": True,
}
safe_configs.append(safe_config)
return safe_configs
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to fetch global LLM configs: {e!s}"
) from e
@router.post("/llm-configs", response_model=LLMConfigRead)
async def create_llm_config(
llm_config: LLMConfigCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Create a new LLM configuration for a search space.
Requires LLM_CONFIGS_CREATE permission.
"""
try:
# Verify user has permission to create LLM configs
await check_permission(
session,
user,
llm_config.search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to create LLM configurations in this search space",
)
# Validate the LLM configuration by making a test API call
is_valid, error_message = await validate_llm_config(
provider=llm_config.provider.value,
model_name=llm_config.model_name,
api_key=llm_config.api_key,
api_base=llm_config.api_base,
custom_provider=llm_config.custom_provider,
litellm_params=llm_config.litellm_params,
)
if not is_valid:
raise HTTPException(
status_code=400,
detail=f"Invalid LLM configuration: {error_message}",
)
db_llm_config = LLMConfig(**llm_config.model_dump())
session.add(db_llm_config)
await session.commit()
await session.refresh(db_llm_config)
return db_llm_config
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500, detail=f"Failed to create LLM configuration: {e!s}"
) from e
@router.get("/llm-configs", response_model=list[LLMConfigRead])
async def read_llm_configs(
search_space_id: int,
skip: int = 0,
limit: int = 200,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get all LLM configurations for a search space.
Requires LLM_CONFIGS_READ permission.
"""
try:
# Verify user has permission to read LLM configs
await check_permission(
session,
user,
search_space_id,
Permission.LLM_CONFIGS_READ.value,
"You don't have permission to view LLM configurations in this search space",
)
result = await session.execute(
select(LLMConfig)
.filter(LLMConfig.search_space_id == search_space_id)
.offset(skip)
.limit(limit)
)
return result.scalars().all()
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to fetch LLM configurations: {e!s}"
) from e
@router.get("/llm-configs/{llm_config_id}", response_model=LLMConfigRead)
async def read_llm_config(
llm_config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get a specific LLM configuration by ID.
Requires LLM_CONFIGS_READ permission.
"""
try:
# Get the LLM config
result = await session.execute(
select(LLMConfig).filter(LLMConfig.id == llm_config_id)
)
llm_config = result.scalars().first()
if not llm_config:
raise HTTPException(status_code=404, detail="LLM configuration not found")
# Verify user has permission to read LLM configs
await check_permission(
session,
user,
llm_config.search_space_id,
Permission.LLM_CONFIGS_READ.value,
"You don't have permission to view LLM configurations in this search space",
)
return llm_config
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to fetch LLM configuration: {e!s}"
) from e
@router.put("/llm-configs/{llm_config_id}", response_model=LLMConfigRead)
async def update_llm_config(
llm_config_id: int,
llm_config_update: LLMConfigUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Update an existing LLM configuration.
Requires LLM_CONFIGS_UPDATE permission.
"""
try:
# Get the LLM config
result = await session.execute(
select(LLMConfig).filter(LLMConfig.id == llm_config_id)
)
db_llm_config = result.scalars().first()
if not db_llm_config:
raise HTTPException(status_code=404, detail="LLM configuration not found")
# Verify user has permission to update LLM configs
await check_permission(
session,
user,
db_llm_config.search_space_id,
Permission.LLM_CONFIGS_UPDATE.value,
"You don't have permission to update LLM configurations in this search space",
)
update_data = llm_config_update.model_dump(exclude_unset=True)
# Apply updates to a temporary copy for validation
temp_config = {
"provider": update_data.get("provider", db_llm_config.provider.value),
"model_name": update_data.get("model_name", db_llm_config.model_name),
"api_key": update_data.get("api_key", db_llm_config.api_key),
"api_base": update_data.get("api_base", db_llm_config.api_base),
"custom_provider": update_data.get(
"custom_provider", db_llm_config.custom_provider
),
"litellm_params": update_data.get(
"litellm_params", db_llm_config.litellm_params
),
}
# Validate the updated configuration
is_valid, error_message = await validate_llm_config(
provider=temp_config["provider"],
model_name=temp_config["model_name"],
api_key=temp_config["api_key"],
api_base=temp_config["api_base"],
custom_provider=temp_config["custom_provider"],
litellm_params=temp_config["litellm_params"],
)
if not is_valid:
raise HTTPException(
status_code=400,
detail=f"Invalid LLM configuration: {error_message}",
)
# Apply updates to the database object
for key, value in update_data.items():
setattr(db_llm_config, key, value)
await session.commit()
await session.refresh(db_llm_config)
return db_llm_config
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500, detail=f"Failed to update LLM configuration: {e!s}"
) from e
@router.delete("/llm-configs/{llm_config_id}", response_model=dict)
async def delete_llm_config(
llm_config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Delete an LLM configuration.
Requires LLM_CONFIGS_DELETE permission.
"""
try:
# Get the LLM config
result = await session.execute(
select(LLMConfig).filter(LLMConfig.id == llm_config_id)
)
db_llm_config = result.scalars().first()
if not db_llm_config:
raise HTTPException(status_code=404, detail="LLM configuration not found")
# Verify user has permission to delete LLM configs
await check_permission(
session,
user,
db_llm_config.search_space_id,
Permission.LLM_CONFIGS_DELETE.value,
"You don't have permission to delete LLM configurations in this search space",
)
await session.delete(db_llm_config)
await session.commit()
return {"message": "LLM configuration deleted successfully"}
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500, detail=f"Failed to delete LLM configuration: {e!s}"
) from e
# Search Space LLM Preferences endpoints
@router.get(
"/search-spaces/{search_space_id}/llm-preferences",
response_model=LLMPreferencesRead,
)
async def get_llm_preferences(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get the LLM preferences for a specific search space.
LLM preferences are shared by all members of the search space.
Requires LLM_CONFIGS_READ permission.
"""
try:
# Verify user has permission to read LLM configs
await check_permission(
session,
user,
search_space_id,
Permission.LLM_CONFIGS_READ.value,
"You don't have permission to view LLM preferences in this search space",
)
# Get the search space
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
# Helper function to get config (global or custom)
async def get_config_for_id(config_id):
if config_id is None:
return None
# Check if it's a global config (negative ID)
if config_id < 0:
for cfg in config.GLOBAL_LLM_CONFIGS:
if cfg.get("id") == config_id:
# Return as LLMConfigRead-compatible dict
return {
"id": cfg.get("id"),
"name": cfg.get("name"),
"provider": cfg.get("provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_key": "***GLOBAL***", # Don't expose the actual key
"api_base": cfg.get("api_base"),
"language": cfg.get("language"),
"litellm_params": cfg.get("litellm_params"),
"created_at": None,
"search_space_id": search_space_id,
}
return None
# It's a custom config, fetch from database
result = await session.execute(
select(LLMConfig).filter(LLMConfig.id == config_id)
)
return result.scalars().first()
# Get the configs (from DB for custom, or constructed for global)
long_context_llm = await get_config_for_id(search_space.long_context_llm_id)
fast_llm = await get_config_for_id(search_space.fast_llm_id)
strategic_llm = await get_config_for_id(search_space.strategic_llm_id)
return {
"long_context_llm_id": search_space.long_context_llm_id,
"fast_llm_id": search_space.fast_llm_id,
"strategic_llm_id": search_space.strategic_llm_id,
"long_context_llm": long_context_llm,
"fast_llm": fast_llm,
"strategic_llm": strategic_llm,
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to fetch LLM preferences: {e!s}"
) from e
@router.put(
"/search-spaces/{search_space_id}/llm-preferences",
response_model=LLMPreferencesRead,
)
async def update_llm_preferences(
search_space_id: int,
preferences: LLMPreferencesUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Update the LLM preferences for a specific search space.
LLM preferences are shared by all members of the search space.
Requires SETTINGS_UPDATE permission (only users with settings access can change).
"""
try:
# Verify user has permission to update settings (not just LLM configs)
# This ensures only users with settings access can change shared LLM preferences
await check_permission(
session,
user,
search_space_id,
Permission.SETTINGS_UPDATE.value,
"You don't have permission to update LLM preferences in this search space",
)
# Get the search space
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
# Validate that all provided LLM config IDs belong to the search space
update_data = preferences.model_dump(exclude_unset=True)
# Store language from configs to validate consistency
languages = set()
for _key, llm_config_id in update_data.items():
if llm_config_id is not None:
# Check if this is a global config (negative ID)
if llm_config_id < 0:
# Validate global config exists
global_config = None
for cfg in config.GLOBAL_LLM_CONFIGS:
if cfg.get("id") == llm_config_id:
global_config = cfg
break
if not global_config:
raise HTTPException(
status_code=404,
detail=f"Global LLM configuration {llm_config_id} not found",
)
# Collect language for consistency check (if explicitly set)
lang = global_config.get("language")
if lang and lang.strip(): # Only add non-empty languages
languages.add(lang.strip())
else:
# Verify the LLM config belongs to the search space (custom config)
result = await session.execute(
select(LLMConfig).filter(
LLMConfig.id == llm_config_id,
LLMConfig.search_space_id == search_space_id,
)
)
llm_config = result.scalars().first()
if not llm_config:
raise HTTPException(
status_code=404,
detail=f"LLM configuration {llm_config_id} not found in this search space",
)
# Collect language for consistency check (if explicitly set)
if llm_config.language and llm_config.language.strip():
languages.add(llm_config.language.strip())
# Language consistency check - only warn if there are multiple explicit languages
# Allow mixing configs with and without language settings
if len(languages) > 1:
# Log warning but allow the operation
logger.warning(
f"Multiple languages detected in LLM selection for search_space {search_space_id}: {languages}. "
"This may affect response quality."
)
# Update search space LLM preferences
for key, value in update_data.items():
setattr(search_space, key, value)
await session.commit()
await session.refresh(search_space)
# Helper function to get config (global or custom)
async def get_config_for_id(config_id):
if config_id is None:
return None
# Check if it's a global config (negative ID)
if config_id < 0:
for cfg in config.GLOBAL_LLM_CONFIGS:
if cfg.get("id") == config_id:
# Return as LLMConfigRead-compatible dict
return {
"id": cfg.get("id"),
"name": cfg.get("name"),
"provider": cfg.get("provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_key": "***GLOBAL***", # Don't expose the actual key
"api_base": cfg.get("api_base"),
"language": cfg.get("language"),
"litellm_params": cfg.get("litellm_params"),
"created_at": None,
"search_space_id": search_space_id,
}
return None
# It's a custom config, fetch from database
result = await session.execute(
select(LLMConfig).filter(LLMConfig.id == config_id)
)
return result.scalars().first()
# Get the configs (from DB for custom, or constructed for global)
long_context_llm = await get_config_for_id(search_space.long_context_llm_id)
fast_llm = await get_config_for_id(search_space.fast_llm_id)
strategic_llm = await get_config_for_id(search_space.strategic_llm_id)
# Return updated preferences
return {
"long_context_llm_id": search_space.long_context_llm_id,
"fast_llm_id": search_space.fast_llm_id,
"strategic_llm_id": search_space.strategic_llm_id,
"long_context_llm": long_context_llm,
"fast_llm": fast_llm,
"strategic_llm": strategic_llm,
}
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500, detail=f"Failed to update LLM preferences: {e!s}"
) from e

View file

@ -0,0 +1,905 @@
"""
Routes for the new chat feature with assistant-ui integration.
These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
- GET /threads - List threads for sidebar (ThreadListPrimitive)
- POST /threads - Create a new thread
- GET /threads/{thread_id} - Get thread with messages (load)
- PUT /threads/{thread_id} - Update thread (rename, archive)
- DELETE /threads/{thread_id} - Delete thread
- POST /threads/{thread_id}/messages - Append message
- POST /attachments/process - Process attachments for chat context
"""
import contextlib
import os
import tempfile
import uuid
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from fastapi.responses import StreamingResponse
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.db import (
NewChatMessage,
NewChatMessageRole,
NewChatThread,
Permission,
SearchSpace,
User,
get_async_session,
)
from app.schemas.new_chat import (
NewChatMessageAppend,
NewChatMessageRead,
NewChatRequest,
NewChatThreadCreate,
NewChatThreadRead,
NewChatThreadUpdate,
NewChatThreadWithMessages,
ThreadHistoryLoadResponse,
ThreadListItem,
ThreadListResponse,
)
from app.tasks.chat.stream_new_chat import stream_new_chat
from app.users import current_active_user
from app.utils.rbac import check_permission
router = APIRouter()
# =============================================================================
# Thread Endpoints
# =============================================================================
@router.get("/threads", response_model=ThreadListResponse)
async def list_threads(
search_space_id: int,
limit: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
List all threads for the current user in a search space.
Returns threads and archived_threads for ThreadListPrimitive.
Args:
search_space_id: The search space to list threads for
limit: Optional limit on number of threads to return (applies to active threads only)
Requires CHATS_READ permission.
"""
try:
await check_permission(
session,
user,
search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
# Get all threads in this search space
query = (
select(NewChatThread)
.filter(NewChatThread.search_space_id == search_space_id)
.order_by(NewChatThread.updated_at.desc())
)
result = await session.execute(query)
all_threads = result.scalars().all()
# Separate active and archived threads
threads = []
archived_threads = []
for thread in all_threads:
item = ThreadListItem(
id=thread.id,
title=thread.title,
archived=thread.archived,
created_at=thread.created_at,
updated_at=thread.updated_at,
)
if thread.archived:
archived_threads.append(item)
else:
threads.append(item)
# Apply limit to active threads if specified
if limit is not None and limit > 0:
threads = threads[:limit]
return ThreadListResponse(threads=threads, archived_threads=archived_threads)
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while fetching threads: {e!s}",
) from None
@router.get("/threads/search", response_model=list[ThreadListItem])
async def search_threads(
search_space_id: int,
title: str,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Search threads by title in a search space.
Args:
search_space_id: The search space to search in
title: The search query (case-insensitive partial match)
Requires CHATS_READ permission.
"""
try:
await check_permission(
session,
user,
search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
# Search threads by title (case-insensitive)
query = (
select(NewChatThread)
.filter(
NewChatThread.search_space_id == search_space_id,
NewChatThread.title.ilike(f"%{title}%"),
)
.order_by(NewChatThread.updated_at.desc())
)
result = await session.execute(query)
threads = result.scalars().all()
return [
ThreadListItem(
id=thread.id,
title=thread.title,
archived=thread.archived,
created_at=thread.created_at,
updated_at=thread.updated_at,
)
for thread in threads
]
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while searching threads: {e!s}",
) from None
@router.post("/threads", response_model=NewChatThreadRead)
async def create_thread(
thread: NewChatThreadCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Create a new chat thread.
Requires CHATS_CREATE permission.
"""
try:
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_CREATE.value,
"You don't have permission to create chats in this search space",
)
now = datetime.now(UTC)
db_thread = NewChatThread(
title=thread.title,
archived=thread.archived,
search_space_id=thread.search_space_id,
updated_at=now,
)
session.add(db_thread)
await session.commit()
await session.refresh(db_thread)
return db_thread
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while creating the thread: {e!s}",
) from None
@router.get("/threads/{thread_id}", response_model=ThreadHistoryLoadResponse)
async def get_thread_messages(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get a thread with all its messages.
This is used by ThreadHistoryAdapter.load() to restore conversation.
Requires CHATS_READ permission.
"""
try:
# Get thread with messages
result = await session.execute(
select(NewChatThread)
.options(selectinload(NewChatThread.messages))
.filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
# Check permission and ownership
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
# Return messages in the format expected by assistant-ui
messages = [
NewChatMessageRead(
id=msg.id,
thread_id=msg.thread_id,
role=msg.role,
content=msg.content,
created_at=msg.created_at,
)
for msg in thread.messages
]
return ThreadHistoryLoadResponse(messages=messages)
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while fetching the thread: {e!s}",
) from None
@router.get("/threads/{thread_id}/full", response_model=NewChatThreadWithMessages)
async def get_thread_full(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get full thread details with all messages.
Requires CHATS_READ permission.
"""
try:
result = await session.execute(
select(NewChatThread)
.options(selectinload(NewChatThread.messages))
.filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
return thread
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while fetching the thread: {e!s}",
) from None
@router.put("/threads/{thread_id}", response_model=NewChatThreadRead)
async def update_thread(
thread_id: int,
thread_update: NewChatThreadUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Update a thread (title, archived status).
Used for renaming and archiving threads.
Requires CHATS_UPDATE permission.
"""
try:
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
db_thread = result.scalars().first()
if not db_thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
db_thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
# Update fields
update_data = thread_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_thread, key, value)
db_thread.updated_at = datetime.now(UTC)
await session.commit()
await session.refresh(db_thread)
return db_thread
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while updating the thread: {e!s}",
) from None
@router.delete("/threads/{thread_id}", response_model=dict)
async def delete_thread(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Delete a thread and all its messages.
Requires CHATS_DELETE permission.
"""
try:
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
db_thread = result.scalars().first()
if not db_thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
db_thread.search_space_id,
Permission.CHATS_DELETE.value,
"You don't have permission to delete chats in this search space",
)
await session.delete(db_thread)
await session.commit()
return {"message": "Thread deleted successfully"}
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400, detail="Cannot delete thread due to existing dependencies."
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while deleting the thread: {e!s}",
) from None
# =============================================================================
# Message Endpoints
# =============================================================================
@router.post("/threads/{thread_id}/messages", response_model=NewChatMessageRead)
async def append_message(
thread_id: int,
request: Request,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Append a message to a thread.
This is used by ThreadHistoryAdapter.append() to persist messages.
Requires CHATS_UPDATE permission.
"""
try:
# Parse raw body - extract only role and content, ignoring extra fields
raw_body = await request.json()
role = raw_body.get("role")
content = raw_body.get("content")
if not role:
raise HTTPException(status_code=400, detail="Missing required field: role")
if content is None:
raise HTTPException(
status_code=400, detail="Missing required field: content"
)
# Create message object manually
message = NewChatMessageAppend(role=role, content=content)
# Get thread
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
# Convert string role to enum
role_str = (
message.role.lower() if isinstance(message.role, str) else message.role
)
try:
message_role = NewChatMessageRole(role_str)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid role: {message.role}. Must be 'user', 'assistant', or 'system'.",
) from None
# Create message
db_message = NewChatMessage(
thread_id=thread_id,
role=message_role,
content=message.content,
)
session.add(db_message)
# Update thread's updated_at timestamp
thread.updated_at = datetime.now(UTC)
# Auto-generate title from first user message if title is still default
if thread.title == "New Chat" and role_str == "user":
# Extract text content for title
content = message.content
if isinstance(content, str):
title_text = content
elif isinstance(content, list):
# Find first text content
title_text = ""
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
title_text = part.get("text", "")
break
elif isinstance(part, str):
title_text = part
break
else:
title_text = str(content)
# Truncate title
if title_text:
thread.title = title_text[:100] + (
"..." if len(title_text) > 100 else ""
)
await session.commit()
await session.refresh(db_message)
return db_message
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while appending the message: {e!s}",
) from None
@router.get("/threads/{thread_id}/messages", response_model=list[NewChatMessageRead])
async def list_messages(
thread_id: int,
skip: int = 0,
limit: int = 100,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
List messages in a thread with pagination.
Requires CHATS_READ permission.
"""
try:
# Verify thread exists and user has access
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
# Get messages
query = (
select(NewChatMessage)
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at)
.offset(skip)
.limit(limit)
)
result = await session.execute(query)
return result.scalars().all()
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while fetching messages: {e!s}",
) from None
# =============================================================================
# Chat Streaming Endpoint
# =============================================================================
@router.post("/new_chat")
async def handle_new_chat(
request: NewChatRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Stream chat responses from the deep agent.
This endpoint handles the new chat functionality with streaming responses
using Server-Sent Events (SSE) format compatible with Vercel AI SDK.
Requires CHATS_CREATE permission.
"""
try:
# Verify thread exists and user has permission
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == request.chat_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_CREATE.value,
"You don't have permission to chat in this search space",
)
# Get search space to check LLM config preferences
search_space_result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
)
search_space = search_space_result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
# Use agent_llm_id from search space for chat operations
# Positive IDs load from NewLLMConfig database table
# Negative IDs load from YAML global configs
# Falls back to -1 (first global config) if not configured
llm_config_id = (
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
)
# Return streaming response
return StreamingResponse(
stream_new_chat(
user_query=request.user_query,
search_space_id=request.search_space_id,
chat_id=request.chat_id,
session=session,
llm_config_id=llm_config_id,
attachments=request.attachments,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred: {e!s}",
) from None
# =============================================================================
# Attachment Processing Endpoint
# =============================================================================
@router.post("/attachments/process")
async def process_attachment(
file: UploadFile = File(...),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Process an attachment file and extract its content as markdown.
This endpoint uses the configured ETL service to parse files and return
the extracted content that can be used as context in chat messages.
Supported file types depend on the configured ETL_SERVICE:
- Markdown/Text files: .md, .markdown, .txt (always supported)
- Audio files: .mp3, .mp4, .mpeg, .mpga, .m4a, .wav, .webm (if STT configured)
- Documents: .pdf, .docx, .doc, .pptx, .xlsx (depends on ETL service)
Returns:
JSON with attachment id, name, type, and extracted content
"""
from app.config import config as app_config
if not file.filename:
raise HTTPException(status_code=400, detail="No filename provided")
filename = file.filename
attachment_id = str(uuid.uuid4())
try:
# Save file to a temporary location
file_ext = os.path.splitext(filename)[1].lower()
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as temp_file:
temp_path = temp_file.name
content = await file.read()
temp_file.write(content)
extracted_content = ""
# Process based on file type
if file_ext in (".md", ".markdown", ".txt"):
# For text/markdown files, read content directly
with open(temp_path, encoding="utf-8") as f:
extracted_content = f.read()
elif file_ext in (".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm"):
# Audio files - transcribe if STT service is configured
if not app_config.STT_SERVICE:
raise HTTPException(
status_code=422,
detail="Audio transcription is not configured. Please set STT_SERVICE.",
)
stt_service_type = (
"local" if app_config.STT_SERVICE.startswith("local/") else "external"
)
if stt_service_type == "local":
from app.services.stt_service import stt_service
result = stt_service.transcribe_file(temp_path)
extracted_content = result.get("text", "")
else:
from litellm import atranscription
with open(temp_path, "rb") as audio_file:
transcription_kwargs = {
"model": app_config.STT_SERVICE,
"file": audio_file,
"api_key": app_config.STT_SERVICE_API_KEY,
}
if app_config.STT_SERVICE_API_BASE:
transcription_kwargs["api_base"] = (
app_config.STT_SERVICE_API_BASE
)
transcription_response = await atranscription(
**transcription_kwargs
)
extracted_content = transcription_response.get("text", "")
if extracted_content:
extracted_content = (
f"# Transcription of {filename}\n\n{extracted_content}"
)
else:
# Document files - use configured ETL service
if app_config.ETL_SERVICE == "UNSTRUCTURED":
from langchain_unstructured import UnstructuredLoader
from app.utils.document_converters import convert_document_to_markdown
loader = UnstructuredLoader(
temp_path,
mode="elements",
post_processors=[],
languages=["eng"],
include_orig_elements=False,
include_metadata=False,
strategy="auto",
)
docs = await loader.aload()
extracted_content = await convert_document_to_markdown(docs)
elif app_config.ETL_SERVICE == "LLAMACLOUD":
from llama_cloud_services import LlamaParse
from llama_cloud_services.parse.utils import ResultType
parser = LlamaParse(
api_key=app_config.LLAMA_CLOUD_API_KEY,
num_workers=1,
verbose=False,
language="en",
result_type=ResultType.MD,
)
result = await parser.aparse(temp_path)
markdown_documents = await result.aget_markdown_documents(
split_by_page=False
)
if markdown_documents:
extracted_content = "\n\n".join(
doc.text for doc in markdown_documents
)
elif app_config.ETL_SERVICE == "DOCLING":
from app.services.docling_service import create_docling_service
docling_service = create_docling_service()
result = await docling_service.process_document(temp_path, filename)
extracted_content = result.get("content", "")
else:
raise HTTPException(
status_code=422,
detail=f"ETL service not configured or unsupported file type: {file_ext}",
)
# Clean up temp file
with contextlib.suppress(Exception):
os.unlink(temp_path)
if not extracted_content:
raise HTTPException(
status_code=422,
detail=f"Could not extract content from file: {filename}",
)
# Determine attachment type (must be one of: "image", "document", "file")
# assistant-ui only supports these three types
if file_ext in (".png", ".jpg", ".jpeg", ".gif", ".webp"):
attachment_type = "image"
else:
# All other files (including audio, documents, text) are treated as "document"
attachment_type = "document"
return {
"id": attachment_id,
"name": filename,
"type": attachment_type,
"content": extracted_content,
"contentLength": len(extracted_content),
}
except HTTPException:
raise
except Exception as e:
# Clean up temp file on error
with contextlib.suppress(Exception):
os.unlink(temp_path)
raise HTTPException(
status_code=500,
detail=f"Failed to process attachment: {e!s}",
) from e

View file

@ -0,0 +1,376 @@
"""
API routes for NewLLMConfig CRUD operations.
NewLLMConfig combines LLM model settings with prompt configuration:
- LLM provider, model, API key, etc.
- Configurable system instructions
- Citation toggle
"""
import logging
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.new_chat.system_prompt import get_default_system_instructions
from app.config import config
from app.db import (
NewLLMConfig,
Permission,
User,
get_async_session,
)
from app.schemas import (
DefaultSystemInstructionsResponse,
GlobalNewLLMConfigRead,
NewLLMConfigCreate,
NewLLMConfigRead,
NewLLMConfigUpdate,
)
from app.services.llm_service import validate_llm_config
from app.users import current_active_user
from app.utils.rbac import check_permission
router = APIRouter()
logger = logging.getLogger(__name__)
# =============================================================================
# Global Configs Routes
# =============================================================================
@router.get("/global-new-llm-configs", response_model=list[GlobalNewLLMConfigRead])
async def get_global_new_llm_configs(
user: User = Depends(current_active_user),
):
"""
Get all available global NewLLMConfig configurations.
These are pre-configured by the system administrator and available to all users.
API keys are not exposed through this endpoint.
Global configs have negative IDs to distinguish from user-created configs.
"""
try:
global_configs = config.GLOBAL_LLM_CONFIGS
# Transform to new structure, hiding API keys
safe_configs = []
for cfg in global_configs:
safe_config = {
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base") or None,
"litellm_params": cfg.get("litellm_params", {}),
# New prompt configuration fields
"system_instructions": cfg.get("system_instructions", ""),
"use_default_system_instructions": cfg.get(
"use_default_system_instructions", True
),
"citations_enabled": cfg.get("citations_enabled", True),
"is_global": True,
}
safe_configs.append(safe_config)
return safe_configs
except Exception as e:
logger.exception("Failed to fetch global NewLLMConfigs")
raise HTTPException(
status_code=500, detail=f"Failed to fetch global configurations: {e!s}"
) from e
# =============================================================================
# CRUD Routes
# =============================================================================
@router.post("/new-llm-configs", response_model=NewLLMConfigRead)
async def create_new_llm_config(
config_data: NewLLMConfigCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Create a new NewLLMConfig for a search space.
Requires LLM_CONFIGS_CREATE permission.
"""
try:
# Verify user has permission
await check_permission(
session,
user,
config_data.search_space_id,
Permission.LLM_CONFIGS_CREATE.value,
"You don't have permission to create LLM configurations in this search space",
)
# Validate the LLM configuration by making a test API call
is_valid, error_message = await validate_llm_config(
provider=config_data.provider.value,
model_name=config_data.model_name,
api_key=config_data.api_key,
api_base=config_data.api_base,
custom_provider=config_data.custom_provider,
litellm_params=config_data.litellm_params,
)
if not is_valid:
raise HTTPException(
status_code=400,
detail=f"Invalid LLM configuration: {error_message}",
)
# Create the config
db_config = NewLLMConfig(**config_data.model_dump())
session.add(db_config)
await session.commit()
await session.refresh(db_config)
return db_config
except HTTPException:
raise
except Exception as e:
await session.rollback()
logger.exception("Failed to create NewLLMConfig")
raise HTTPException(
status_code=500, detail=f"Failed to create configuration: {e!s}"
) from e
@router.get("/new-llm-configs", response_model=list[NewLLMConfigRead])
async def list_new_llm_configs(
search_space_id: int,
skip: int = 0,
limit: int = 100,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get all NewLLMConfigs for a search space.
Requires LLM_CONFIGS_READ permission.
"""
try:
# Verify user has permission
await check_permission(
session,
user,
search_space_id,
Permission.LLM_CONFIGS_READ.value,
"You don't have permission to view LLM configurations in this search space",
)
result = await session.execute(
select(NewLLMConfig)
.filter(NewLLMConfig.search_space_id == search_space_id)
.order_by(NewLLMConfig.created_at.desc())
.offset(skip)
.limit(limit)
)
return result.scalars().all()
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to list NewLLMConfigs")
raise HTTPException(
status_code=500, detail=f"Failed to fetch configurations: {e!s}"
) from e
@router.get(
"/new-llm-configs/default-system-instructions",
response_model=DefaultSystemInstructionsResponse,
)
async def get_default_system_instructions_endpoint(
user: User = Depends(current_active_user),
):
"""
Get the default SURFSENSE_SYSTEM_INSTRUCTIONS template.
Useful for pre-populating the UI when creating a new configuration.
"""
return DefaultSystemInstructionsResponse(
default_system_instructions=get_default_system_instructions()
)
@router.get("/new-llm-configs/{config_id}", response_model=NewLLMConfigRead)
async def get_new_llm_config(
config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get a specific NewLLMConfig by ID.
Requires LLM_CONFIGS_READ permission.
"""
try:
result = await session.execute(
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
)
config = result.scalars().first()
if not config:
raise HTTPException(status_code=404, detail="Configuration not found")
# Verify user has permission
await check_permission(
session,
user,
config.search_space_id,
Permission.LLM_CONFIGS_READ.value,
"You don't have permission to view LLM configurations in this search space",
)
return config
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to get NewLLMConfig")
raise HTTPException(
status_code=500, detail=f"Failed to fetch configuration: {e!s}"
) from e
@router.put("/new-llm-configs/{config_id}", response_model=NewLLMConfigRead)
async def update_new_llm_config(
config_id: int,
update_data: NewLLMConfigUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Update an existing NewLLMConfig.
Requires LLM_CONFIGS_UPDATE permission.
"""
try:
result = await session.execute(
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
)
config = result.scalars().first()
if not config:
raise HTTPException(status_code=404, detail="Configuration not found")
# Verify user has permission
await check_permission(
session,
user,
config.search_space_id,
Permission.LLM_CONFIGS_UPDATE.value,
"You don't have permission to update LLM configurations in this search space",
)
update_dict = update_data.model_dump(exclude_unset=True)
# If updating LLM settings, validate them
if any(
key in update_dict
for key in [
"provider",
"model_name",
"api_key",
"api_base",
"custom_provider",
"litellm_params",
]
):
# Build the validation config from existing + updates
validation_config = {
"provider": update_dict.get("provider", config.provider).value
if hasattr(update_dict.get("provider", config.provider), "value")
else update_dict.get("provider", config.provider.value),
"model_name": update_dict.get("model_name", config.model_name),
"api_key": update_dict.get("api_key", config.api_key),
"api_base": update_dict.get("api_base", config.api_base),
"custom_provider": update_dict.get(
"custom_provider", config.custom_provider
),
"litellm_params": update_dict.get(
"litellm_params", config.litellm_params
),
}
is_valid, error_message = await validate_llm_config(
provider=validation_config["provider"],
model_name=validation_config["model_name"],
api_key=validation_config["api_key"],
api_base=validation_config["api_base"],
custom_provider=validation_config["custom_provider"],
litellm_params=validation_config["litellm_params"],
)
if not is_valid:
raise HTTPException(
status_code=400,
detail=f"Invalid LLM configuration: {error_message}",
)
# Apply updates
for key, value in update_dict.items():
setattr(config, key, value)
await session.commit()
await session.refresh(config)
return config
except HTTPException:
raise
except Exception as e:
await session.rollback()
logger.exception("Failed to update NewLLMConfig")
raise HTTPException(
status_code=500, detail=f"Failed to update configuration: {e!s}"
) from e
@router.delete("/new-llm-configs/{config_id}", response_model=dict)
async def delete_new_llm_config(
config_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Delete a NewLLMConfig.
Requires LLM_CONFIGS_DELETE permission.
"""
try:
result = await session.execute(
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
)
config = result.scalars().first()
if not config:
raise HTTPException(status_code=404, detail="Configuration not found")
# Verify user has permission
await check_permission(
session,
user,
config.search_space_id,
Permission.LLM_CONFIGS_DELETE.value,
"You don't have permission to delete LLM configurations in this search space",
)
await session.delete(config)
await session.commit()
return {"message": "Configuration deleted successfully", "id": config_id}
except HTTPException:
raise
except Exception as e:
await session.rollback()
logger.exception("Failed to delete NewLLMConfig")
raise HTTPException(
status_code=500, detail=f"Failed to delete configuration: {e!s}"
) from e

View file

@ -1,14 +1,22 @@
"""
Podcast routes for task status polling and audio retrieval.
These routes support the podcast generation feature in new-chat.
Note: The old Chat-based podcast generation has been removed.
"""
import os
from pathlib import Path
from celery.result import AsyncResult
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.celery_app import celery_app
from app.db import (
Chat,
Permission,
Podcast,
SearchSpace,
@ -16,62 +24,13 @@ from app.db import (
User,
get_async_session,
)
from app.schemas import (
PodcastCreate,
PodcastGenerateRequest,
PodcastRead,
PodcastUpdate,
)
from app.tasks.podcast_tasks import generate_chat_podcast
from app.schemas import PodcastRead
from app.users import current_active_user
from app.utils.rbac import check_permission
router = APIRouter()
@router.post("/podcasts", response_model=PodcastRead)
async def create_podcast(
podcast: PodcastCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Create a new podcast.
Requires PODCASTS_CREATE permission.
"""
try:
await check_permission(
session,
user,
podcast.search_space_id,
Permission.PODCASTS_CREATE.value,
"You don't have permission to create podcasts in this search space",
)
db_podcast = Podcast(**podcast.model_dump())
session.add(db_podcast)
await session.commit()
await session.refresh(db_podcast)
return db_podcast
except HTTPException as he:
raise he
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400,
detail="Podcast creation failed due to constraint violation",
) from None
except SQLAlchemyError:
await session.rollback()
raise HTTPException(
status_code=500, detail="Database error occurred while creating podcast"
) from None
except Exception:
await session.rollback()
raise HTTPException(
status_code=500, detail="An unexpected error occurred"
) from None
@router.get("/podcasts", response_model=list[PodcastRead])
async def read_podcasts(
skip: int = 0,
@ -159,53 +118,6 @@ async def read_podcast(
) from None
@router.put("/podcasts/{podcast_id}", response_model=PodcastRead)
async def update_podcast(
podcast_id: int,
podcast_update: PodcastUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Update a podcast.
Requires PODCASTS_UPDATE permission for the search space.
"""
try:
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
db_podcast = result.scalars().first()
if not db_podcast:
raise HTTPException(status_code=404, detail="Podcast not found")
# Check permission for the search space
await check_permission(
session,
user,
db_podcast.search_space_id,
Permission.PODCASTS_UPDATE.value,
"You don't have permission to update podcasts in this search space",
)
update_data = podcast_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_podcast, key, value)
await session.commit()
await session.refresh(db_podcast)
return db_podcast
except HTTPException as he:
raise he
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400, detail="Update failed due to constraint violation"
) from None
except SQLAlchemyError:
await session.rollback()
raise HTTPException(
status_code=500, detail="Database error occurred while updating podcast"
) from None
@router.delete("/podcasts/{podcast_id}", response_model=dict)
async def delete_podcast(
podcast_id: int,
@ -244,108 +156,8 @@ async def delete_podcast(
) from None
async def generate_chat_podcast_with_new_session(
chat_id: int,
search_space_id: int,
user_id: int,
podcast_title: str | None = None,
user_prompt: str | None = None,
):
"""Create a new session and process chat podcast generation."""
from app.db import async_session_maker
async with async_session_maker() as session:
try:
await generate_chat_podcast(
session, chat_id, search_space_id, user_id, podcast_title, user_prompt
)
except Exception as e:
import logging
logging.error(f"Error generating podcast from chat: {e!s}")
@router.post("/podcasts/generate")
async def generate_podcast(
request: PodcastGenerateRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Generate a podcast from a chat or document.
Requires PODCASTS_CREATE permission.
"""
try:
# Check if the user has permission to create podcasts
await check_permission(
session,
user,
request.search_space_id,
Permission.PODCASTS_CREATE.value,
"You don't have permission to create podcasts in this search space",
)
if request.type == "CHAT":
# Verify that all chat IDs belong to this user and search space
query = (
select(Chat)
.filter(
Chat.id.in_(request.ids),
Chat.search_space_id == request.search_space_id,
)
.join(SearchSpace)
.filter(SearchSpace.user_id == user.id)
)
result = await session.execute(query)
valid_chats = result.scalars().all()
valid_chat_ids = [chat.id for chat in valid_chats]
# If any requested ID is not in valid IDs, raise error immediately
if len(valid_chat_ids) != len(request.ids):
raise HTTPException(
status_code=403,
detail="One or more chat IDs do not belong to this user or search space",
)
from app.tasks.celery_tasks.podcast_tasks import (
generate_chat_podcast_task,
)
# Add Celery tasks for each chat ID
for chat_id in valid_chat_ids:
generate_chat_podcast_task.delay(
chat_id,
request.search_space_id,
user.id,
request.podcast_title,
request.user_prompt,
)
return {
"message": "Podcast generation started",
}
except HTTPException as he:
raise he
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400,
detail="Podcast generation failed due to constraint violation",
) from None
except SQLAlchemyError:
await session.rollback()
raise HTTPException(
status_code=500, detail="Database error occurred while generating podcast"
) from None
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500, detail=f"An unexpected error occurred: {e!s}"
) from e
@router.get("/podcasts/{podcast_id}/stream")
@router.get("/podcasts/{podcast_id}/audio")
async def stream_podcast(
podcast_id: int,
session: AsyncSession = Depends(get_async_session),
@ -354,6 +166,8 @@ async def stream_podcast(
"""
Stream a podcast audio file.
Requires PODCASTS_READ permission for the search space.
Note: Both /stream and /audio endpoints are supported for compatibility.
"""
try:
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
@ -378,7 +192,7 @@ async def stream_podcast(
file_path = podcast.file_location
# Check if the file exists
if not os.path.isfile(file_path):
if not file_path or not os.path.isfile(file_path):
raise HTTPException(status_code=404, detail="Podcast audio file not found")
# Define a generator function to stream the file
@ -404,43 +218,60 @@ async def stream_podcast(
) from e
@router.get("/podcasts/by-chat/{chat_id}", response_model=PodcastRead | None)
async def get_podcast_by_chat_id(
chat_id: int,
session: AsyncSession = Depends(get_async_session),
@router.get("/podcasts/task/{task_id}/status")
async def get_podcast_task_status(
task_id: str,
user: User = Depends(current_active_user),
):
"""
Get a podcast by its associated chat ID.
Requires PODCASTS_READ permission for the search space.
Get the status of a podcast generation task.
Used by new-chat frontend to poll for completion.
Returns:
- status: "processing" | "success" | "error"
- podcast_id: (only if status == "success")
- title: (only if status == "success")
- error: (only if status == "error")
"""
try:
# First get the chat to find its search space
chat_result = await session.execute(select(Chat).filter(Chat.id == chat_id))
chat = chat_result.scalars().first()
result = AsyncResult(task_id, app=celery_app)
if not chat:
return None
if result.ready():
# Task completed
if result.successful():
task_result = result.result
if isinstance(task_result, dict):
if task_result.get("status") == "success":
return {
"status": "success",
"podcast_id": task_result.get("podcast_id"),
"title": task_result.get("title"),
"transcript_entries": task_result.get("transcript_entries"),
}
else:
return {
"status": "error",
"error": task_result.get("error", "Unknown error"),
}
else:
return {
"status": "error",
"error": "Unexpected task result format",
}
else:
# Task failed
return {
"status": "error",
"error": str(result.result) if result.result else "Task failed",
}
else:
# Task still processing
return {
"status": "processing",
"state": result.state,
}
# Check permission for the search space
await check_permission(
session,
user,
chat.search_space_id,
Permission.PODCASTS_READ.value,
"You don't have permission to read podcasts in this search space",
)
# Get the podcast
result = await session.execute(
select(Podcast).filter(Podcast.chat_id == chat_id)
)
podcast = result.scalars().first()
return podcast
except HTTPException as he:
raise he
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error fetching podcast: {e!s}"
status_code=500, detail=f"Error checking task status: {e!s}"
) from e

View file

@ -1,13 +1,13 @@
import logging
from pathlib import Path
import yaml
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.config import config
from app.db import (
NewLLMConfig,
Permission,
SearchSpace,
SearchSpaceMembership,
@ -17,6 +17,8 @@ from app.db import (
get_default_roles_config,
)
from app.schemas import (
LLMPreferencesRead,
LLMPreferencesUpdate,
SearchSpaceCreate,
SearchSpaceRead,
SearchSpaceUpdate,
@ -184,37 +186,6 @@ async def read_search_spaces(
) from e
@router.get("/searchspaces/prompts/community")
async def get_community_prompts():
"""
Get community-curated prompts for SearchSpace System Instructions.
This endpoint does not require authentication as it serves public prompts.
"""
try:
# Get the path to the prompts YAML file
prompts_file = (
Path(__file__).parent.parent
/ "prompts"
/ "public_search_space_prompts.yaml"
)
if not prompts_file.exists():
raise HTTPException(
status_code=404, detail="Community prompts file not found"
)
with open(prompts_file, encoding="utf-8") as f:
data = yaml.safe_load(f)
return data.get("prompts", [])
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to load community prompts: {e!s}"
) from e
@router.get("/searchspaces/{search_space_id}", response_model=SearchSpaceRead)
async def read_search_space(
search_space_id: int,
@ -329,3 +300,184 @@ async def delete_search_space(
raise HTTPException(
status_code=500, detail=f"Failed to delete search space: {e!s}"
) from e
# =============================================================================
# LLM Preferences Routes
# =============================================================================
async def _get_llm_config_by_id(
session: AsyncSession, config_id: int | None
) -> dict | None:
"""
Get an LLM config by ID as a dictionary. Returns database config for positive IDs,
global config for negative IDs, or None if ID is None.
"""
if config_id is None:
return None
if config_id < 0:
# Global config - find from YAML
global_configs = config.GLOBAL_LLM_CONFIGS
for cfg in global_configs:
if cfg.get("id") == config_id:
return {
"id": cfg.get("id"),
"name": cfg.get("name"),
"description": cfg.get("description"),
"provider": cfg.get("provider"),
"custom_provider": cfg.get("custom_provider"),
"model_name": cfg.get("model_name"),
"api_base": cfg.get("api_base"),
"litellm_params": cfg.get("litellm_params", {}),
"system_instructions": cfg.get("system_instructions", ""),
"use_default_system_instructions": cfg.get(
"use_default_system_instructions", True
),
"citations_enabled": cfg.get("citations_enabled", True),
"is_global": True,
}
return None
else:
# Database config - convert to dict
result = await session.execute(
select(NewLLMConfig).filter(NewLLMConfig.id == config_id)
)
db_config = result.scalars().first()
if db_config:
return {
"id": db_config.id,
"name": db_config.name,
"description": db_config.description,
"provider": db_config.provider.value if db_config.provider else None,
"custom_provider": db_config.custom_provider,
"model_name": db_config.model_name,
"api_key": db_config.api_key,
"api_base": db_config.api_base,
"litellm_params": db_config.litellm_params or {},
"system_instructions": db_config.system_instructions or "",
"use_default_system_instructions": db_config.use_default_system_instructions,
"citations_enabled": db_config.citations_enabled,
"created_at": db_config.created_at.isoformat()
if db_config.created_at
else None,
"search_space_id": db_config.search_space_id,
}
return None
@router.get(
"/search-spaces/{search_space_id}/llm-preferences",
response_model=LLMPreferencesRead,
)
async def get_llm_preferences(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get LLM preferences (role assignments) for a search space.
Requires LLM_CONFIGS_READ permission.
"""
try:
# Check permission
await check_permission(
session,
user,
search_space_id,
Permission.LLM_CONFIGS_READ.value,
"You don't have permission to view LLM preferences",
)
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
# Get full config objects for each role
agent_llm = await _get_llm_config_by_id(session, search_space.agent_llm_id)
document_summary_llm = await _get_llm_config_by_id(
session, search_space.document_summary_llm_id
)
return LLMPreferencesRead(
agent_llm_id=search_space.agent_llm_id,
document_summary_llm_id=search_space.document_summary_llm_id,
agent_llm=agent_llm,
document_summary_llm=document_summary_llm,
)
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to get LLM preferences")
raise HTTPException(
status_code=500, detail=f"Failed to get LLM preferences: {e!s}"
) from e
@router.put(
"/search-spaces/{search_space_id}/llm-preferences",
response_model=LLMPreferencesRead,
)
async def update_llm_preferences(
search_space_id: int,
preferences: LLMPreferencesUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Update LLM preferences (role assignments) for a search space.
Requires LLM_CONFIGS_UPDATE permission.
"""
try:
# Check permission
await check_permission(
session,
user,
search_space_id,
Permission.LLM_CONFIGS_UPDATE.value,
"You don't have permission to update LLM preferences",
)
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
# Update preferences
update_data = preferences.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(search_space, key, value)
await session.commit()
await session.refresh(search_space)
# Get full config objects for response
agent_llm = await _get_llm_config_by_id(session, search_space.agent_llm_id)
document_summary_llm = await _get_llm_config_by_id(
session, search_space.document_summary_llm_id
)
return LLMPreferencesRead(
agent_llm_id=search_space.agent_llm_id,
document_summary_llm_id=search_space.document_summary_llm_id,
agent_llm=agent_llm,
document_summary_llm=document_summary_llm,
)
except HTTPException:
raise
except Exception as e:
await session.rollback()
logger.exception("Failed to update LLM preferences")
raise HTTPException(
status_code=500, detail=f"Failed to update LLM preferences: {e!s}"
) from e

View file

@ -1,13 +1,4 @@
from .base import IDModel, TimestampModel
from .chats import (
AISDKChatRequest,
ChatBase,
ChatCreate,
ChatRead,
ChatReadWithoutMessages,
ChatUpdate,
NewChatRequest,
)
from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate
from .documents import (
DocumentBase,
@ -19,15 +10,32 @@ from .documents import (
ExtensionDocumentMetadata,
PaginatedResponse,
)
from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate
from .podcasts import (
PodcastBase,
PodcastCreate,
PodcastGenerateRequest,
PodcastRead,
PodcastUpdate,
from .new_chat import (
ChatMessage,
NewChatMessageAppend,
NewChatMessageCreate,
NewChatMessageRead,
NewChatRequest,
NewChatThreadCreate,
NewChatThreadRead,
NewChatThreadUpdate,
NewChatThreadWithMessages,
ThreadHistoryLoadResponse,
ThreadListItem,
ThreadListResponse,
)
from .new_llm_config import (
DefaultSystemInstructionsResponse,
GlobalNewLLMConfigRead,
LLMPreferencesRead,
LLMPreferencesUpdate,
NewLLMConfigCreate,
NewLLMConfigPublic,
NewLLMConfigRead,
NewLLMConfigUpdate,
)
from .podcasts import PodcastBase, PodcastCreate, PodcastRead, PodcastUpdate
from .rbac_schemas import (
InviteAcceptRequest,
InviteAcceptResponse,
@ -61,16 +69,15 @@ from .search_space import (
from .users import UserCreate, UserRead, UserUpdate
__all__ = [
"AISDKChatRequest",
"ChatBase",
"ChatCreate",
"ChatRead",
"ChatReadWithoutMessages",
"ChatUpdate",
# Chat schemas (assistant-ui integration)
"ChatMessage",
# Chunk schemas
"ChunkBase",
"ChunkCreate",
"ChunkRead",
"ChunkUpdate",
"DefaultSystemInstructionsResponse",
# Document schemas
"DocumentBase",
"DocumentRead",
"DocumentUpdate",
@ -78,6 +85,8 @@ __all__ = [
"DocumentsCreate",
"ExtensionDocumentContent",
"ExtensionDocumentMetadata",
"GlobalNewLLMConfigRead",
# Base schemas
"IDModel",
# RBAC schemas
"InviteAcceptRequest",
@ -86,10 +95,10 @@ __all__ = [
"InviteInfoResponse",
"InviteRead",
"InviteUpdate",
"LLMConfigBase",
"LLMConfigCreate",
"LLMConfigRead",
"LLMConfigUpdate",
# LLM Preferences schemas
"LLMPreferencesRead",
"LLMPreferencesUpdate",
# Log schemas
"LogBase",
"LogCreate",
"LogFilter",
@ -98,28 +107,46 @@ __all__ = [
"MembershipRead",
"MembershipReadWithUser",
"MembershipUpdate",
"NewChatMessageAppend",
"NewChatMessageCreate",
"NewChatMessageRead",
"NewChatRequest",
"NewChatThreadCreate",
"NewChatThreadRead",
"NewChatThreadUpdate",
"NewChatThreadWithMessages",
# NewLLMConfig schemas
"NewLLMConfigCreate",
"NewLLMConfigPublic",
"NewLLMConfigRead",
"NewLLMConfigUpdate",
"PaginatedResponse",
"PermissionInfo",
"PermissionsListResponse",
# Podcast schemas
"PodcastBase",
"PodcastCreate",
"PodcastGenerateRequest",
"PodcastRead",
"PodcastUpdate",
"RoleCreate",
"RoleRead",
"RoleUpdate",
# Search source connector schemas
"SearchSourceConnectorBase",
"SearchSourceConnectorCreate",
"SearchSourceConnectorRead",
"SearchSourceConnectorUpdate",
# Search space schemas
"SearchSpaceBase",
"SearchSpaceCreate",
"SearchSpaceRead",
"SearchSpaceUpdate",
"SearchSpaceWithStats",
"ThreadHistoryLoadResponse",
"ThreadListItem",
"ThreadListResponse",
"TimestampModel",
# User schemas
"UserCreate",
"UserRead",
"UserSearchSpaceAccess",

View file

@ -1,72 +0,0 @@
from typing import Any
from pydantic import BaseModel, ConfigDict
from app.db import ChatType
from .base import IDModel, TimestampModel
class ChatBase(BaseModel):
type: ChatType
title: str
initial_connectors: list[str] | None = None
messages: list[Any]
search_space_id: int
state_version: int = 1
class ChatBaseWithoutMessages(BaseModel):
type: ChatType
title: str
search_space_id: int
state_version: int = 1
class ClientAttachment(BaseModel):
name: str
content_type: str
url: str
class ToolInvocation(BaseModel):
tool_call_id: str
tool_name: str
args: dict
result: dict
# class ClientMessage(BaseModel):
# role: str
# content: str
# experimental_attachments: Optional[List[ClientAttachment]] = None
# toolInvocations: Optional[List[ToolInvocation]] = None
class AISDKChatRequest(BaseModel):
messages: list[Any]
data: dict[str, Any] | None = None
class NewChatRequest(BaseModel):
"""Request schema for the new deep agent chat endpoint."""
chat_id: int
user_query: str
search_space_id: int
class ChatCreate(ChatBase):
pass
class ChatUpdate(ChatBase):
pass
class ChatRead(ChatBase, IDModel, TimestampModel):
model_config = ConfigDict(from_attributes=True)
class ChatReadWithoutMessages(ChatBaseWithoutMessages, IDModel, TimestampModel):
model_config = ConfigDict(from_attributes=True)

View file

@ -1,72 +0,0 @@
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from app.db import LiteLLMProvider
from .base import IDModel, TimestampModel
class LLMConfigBase(BaseModel):
name: str = Field(
..., max_length=100, description="User-friendly name for the LLM configuration"
)
provider: LiteLLMProvider = Field(..., description="LiteLLM provider type")
custom_provider: str | None = Field(
None, max_length=100, description="Custom provider name when provider is CUSTOM"
)
model_name: str = Field(
..., max_length=100, description="Model name without provider prefix"
)
api_key: str = Field(..., description="API key for the provider")
api_base: str | None = Field(
None, max_length=500, description="Optional API base URL"
)
litellm_params: dict[str, Any] | None = Field(
default=None, description="Additional LiteLLM parameters"
)
language: str | None = Field(
default="English", max_length=50, description="Language for the LLM"
)
class LLMConfigCreate(LLMConfigBase):
search_space_id: int = Field(
..., description="Search space ID to associate the LLM config with"
)
class LLMConfigUpdate(BaseModel):
name: str | None = Field(
None, max_length=100, description="User-friendly name for the LLM configuration"
)
provider: LiteLLMProvider | None = Field(None, description="LiteLLM provider type")
custom_provider: str | None = Field(
None, max_length=100, description="Custom provider name when provider is CUSTOM"
)
model_name: str | None = Field(
None, max_length=100, description="Model name without provider prefix"
)
api_key: str | None = Field(None, description="API key for the provider")
api_base: str | None = Field(
None, max_length=500, description="Optional API base URL"
)
language: str | None = Field(
None, max_length=50, description="Language for the LLM"
)
litellm_params: dict[str, Any] | None = Field(
None, description="Additional LiteLLM parameters"
)
class LLMConfigRead(LLMConfigBase, IDModel, TimestampModel):
id: int
created_at: datetime | None = Field(
None, description="Creation timestamp (None for global configs)"
)
search_space_id: int | None = Field(
None, description="Search space ID (None for global configs)"
)
model_config = ConfigDict(from_attributes=True)

View file

@ -0,0 +1,162 @@
"""
Pydantic schemas for the new chat feature with assistant-ui integration.
These schemas follow the assistant-ui ThreadHistoryAdapter pattern:
- ThreadRecord: id, title, archived, createdAt, updatedAt
- MessageRecord: id, threadId, role, content, createdAt
"""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from app.db import NewChatMessageRole
from .base import IDModel, TimestampModel
# =============================================================================
# Message Schemas
# =============================================================================
class NewChatMessageBase(BaseModel):
"""Base schema for new chat messages."""
role: NewChatMessageRole
content: Any # JSONB content - can be text, tool calls, etc.
class NewChatMessageCreate(NewChatMessageBase):
"""Schema for creating a new message."""
thread_id: int
class NewChatMessageRead(NewChatMessageBase, IDModel, TimestampModel):
"""Schema for reading a message."""
thread_id: int
model_config = ConfigDict(from_attributes=True)
class NewChatMessageAppend(BaseModel):
"""
Schema for appending a message via the history adapter.
This is the format assistant-ui sends when calling append().
"""
role: str # Accept string and validate in route handler
content: Any
# =============================================================================
# Thread Schemas
# =============================================================================
class NewChatThreadBase(BaseModel):
"""Base schema for new chat threads."""
title: str = Field(default="New Chat", max_length=500)
archived: bool = False
class NewChatThreadCreate(NewChatThreadBase):
"""Schema for creating a new thread."""
search_space_id: int
class NewChatThreadUpdate(BaseModel):
"""Schema for updating a thread."""
title: str | None = None
archived: bool | None = None
class NewChatThreadRead(NewChatThreadBase, IDModel):
"""
Schema for reading a thread (matches assistant-ui ThreadRecord).
"""
search_space_id: int
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
class NewChatThreadWithMessages(NewChatThreadRead):
"""Schema for reading a thread with its messages."""
messages: list[NewChatMessageRead] = []
# =============================================================================
# History Adapter Response Schemas
# =============================================================================
class ThreadHistoryLoadResponse(BaseModel):
"""
Response format for the ThreadHistoryAdapter.load() method.
Returns messages array for the current thread.
"""
messages: list[NewChatMessageRead]
class ThreadListItem(BaseModel):
"""
Thread list item for sidebar display.
Matches assistant-ui ThreadListPrimitive expected format.
"""
id: int
title: str
archived: bool
created_at: datetime = Field(alias="createdAt")
updated_at: datetime = Field(alias="updatedAt")
model_config = ConfigDict(from_attributes=True, populate_by_name=True)
class ThreadListResponse(BaseModel):
"""Response containing list of threads for the sidebar."""
threads: list[ThreadListItem]
archived_threads: list[ThreadListItem]
# =============================================================================
# Chat Request Schemas (for deep agent)
# =============================================================================
class ChatMessage(BaseModel):
"""A single message in the chat history."""
role: str # "user" or "assistant"
content: str
class ChatAttachment(BaseModel):
"""An attachment with its extracted content for chat context."""
id: str # Unique attachment ID
name: str # Original filename
type: str # Attachment type: document, image, audio
content: str # Extracted markdown content from the file
class NewChatRequest(BaseModel):
"""Request schema for the deep agent chat endpoint."""
chat_id: int
user_query: str
search_space_id: int
messages: list[ChatMessage] | None = None # Optional chat history from frontend
attachments: list[ChatAttachment] | None = (
None # Optional attachments with extracted content
)

View file

@ -0,0 +1,191 @@
"""
Pydantic schemas for the NewLLMConfig API.
NewLLMConfig combines LLM model settings with prompt configuration:
- LLM provider, model, API key, etc.
- Configurable system instructions
- Citation toggle
"""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from app.db import LiteLLMProvider
class NewLLMConfigBase(BaseModel):
"""Base schema with common fields for NewLLMConfig."""
name: str = Field(
..., max_length=100, description="User-friendly name for the configuration"
)
description: str | None = Field(
None, max_length=500, description="Optional description"
)
# LLM Model Configuration
provider: LiteLLMProvider = Field(..., description="LiteLLM provider type")
custom_provider: str | None = Field(
None, max_length=100, description="Custom provider name when provider is CUSTOM"
)
model_name: str = Field(
..., max_length=100, description="Model name without provider prefix"
)
api_key: str = Field(..., description="API key for the provider")
api_base: str | None = Field(
None, max_length=500, description="Optional API base URL"
)
litellm_params: dict[str, Any] | None = Field(
default=None, description="Additional LiteLLM parameters"
)
# Prompt Configuration
system_instructions: str = Field(
default="",
description="Custom system instructions. Empty string uses default SURFSENSE_SYSTEM_INSTRUCTIONS.",
)
use_default_system_instructions: bool = Field(
default=True,
description="Whether to use default instructions when system_instructions is empty",
)
citations_enabled: bool = Field(
default=True,
description="Whether to include citation instructions in the system prompt",
)
class NewLLMConfigCreate(NewLLMConfigBase):
"""Schema for creating a new NewLLMConfig."""
search_space_id: int = Field(
..., description="Search space ID to associate the config with"
)
class NewLLMConfigUpdate(BaseModel):
"""Schema for updating an existing NewLLMConfig. All fields are optional."""
name: str | None = Field(None, max_length=100)
description: str | None = Field(None, max_length=500)
# LLM Model Configuration
provider: LiteLLMProvider | None = None
custom_provider: str | None = Field(None, max_length=100)
model_name: str | None = Field(None, max_length=100)
api_key: str | None = None
api_base: str | None = Field(None, max_length=500)
litellm_params: dict[str, Any] | None = None
# Prompt Configuration
system_instructions: str | None = None
use_default_system_instructions: bool | None = None
citations_enabled: bool | None = None
class NewLLMConfigRead(NewLLMConfigBase):
"""Schema for reading a NewLLMConfig (includes id and timestamps)."""
id: int
created_at: datetime
search_space_id: int
model_config = ConfigDict(from_attributes=True)
class NewLLMConfigPublic(BaseModel):
"""
Public schema for NewLLMConfig that hides the API key.
Used when returning configs in list views or to users who shouldn't see keys.
"""
id: int
name: str
description: str | None = None
# LLM Model Configuration (no api_key)
provider: LiteLLMProvider
custom_provider: str | None = None
model_name: str
api_base: str | None = None
litellm_params: dict[str, Any] | None = None
# Prompt Configuration
system_instructions: str
use_default_system_instructions: bool
citations_enabled: bool
created_at: datetime
search_space_id: int
model_config = ConfigDict(from_attributes=True)
class DefaultSystemInstructionsResponse(BaseModel):
"""Response schema for getting default system instructions."""
default_system_instructions: str = Field(
..., description="The default SURFSENSE_SYSTEM_INSTRUCTIONS template"
)
class GlobalNewLLMConfigRead(BaseModel):
"""
Schema for reading global LLM configs from YAML.
Global configs have negative IDs and no search_space_id.
API key is hidden for security.
"""
id: int = Field(..., description="Negative ID for global configs")
name: str
description: str | None = None
# LLM Model Configuration (no api_key)
provider: str # String because YAML doesn't enforce enum
custom_provider: str | None = None
model_name: str
api_base: str | None = None
litellm_params: dict[str, Any] | None = None
# Prompt Configuration
system_instructions: str = ""
use_default_system_instructions: bool = True
citations_enabled: bool = True
is_global: bool = True # Always true for global configs
# =============================================================================
# LLM Preferences Schemas (for role assignments)
# =============================================================================
class LLMPreferencesRead(BaseModel):
"""Schema for reading LLM preferences (role assignments) for a search space."""
agent_llm_id: int | None = Field(
None, description="ID of the LLM config to use for agent/chat tasks"
)
document_summary_llm_id: int | None = Field(
None, description="ID of the LLM config to use for document summarization"
)
agent_llm: dict[str, Any] | None = Field(
None, description="Full config for agent LLM"
)
document_summary_llm: dict[str, Any] | None = Field(
None, description="Full config for document summary LLM"
)
model_config = ConfigDict(from_attributes=True)
class LLMPreferencesUpdate(BaseModel):
"""Schema for updating LLM preferences."""
agent_llm_id: int | None = Field(
None, description="ID of the LLM config to use for agent/chat tasks"
)
document_summary_llm_id: int | None = Field(
None, description="ID of the LLM config to use for document summarization"
)

View file

@ -1,33 +1,39 @@
from typing import Any, Literal
"""Podcast schemas for API responses."""
from pydantic import BaseModel, ConfigDict
from datetime import datetime
from typing import Any
from .base import IDModel, TimestampModel
from pydantic import BaseModel
class PodcastBase(BaseModel):
"""Base podcast schema."""
title: str
podcast_transcript: list[Any]
file_location: str = ""
podcast_transcript: list[dict[str, Any]] | None = None
file_location: str | None = None
search_space_id: int
chat_state_version: int | None = None
class PodcastCreate(PodcastBase):
"""Schema for creating a podcast."""
pass
class PodcastUpdate(PodcastBase):
pass
class PodcastUpdate(BaseModel):
"""Schema for updating a podcast."""
title: str | None = None
podcast_transcript: list[dict[str, Any]] | None = None
file_location: str | None = None
class PodcastRead(PodcastBase, IDModel, TimestampModel):
model_config = ConfigDict(from_attributes=True)
class PodcastRead(PodcastBase):
"""Schema for reading a podcast."""
id: int
created_at: datetime
class PodcastGenerateRequest(BaseModel):
type: Literal["DOCUMENT", "CHAT"]
ids: list[int]
search_space_id: int
podcast_title: str | None = None
user_prompt: str | None = None
class Config:
from_attributes = True

View file

@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.config import config
from app.db import LLMConfig, SearchSpace
from app.db import NewLLMConfig, SearchSpace
# Configure litellm to automatically drop unsupported parameters
litellm.drop_params = True
@ -16,9 +16,8 @@ logger = logging.getLogger(__name__)
class LLMRole:
LONG_CONTEXT = "long_context"
FAST = "fast"
STRATEGIC = "strategic"
AGENT = "agent" # For agent/chat operations
DOCUMENT_SUMMARY = "document_summary" # For document summarization
def get_global_llm_config(llm_config_id: int) -> dict | None:
@ -155,7 +154,7 @@ async def get_search_space_llm_instance(
Args:
session: Database session
search_space_id: Search Space ID
role: LLM role ('long_context', 'fast', or 'strategic')
role: LLM role ('agent' or 'document_summary')
Returns:
ChatLiteLLM instance or None if not found
@ -173,12 +172,10 @@ async def get_search_space_llm_instance(
# Get the appropriate LLM config ID based on role
llm_config_id = None
if role == LLMRole.LONG_CONTEXT:
llm_config_id = search_space.long_context_llm_id
elif role == LLMRole.FAST:
llm_config_id = search_space.fast_llm_id
elif role == LLMRole.STRATEGIC:
llm_config_id = search_space.strategic_llm_id
if role == LLMRole.AGENT:
llm_config_id = search_space.agent_llm_id
elif role == LLMRole.DOCUMENT_SUMMARY:
llm_config_id = search_space.document_summary_llm_id
else:
logger.error(f"Invalid LLM role: {role}")
return None
@ -250,11 +247,11 @@ async def get_search_space_llm_instance(
return ChatLiteLLM(**litellm_kwargs)
# Get the LLM configuration from database (user-specific config)
# Get the LLM configuration from database (NewLLMConfig)
result = await session.execute(
select(LLMConfig).where(
LLMConfig.id == llm_config_id,
LLMConfig.search_space_id == search_space_id,
select(NewLLMConfig).where(
NewLLMConfig.id == llm_config_id,
NewLLMConfig.search_space_id == search_space_id,
)
)
llm_config = result.scalars().first()
@ -265,11 +262,11 @@ async def get_search_space_llm_instance(
)
return None
# Build the model string for litellm / 构建 LiteLLM 的模型字符串
# Build the model string for litellm
if llm_config.custom_provider:
model_string = f"{llm_config.custom_provider}/{llm_config.model_name}"
else:
# Map provider enum to litellm format / 将提供商枚举映射为 LiteLLM 格式
# Map provider enum to litellm format
provider_map = {
"OPENAI": "openai",
"ANTHROPIC": "anthropic",
@ -283,7 +280,7 @@ async def get_search_space_llm_instance(
"COMETAPI": "cometapi",
"XAI": "xai",
"BEDROCK": "bedrock",
"AWS_BEDROCK": "bedrock", # Legacy support (backward compatibility)
"AWS_BEDROCK": "bedrock",
"VERTEX_AI": "vertex_ai",
"TOGETHER_AI": "together_ai",
"FIREWORKS_AI": "fireworks_ai",
@ -296,7 +293,6 @@ async def get_search_space_llm_instance(
"AI21": "ai21",
"CLOUDFLARE": "cloudflare",
"DATABRICKS": "databricks",
# Chinese LLM providers
"DEEPSEEK": "openai",
"ALIBABA_QWEN": "openai",
"MOONSHOT": "openai",
@ -330,28 +326,19 @@ async def get_search_space_llm_instance(
return None
async def get_long_context_llm(
async def get_agent_llm(
session: AsyncSession, search_space_id: int
) -> ChatLiteLLM | None:
"""Get the search space's long context LLM instance."""
"""Get the search space's agent LLM instance for chat operations."""
return await get_search_space_llm_instance(session, search_space_id, LLMRole.AGENT)
async def get_document_summary_llm(
session: AsyncSession, search_space_id: int
) -> ChatLiteLLM | None:
"""Get the search space's document summary LLM instance."""
return await get_search_space_llm_instance(
session, search_space_id, LLMRole.LONG_CONTEXT
)
async def get_fast_llm(
session: AsyncSession, search_space_id: int
) -> ChatLiteLLM | None:
"""Get the search space's fast LLM instance."""
return await get_search_space_llm_instance(session, search_space_id, LLMRole.FAST)
async def get_strategic_llm(
session: AsyncSession, search_space_id: int
) -> ChatLiteLLM | None:
"""Get the search space's strategic LLM instance."""
return await get_search_space_llm_instance(
session, search_space_id, LLMRole.STRATEGIC
session, search_space_id, LLMRole.DOCUMENT_SUMMARY
)
@ -366,22 +353,54 @@ async def get_user_llm_instance(
return await get_search_space_llm_instance(session, search_space_id, role)
# Legacy aliases for backward compatibility
async def get_long_context_llm(
session: AsyncSession, search_space_id: int
) -> ChatLiteLLM | None:
"""Deprecated: Use get_document_summary_llm instead."""
return await get_document_summary_llm(session, search_space_id)
async def get_fast_llm(
session: AsyncSession, search_space_id: int
) -> ChatLiteLLM | None:
"""Deprecated: Use get_agent_llm instead."""
return await get_agent_llm(session, search_space_id)
async def get_strategic_llm(
session: AsyncSession, search_space_id: int
) -> ChatLiteLLM | None:
"""Deprecated: Use get_document_summary_llm instead."""
return await get_document_summary_llm(session, search_space_id)
# User-based legacy aliases (LLM preferences are now per-search-space, not per-user)
async def get_user_long_context_llm(
session: AsyncSession, user_id: str, search_space_id: int
) -> ChatLiteLLM | None:
"""Deprecated: Use get_long_context_llm instead."""
return await get_long_context_llm(session, search_space_id)
"""
Deprecated: Use get_document_summary_llm instead.
The user_id parameter is ignored as LLM preferences are now per-search-space.
"""
return await get_document_summary_llm(session, search_space_id)
async def get_user_fast_llm(
session: AsyncSession, user_id: str, search_space_id: int
) -> ChatLiteLLM | None:
"""Deprecated: Use get_fast_llm instead."""
return await get_fast_llm(session, search_space_id)
"""
Deprecated: Use get_agent_llm instead.
The user_id parameter is ignored as LLM preferences are now per-search-space.
"""
return await get_agent_llm(session, search_space_id)
async def get_user_strategic_llm(
session: AsyncSession, user_id: str, search_space_id: int
) -> ChatLiteLLM | None:
"""Deprecated: Use get_strategic_llm instead."""
return await get_strategic_llm(session, search_space_id)
"""
Deprecated: Use get_document_summary_llm instead.
The user_id parameter is ignored as LLM preferences are now per-search-space.
"""
return await get_document_summary_llm(session, search_space_id)

View file

@ -450,6 +450,35 @@ class VercelStreamingService:
"""
return self.format_data("further-questions", {"questions": questions})
def format_thinking_step(
self,
step_id: str,
title: str,
status: str = "in_progress",
items: list[str] | None = None,
) -> str:
"""
Format a thinking step for chain-of-thought display (SurfSense specific).
Args:
step_id: Unique identifier for the step
title: The step title (e.g., "Analyzing your request")
status: Step status - "pending", "in_progress", or "completed"
items: Optional list of sub-items/details for this step
Returns:
str: SSE formatted thinking step data part
"""
return self.format_data(
"thinking-step",
{
"id": step_id,
"title": title,
"status": status,
"items": items or [],
},
)
# =========================================================================
# Error Part
# =========================================================================

View file

@ -4,7 +4,7 @@ from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.llm_service import get_strategic_llm
from app.services.llm_service import get_document_summary_llm
class QueryService:
@ -20,7 +20,7 @@ class QueryService:
chat_history_str: str | None = None,
) -> str:
"""
Reformulate the user query using the search space's strategic LLM to make it more
Reformulate the user query using the search space's document summary LLM to make it more
effective for information retrieval and research purposes.
Args:
@ -36,11 +36,11 @@ class QueryService:
return user_query
try:
# Get the search space's strategic LLM instance
llm = await get_strategic_llm(session, search_space_id)
# Get the search space's document summary LLM instance
llm = await get_document_summary_llm(session, search_space_id)
if not llm:
print(
f"Warning: No strategic LLM configured for search space {search_space_id}. Using original query."
f"Warning: No document summary LLM configured for search space {search_space_id}. Using original query."
)
return user_query

View file

@ -1,191 +0,0 @@
import json
from typing import Any
class StreamingService:
def __init__(self):
self.terminal_idx = 1
self.message_annotations = [
{"type": "TERMINAL_INFO", "content": []},
{"type": "SOURCES", "content": []},
{"type": "ANSWER", "content": []},
{"type": "FURTHER_QUESTIONS", "content": []},
]
# DEPRECATED: This sends the full annotation array every time (inefficient)
def _format_annotations(self) -> str:
"""
Format the annotations as a string
DEPRECATED: This method sends the full annotation state every time.
Use the delta formatters instead for optimal streaming.
Returns:
str: The formatted annotations string
"""
return f"8:{json.dumps(self.message_annotations)}\n"
def format_terminal_info_delta(self, text: str, message_type: str = "info") -> str:
"""
Format a single terminal info message as a delta annotation
Args:
text: The terminal message text
message_type: The message type (info, error, success, etc.)
Returns:
str: The formatted annotation delta string
"""
message = {"id": self.terminal_idx, "text": text, "type": message_type}
self.terminal_idx += 1
# Update internal state for reference
self.message_annotations[0]["content"].append(message)
# Return only the delta annotation
annotation = {"type": "TERMINAL_INFO", "data": message}
return f"8:[{json.dumps(annotation)}]\n"
def format_sources_delta(self, sources: list[dict[str, Any]]) -> str:
"""
Format sources as a delta annotation
Args:
sources: List of source objects
Returns:
str: The formatted annotation delta string
"""
# Update internal state
self.message_annotations[1]["content"] = sources
# Return only the delta annotation
nodes = []
for group in sources:
for source in group.get("sources", []):
node = {
"id": str(source.get("id", "")),
"text": source.get("description", "").strip(),
"url": source.get("url", ""),
"metadata": {
"title": source.get("title", ""),
"source_type": group.get("type", ""),
"group_name": group.get("name", ""),
},
}
nodes.append(node)
annotation = {"type": "sources", "data": {"nodes": nodes}}
return f"8:[{json.dumps(annotation)}]\n"
def format_answer_delta(self, answer_chunk: str) -> str:
"""
Format a single answer chunk as a delta annotation
Args:
answer_chunk: The new answer chunk to add
Returns:
str: The formatted annotation delta string
"""
# Update internal state by appending the chunk
if isinstance(self.message_annotations[2]["content"], list):
self.message_annotations[2]["content"].append(answer_chunk)
else:
self.message_annotations[2]["content"] = [answer_chunk]
# Return only the delta annotation with the new chunk
annotation = {"type": "ANSWER", "content": [answer_chunk]}
return f"8:[{json.dumps(annotation)}]\n"
def format_answer_annotation(self, answer_lines: list[str]) -> str:
"""
Format the complete answer as a replacement annotation
Args:
answer_lines: Complete list of answer lines
Returns:
str: The formatted annotation string
"""
# Update internal state
self.message_annotations[2]["content"] = answer_lines
# Return the full answer annotation
annotation = {"type": "ANSWER", "content": answer_lines}
return f"8:[{json.dumps(annotation)}]\n"
def format_further_questions_delta(
self, further_questions: list[dict[str, Any]]
) -> str:
"""
Format further questions as a delta annotation
Args:
further_questions: List of further question objects
Returns:
str: The formatted annotation delta string
"""
# Update internal state
self.message_annotations[3]["content"] = further_questions
# Return only the delta annotation
annotation = {
"type": "FURTHER_QUESTIONS",
"data": [
question.get("question", "")
for question in further_questions
if question.get("question", "") != ""
],
}
return f"8:[{json.dumps(annotation)}]\n"
def format_text_chunk(self, text: str) -> str:
"""
Format a text chunk using the text stream part
Args:
text: The text chunk to stream
Returns:
str: The formatted text part string
"""
return f"0:{json.dumps(text)}\n"
def format_error(self, error_message: str) -> str:
"""
Format an error using the error stream part
Args:
error_message: The error message
Returns:
str: The formatted error part string
"""
return f"3:{json.dumps(error_message)}\n"
def format_completion(
self, prompt_tokens: int = 156, completion_tokens: int = 204
) -> str:
"""
Format a completion message
Args:
prompt_tokens: Number of prompt tokens
completion_tokens: Number of completion tokens
Returns:
str: The formatted completion string
"""
total_tokens = prompt_tokens + completion_tokens
completion_data = {
"finishReason": "stop",
"usage": {
"promptTokens": prompt_tokens,
"completionTokens": completion_tokens,
"totalTokens": total_tokens,
},
}
return f"d:{json.dumps(completion_data)}\n"

View file

@ -7,9 +7,12 @@ import sys
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
# Import for content-based podcast (new-chat)
from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State as PodcasterState
from app.celery_app import celery_app
from app.config import config
from app.tasks.podcast_tasks import generate_chat_podcast
from app.db import Podcast
logger = logging.getLogger(__name__)
@ -36,53 +39,140 @@ def get_celery_session_maker():
return async_sessionmaker(engine, expire_on_commit=False)
@celery_app.task(name="generate_chat_podcast", bind=True)
def generate_chat_podcast_task(
# =============================================================================
# Content-based podcast generation (for new-chat)
# =============================================================================
def _clear_active_podcast_redis_key(search_space_id: int) -> None:
"""Clear the active podcast task key from Redis when task completes."""
import os
import redis
try:
redis_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
client = redis.from_url(redis_url, decode_responses=True)
key = f"podcast:active:{search_space_id}"
client.delete(key)
logger.info(f"Cleared active podcast key for search_space_id={search_space_id}")
except Exception as e:
logger.warning(f"Could not clear active podcast key: {e}")
@celery_app.task(name="generate_content_podcast", bind=True)
def generate_content_podcast_task(
self,
chat_id: int,
source_content: str,
search_space_id: int,
user_id: int,
podcast_title: str | None = None,
podcast_title: str = "SurfSense Podcast",
user_prompt: str | None = None,
):
) -> dict:
"""
Celery task to generate podcast from chat.
Celery task to generate podcast from source content (for new-chat).
This task generates a podcast directly from provided content.
Args:
chat_id: ID of the chat to generate podcast from
source_content: The text content to convert into a podcast
search_space_id: ID of the search space
user_id: ID of the user,
podcast_title: Title for the podcast
user_prompt: Optional prompt from the user to guide the podcast generation
user_prompt: Optional instructions for podcast style/tone
Returns:
dict with podcast_id on success, or error info on failure
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_generate_chat_podcast(
chat_id, search_space_id, user_id, podcast_title, user_prompt
result = loop.run_until_complete(
_generate_content_podcast(
source_content,
search_space_id,
podcast_title,
user_prompt,
)
)
loop.run_until_complete(loop.shutdown_asyncgens())
return result
except Exception as e:
logger.error(f"Error generating content podcast: {e!s}")
return {"status": "error", "error": str(e)}
finally:
# Always clear the active podcast key when task completes (success or failure)
_clear_active_podcast_redis_key(search_space_id)
asyncio.set_event_loop(None)
loop.close()
async def _generate_chat_podcast(
chat_id: int,
async def _generate_content_podcast(
source_content: str,
search_space_id: int,
user_id: int,
podcast_title: str | None = None,
podcast_title: str = "SurfSense Podcast",
user_prompt: str | None = None,
):
"""Generate chat podcast with new session."""
) -> dict:
"""Generate content-based podcast with new session."""
async with get_celery_session_maker()() as session:
try:
await generate_chat_podcast(
session, chat_id, search_space_id, user_id, podcast_title, user_prompt
# Configure the podcaster graph
graph_config = {
"configurable": {
"podcast_title": podcast_title,
"search_space_id": search_space_id,
"user_prompt": user_prompt,
}
}
# Initialize the podcaster state with the source content
initial_state = PodcasterState(
source_content=source_content,
db_session=session,
)
# Run the podcaster graph
result = await podcaster_graph.ainvoke(initial_state, config=graph_config)
# Extract results
podcast_transcript = result.get("podcast_transcript", [])
file_path = result.get("final_podcast_file_path", "")
# Convert transcript to serializable format
serializable_transcript = []
for entry in podcast_transcript:
if hasattr(entry, "speaker_id"):
serializable_transcript.append(
{"speaker_id": entry.speaker_id, "dialog": entry.dialog}
)
else:
serializable_transcript.append(
{
"speaker_id": entry.get("speaker_id", 0),
"dialog": entry.get("dialog", ""),
}
)
# Save podcast to database
podcast = Podcast(
title=podcast_title,
podcast_transcript=serializable_transcript,
file_location=file_path,
search_space_id=search_space_id,
)
session.add(podcast)
await session.commit()
await session.refresh(podcast)
logger.info(f"Successfully generated content podcast: {podcast.id}")
return {
"status": "success",
"podcast_id": podcast.id,
"title": podcast_title,
"transcript_entries": len(serializable_transcript),
}
except Exception as e:
logger.error(f"Error generating podcast from chat: {e!s}")
logger.error(f"Error in _generate_content_podcast: {e!s}")
await session.rollback()
raise

View file

@ -1,75 +0,0 @@
from collections.abc import AsyncGenerator
from typing import Any
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.researcher.graph import graph as researcher_graph
from app.agents.researcher.state import State
from app.services.streaming_service import StreamingService
async def stream_connector_search_results(
user_query: str,
user_id: str | UUID,
search_space_id: int,
session: AsyncSession,
research_mode: str,
selected_connectors: list[str],
langchain_chat_history: list[Any],
document_ids_to_add_in_context: list[int],
language: str | None = None,
top_k: int = 10,
) -> AsyncGenerator[str, None]:
"""
Stream connector search results to the client
Args:
user_query: The user's query
user_id: The user's ID (can be UUID object or string)
search_space_id: The search space ID
session: The database session
research_mode: The research mode
selected_connectors: List of selected connectors
Yields:
str: Formatted response strings
"""
streaming_service = StreamingService()
# Convert UUID to string if needed
user_id_str = str(user_id) if isinstance(user_id, UUID) else user_id
# Sample configuration
config = {
"configurable": {
"user_query": user_query,
"connectors_to_search": selected_connectors,
"user_id": user_id_str,
"search_space_id": search_space_id,
"document_ids_to_add_in_context": document_ids_to_add_in_context,
"language": language, # Add language to the configuration
"top_k": top_k, # Add top_k to the configuration
}
}
# print(f"Researcher configuration: {config['configurable']}") # Debug print
# Initialize state with database session and streaming service
initial_state = State(
db_session=session,
streaming_service=streaming_service,
chat_history=langchain_chat_history,
)
# Run the graph directly
print("\nRunning the complete researcher workflow...")
# Use streaming with config parameter
async for chunk in researcher_graph.astream(
initial_state,
config=config,
stream_mode="custom",
):
if isinstance(chunk, dict) and "yield_value" in chunk:
yield chunk["yield_value"]
yield streaming_service.format_completion()

View file

@ -3,69 +3,115 @@ Streaming task for the new SurfSense deep agent chat.
This module streams responses from the deep agent using the Vercel AI SDK
Data Stream Protocol (SSE format).
Supports loading LLM configurations from:
- YAML files (negative IDs for global configs)
- NewLLMConfig database table (positive IDs for user-created configs with prompt settings)
"""
import json
from collections.abc import AsyncGenerator
from uuid import UUID
from langchain_core.messages import HumanMessage
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.chat_deepagent import (
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.agents.new_chat.checkpointer import get_checkpointer
from app.agents.new_chat.llm_config import (
AgentConfig,
create_chat_litellm_from_agent_config,
create_chat_litellm_from_config,
create_surfsense_deep_agent,
load_agent_config,
load_llm_config_from_yaml,
)
from app.schemas.new_chat import ChatAttachment
from app.services.connector_service import ConnectorService
from app.services.new_streaming_service import VercelStreamingService
def format_attachments_as_context(attachments: list[ChatAttachment]) -> str:
"""Format attachments as context for the agent."""
if not attachments:
return ""
context_parts = ["<user_attachments>"]
for i, attachment in enumerate(attachments, 1):
context_parts.append(
f"<attachment index='{i}' name='{attachment.name}' type='{attachment.type}'>"
)
context_parts.append(f"<![CDATA[{attachment.content}]]>")
context_parts.append("</attachment>")
context_parts.append("</user_attachments>")
return "\n".join(context_parts)
async def stream_new_chat(
user_query: str,
user_id: str | UUID,
search_space_id: int,
chat_id: int,
session: AsyncSession,
llm_config_id: int = -1,
attachments: list[ChatAttachment] | None = None,
) -> AsyncGenerator[str, None]:
"""
Stream chat responses from the new SurfSense deep agent.
This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming.
The chat_id is used as LangGraph's thread_id for memory/checkpointing,
so chat history is automatically managed by LangGraph.
The chat_id is used as LangGraph's thread_id for memory/checkpointing.
Message history can be passed from the frontend for context.
Args:
user_query: The user's query
user_id: The user's ID (can be UUID object or string)
search_space_id: The search space ID
chat_id: The chat ID (used as LangGraph thread_id for memory)
session: The database session
llm_config_id: The LLM configuration ID (default: -1 for first global config)
messages: Optional chat history from frontend (list of ChatMessage)
Yields:
str: SSE formatted response strings
"""
streaming_service = VercelStreamingService()
# Convert UUID to string if needed
str(user_id) if isinstance(user_id, UUID) else user_id
# Track the current text block for streaming (defined early for exception handling)
current_text_id: str | None = None
try:
# Load LLM config
llm_config = load_llm_config_from_yaml(llm_config_id=llm_config_id)
if not llm_config:
yield streaming_service.format_error(
f"Failed to load LLM config with id {llm_config_id}"
)
yield streaming_service.format_done()
return
# Load LLM config - supports both YAML (negative IDs) and database (positive IDs)
agent_config: AgentConfig | None = None
if llm_config_id >= 0:
# Positive ID: Load from NewLLMConfig database table
agent_config = await load_agent_config(
session=session,
config_id=llm_config_id,
search_space_id=search_space_id,
)
if not agent_config:
yield streaming_service.format_error(
f"Failed to load NewLLMConfig with id {llm_config_id}"
)
yield streaming_service.format_done()
return
# Create ChatLiteLLM from AgentConfig
llm = create_chat_litellm_from_agent_config(agent_config)
else:
# Negative ID: Load from YAML (global configs)
llm_config = load_llm_config_from_yaml(llm_config_id=llm_config_id)
if not llm_config:
yield streaming_service.format_error(
f"Failed to load LLM config with id {llm_config_id}"
)
yield streaming_service.format_done()
return
# Create ChatLiteLLM from YAML config dict
llm = create_chat_litellm_from_config(llm_config)
# Create AgentConfig from YAML for consistency (uses defaults for prompt settings)
agent_config = AgentConfig.from_yaml_config(llm_config)
# Create ChatLiteLLM instance
llm = create_chat_litellm_from_config(llm_config)
if not llm:
yield streaming_service.format_error("Failed to create LLM instance")
yield streaming_service.format_done()
@ -74,18 +120,45 @@ async def stream_new_chat(
# Create connector service
connector_service = ConnectorService(session, search_space_id=search_space_id)
# Create the deep agent
# Get the PostgreSQL checkpointer for persistent conversation memory
checkpointer = await get_checkpointer()
# Create the deep agent with checkpointer and configurable prompts
agent = create_surfsense_deep_agent(
llm=llm,
search_space_id=search_space_id,
db_session=session,
connector_service=connector_service,
checkpointer=checkpointer,
agent_config=agent_config, # Pass prompt configuration
)
# Build input with just the current user query
# Chat history is managed by LangGraph via thread_id
# Build input with message history from frontend
langchain_messages = []
# Format the user query with attachment context if any
final_query = user_query
if attachments:
attachment_context = format_attachments_as_context(attachments)
final_query = (
f"{attachment_context}\n\n<user_query>{user_query}</user_query>"
)
# if messages:
# # Convert frontend messages to LangChain format
# for msg in messages:
# if msg.role == "user":
# langchain_messages.append(HumanMessage(content=msg.content))
# elif msg.role == "assistant":
# langchain_messages.append(AIMessage(content=msg.content))
# else:
# Fallback: just use the current user query with attachment context
langchain_messages.append(HumanMessage(content=final_query))
input_state = {
"messages": [HumanMessage(content=user_query)],
# Lets not pass this message atm because we are using the checkpointer to manage the conversation history
# We will use this to simulate group chat functionality in the future
"messages": langchain_messages,
"search_space_id": search_space_id,
}
@ -103,6 +176,51 @@ async def stream_new_chat(
# Reset text tracking for this stream
accumulated_text = ""
# Track thinking steps for chain-of-thought display
thinking_step_counter = 0
# Map run_id -> step_id for tool calls so we can update them on completion
tool_step_ids: dict[str, str] = {}
# Track the last active step so we can mark it complete at the end
last_active_step_id: str | None = None
last_active_step_title: str = ""
last_active_step_items: list[str] = []
# Track which steps have been completed to avoid duplicate completions
completed_step_ids: set[str] = set()
# Track if we just finished a tool (text flows silently after tools)
just_finished_tool: bool = False
def next_thinking_step_id() -> str:
nonlocal thinking_step_counter
thinking_step_counter += 1
return f"thinking-{thinking_step_counter}"
def complete_current_step() -> str | None:
"""Complete the current active step and return the completion event, if any."""
nonlocal last_active_step_id, last_active_step_title, last_active_step_items
if last_active_step_id and last_active_step_id not in completed_step_ids:
completed_step_ids.add(last_active_step_id)
return streaming_service.format_thinking_step(
step_id=last_active_step_id,
title=last_active_step_title,
status="completed",
items=last_active_step_items if last_active_step_items else None,
)
return None
# Initial thinking step - analyzing the request
analyze_step_id = next_thinking_step_id()
last_active_step_id = analyze_step_id
last_active_step_title = "Understanding your request"
last_active_step_items = [
f"Processing: {user_query[:80]}{'...' if len(user_query) > 80 else ''}"
]
yield streaming_service.format_thinking_step(
step_id=analyze_step_id,
title="Understanding your request",
status="in_progress",
items=last_active_step_items,
)
# Stream the agent response with thread config for memory
async for event in agent.astream_events(
input_state, config=config, version="v2"
@ -117,6 +235,18 @@ async def stream_new_chat(
if content and isinstance(content, str):
# Start a new text block if needed
if current_text_id is None:
# Complete any previous step
completion_event = complete_current_step()
if completion_event:
yield completion_event
if just_finished_tool:
# Clear the active step tracking - text flows without a dedicated step
last_active_step_id = None
last_active_step_title = ""
last_active_step_items = []
just_finished_tool = False
current_text_id = streaming_service.generate_text_id()
yield streaming_service.format_text_start(current_text_id)
@ -137,6 +267,122 @@ async def stream_new_chat(
yield streaming_service.format_text_end(current_text_id)
current_text_id = None
# Complete any previous step EXCEPT "Synthesizing response"
# (we want to reuse the Synthesizing step after tools complete)
if last_active_step_title != "Synthesizing response":
completion_event = complete_current_step()
if completion_event:
yield completion_event
# Reset the just_finished_tool flag since we're starting a new tool
just_finished_tool = False
# Create thinking step for the tool call and store it for later update
tool_step_id = next_thinking_step_id()
tool_step_ids[run_id] = tool_step_id
last_active_step_id = tool_step_id
if tool_name == "search_knowledge_base":
query = (
tool_input.get("query", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
last_active_step_title = "Searching knowledge base"
last_active_step_items = [
f"Query: {query[:100]}{'...' if len(query) > 100 else ''}"
]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Searching knowledge base",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "link_preview":
url = (
tool_input.get("url", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
last_active_step_title = "Fetching link preview"
last_active_step_items = [
f"URL: {url[:80]}{'...' if len(url) > 80 else ''}"
]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Fetching link preview",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "display_image":
src = (
tool_input.get("src", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
title = (
tool_input.get("title", "")
if isinstance(tool_input, dict)
else ""
)
last_active_step_title = "Displaying image"
last_active_step_items = [
f"Image: {title[:50] if title else src[:50]}{'...' if len(title or src) > 50 else ''}"
]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Displaying image",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "scrape_webpage":
url = (
tool_input.get("url", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
last_active_step_title = "Scraping webpage"
last_active_step_items = [
f"URL: {url[:80]}{'...' if len(url) > 80 else ''}"
]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Scraping webpage",
status="in_progress",
items=last_active_step_items,
)
elif tool_name == "generate_podcast":
podcast_title = (
tool_input.get("podcast_title", "SurfSense Podcast")
if isinstance(tool_input, dict)
else "SurfSense Podcast"
)
# Get content length for context
content_len = len(
tool_input.get("source_content", "")
if isinstance(tool_input, dict)
else ""
)
last_active_step_title = "Generating podcast"
last_active_step_items = [
f"Title: {podcast_title}",
f"Content: {content_len:,} characters",
"Preparing audio generation...",
]
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title="Generating podcast",
status="in_progress",
items=last_active_step_items,
)
else:
last_active_step_title = f"Using {tool_name.replace('_', ' ')}"
last_active_step_items = []
yield streaming_service.format_thinking_step(
step_id=tool_step_id,
title=last_active_step_title,
status="in_progress",
)
# Stream tool info
tool_call_id = (
f"call_{run_id[:32]}"
@ -163,22 +409,358 @@ async def stream_new_chat(
f"Searching knowledge base: {query[:100]}{'...' if len(query) > 100 else ''}",
"info",
)
elif tool_name == "link_preview":
url = (
tool_input.get("url", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
yield streaming_service.format_terminal_info(
f"Fetching link preview: {url[:80]}{'...' if len(url) > 80 else ''}",
"info",
)
elif tool_name == "display_image":
src = (
tool_input.get("src", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
yield streaming_service.format_terminal_info(
f"Displaying image: {src[:60]}{'...' if len(src) > 60 else ''}",
"info",
)
elif tool_name == "scrape_webpage":
url = (
tool_input.get("url", "")
if isinstance(tool_input, dict)
else str(tool_input)
)
yield streaming_service.format_terminal_info(
f"Scraping webpage: {url[:70]}{'...' if len(url) > 70 else ''}",
"info",
)
elif tool_name == "generate_podcast":
title = (
tool_input.get("podcast_title", "SurfSense Podcast")
if isinstance(tool_input, dict)
else "SurfSense Podcast"
)
yield streaming_service.format_terminal_info(
f"Generating podcast: {title}",
"info",
)
elif event_type == "on_tool_end":
run_id = event.get("run_id", "")
tool_output = event.get("data", {}).get("output", "")
tool_name = event.get("name", "unknown_tool")
raw_output = event.get("data", {}).get("output", "")
# Extract content from ToolMessage if needed
# LangGraph may return a ToolMessage object instead of raw dict
if hasattr(raw_output, "content"):
# It's a ToolMessage object - extract the content
content = raw_output.content
# If content is a string that looks like JSON, try to parse it
if isinstance(content, str):
try:
tool_output = json.loads(content)
except (json.JSONDecodeError, TypeError):
tool_output = {"result": content}
elif isinstance(content, dict):
tool_output = content
else:
tool_output = {"result": str(content)}
elif isinstance(raw_output, dict):
tool_output = raw_output
else:
tool_output = {
"result": str(raw_output) if raw_output else "completed"
}
tool_call_id = f"call_{run_id[:32]}" if run_id else "call_unknown"
# Don't stream the full output (can be very large), just acknowledge
yield streaming_service.format_tool_output_available(
tool_call_id,
{"status": "completed", "result_length": len(str(tool_output))},
# Get the original tool step ID to update it (not create a new one)
original_step_id = tool_step_ids.get(
run_id, f"thinking-unknown-{run_id[:8]}"
)
yield streaming_service.format_terminal_info(
"Knowledge base search completed", "success"
)
# Mark the tool thinking step as completed using the SAME step ID
# Also add to completed set so we don't try to complete it again
completed_step_ids.add(original_step_id)
if tool_name == "search_knowledge_base":
# Get result count if available
result_info = "Search completed"
if isinstance(tool_output, dict):
result_len = tool_output.get("result_length", 0)
if result_len > 0:
result_info = (
f"Found relevant information ({result_len} chars)"
)
# Include original query in completed items
completed_items = [*last_active_step_items, result_info]
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Searching knowledge base",
status="completed",
items=completed_items,
)
elif tool_name == "link_preview":
# Build completion items based on link preview result
if isinstance(tool_output, dict):
title = tool_output.get("title", "Link")
domain = tool_output.get("domain", "")
has_error = "error" in tool_output
if has_error:
completed_items = [
*last_active_step_items,
f"Error: {tool_output.get('error', 'Failed to fetch')}",
]
else:
completed_items = [
*last_active_step_items,
f"Title: {title[:60]}{'...' if len(title) > 60 else ''}",
f"Domain: {domain}" if domain else "Preview loaded",
]
else:
completed_items = [*last_active_step_items, "Preview loaded"]
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Fetching link preview",
status="completed",
items=completed_items,
)
elif tool_name == "display_image":
# Build completion items for image display
if isinstance(tool_output, dict):
title = tool_output.get("title", "")
alt = tool_output.get("alt", "Image")
display_name = title or alt
completed_items = [
*last_active_step_items,
f"Showing: {display_name[:50]}{'...' if len(display_name) > 50 else ''}",
]
else:
completed_items = [*last_active_step_items, "Image displayed"]
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Displaying image",
status="completed",
items=completed_items,
)
elif tool_name == "scrape_webpage":
# Build completion items for webpage scraping
if isinstance(tool_output, dict):
title = tool_output.get("title", "Webpage")
word_count = tool_output.get("word_count", 0)
has_error = "error" in tool_output
if has_error:
completed_items = [
*last_active_step_items,
f"Error: {tool_output.get('error', 'Failed to scrape')[:50]}",
]
else:
completed_items = [
*last_active_step_items,
f"Title: {title[:50]}{'...' if len(title) > 50 else ''}",
f"Extracted: {word_count:,} words",
]
else:
completed_items = [*last_active_step_items, "Content extracted"]
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Scraping webpage",
status="completed",
items=completed_items,
)
elif tool_name == "generate_podcast":
# Build detailed completion items based on podcast status
podcast_status = (
tool_output.get("status", "unknown")
if isinstance(tool_output, dict)
else "unknown"
)
podcast_title = (
tool_output.get("title", "Podcast")
if isinstance(tool_output, dict)
else "Podcast"
)
if podcast_status == "processing":
completed_items = [
f"Title: {podcast_title}",
"Audio generation started",
"Processing in background...",
]
elif podcast_status == "already_generating":
completed_items = [
f"Title: {podcast_title}",
"Podcast already in progress",
"Please wait for it to complete",
]
elif podcast_status == "error":
error_msg = (
tool_output.get("error", "Unknown error")
if isinstance(tool_output, dict)
else "Unknown error"
)
completed_items = [
f"Title: {podcast_title}",
f"Error: {error_msg[:50]}",
]
else:
completed_items = last_active_step_items
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title="Generating podcast",
status="completed",
items=completed_items,
)
else:
yield streaming_service.format_thinking_step(
step_id=original_step_id,
title=f"Using {tool_name.replace('_', ' ')}",
status="completed",
items=last_active_step_items,
)
# Mark that we just finished a tool - "Synthesizing response" will be created
# when text actually starts flowing (not immediately)
just_finished_tool = True
# Clear the active step since the tool is done
last_active_step_id = None
last_active_step_title = ""
last_active_step_items = []
# Handle different tool outputs
if tool_name == "generate_podcast":
# Stream the full podcast result so frontend can render the audio player
yield streaming_service.format_tool_output_available(
tool_call_id,
tool_output
if isinstance(tool_output, dict)
else {"result": tool_output},
)
# Send appropriate terminal message based on status
if (
isinstance(tool_output, dict)
and tool_output.get("status") == "success"
):
yield streaming_service.format_terminal_info(
f"Podcast generated successfully: {tool_output.get('title', 'Podcast')}",
"success",
)
else:
error_msg = (
tool_output.get("error", "Unknown error")
if isinstance(tool_output, dict)
else "Unknown error"
)
yield streaming_service.format_terminal_info(
f"Podcast generation failed: {error_msg}",
"error",
)
elif tool_name == "link_preview":
# Stream the full link preview result so frontend can render the MediaCard
yield streaming_service.format_tool_output_available(
tool_call_id,
tool_output
if isinstance(tool_output, dict)
else {"result": tool_output},
)
# Send appropriate terminal message
if isinstance(tool_output, dict) and "error" not in tool_output:
title = tool_output.get("title", "Link")
yield streaming_service.format_terminal_info(
f"Link preview loaded: {title[:50]}{'...' if len(title) > 50 else ''}",
"success",
)
else:
error_msg = (
tool_output.get("error", "Failed to fetch")
if isinstance(tool_output, dict)
else "Failed to fetch"
)
yield streaming_service.format_terminal_info(
f"Link preview failed: {error_msg}",
"error",
)
elif tool_name == "display_image":
# Stream the full image result so frontend can render the Image component
yield streaming_service.format_tool_output_available(
tool_call_id,
tool_output
if isinstance(tool_output, dict)
else {"result": tool_output},
)
# Send terminal message
if isinstance(tool_output, dict):
title = tool_output.get("title") or tool_output.get(
"alt", "Image"
)
yield streaming_service.format_terminal_info(
f"Image displayed: {title[:40]}{'...' if len(title) > 40 else ''}",
"success",
)
elif tool_name == "scrape_webpage":
# Stream the scrape result so frontend can render the Article component
# Note: We send metadata for display, but content goes to LLM for processing
if isinstance(tool_output, dict):
# Create a display-friendly output (without full content for the card)
display_output = {
k: v for k, v in tool_output.items() if k != "content"
}
# But keep a truncated content preview
if "content" in tool_output:
content = tool_output.get("content", "")
display_output["content_preview"] = (
content[:500] + "..." if len(content) > 500 else content
)
yield streaming_service.format_tool_output_available(
tool_call_id,
display_output,
)
else:
yield streaming_service.format_tool_output_available(
tool_call_id,
{"result": tool_output},
)
# Send terminal message
if isinstance(tool_output, dict) and "error" not in tool_output:
title = tool_output.get("title", "Webpage")
word_count = tool_output.get("word_count", 0)
yield streaming_service.format_terminal_info(
f"Scraped: {title[:40]}{'...' if len(title) > 40 else ''} ({word_count:,} words)",
"success",
)
else:
error_msg = (
tool_output.get("error", "Failed to scrape")
if isinstance(tool_output, dict)
else "Failed to scrape"
)
yield streaming_service.format_terminal_info(
f"Scrape failed: {error_msg}",
"error",
)
elif tool_name == "search_knowledge_base":
# Don't stream the full output for search (can be very large), just acknowledge
yield streaming_service.format_tool_output_available(
tool_call_id,
{"status": "completed", "result_length": len(str(tool_output))},
)
yield streaming_service.format_terminal_info(
"Knowledge base search completed", "success"
)
else:
# Default handling for other tools
yield streaming_service.format_tool_output_available(
tool_call_id,
{"status": "completed", "result_length": len(str(tool_output))},
)
yield streaming_service.format_terminal_info(
f"Tool {tool_name} completed", "success"
)
# Handle chain/agent end to close any open text blocks
elif event_type in ("on_chain_end", "on_agent_end"):
@ -190,6 +772,11 @@ async def stream_new_chat(
if current_text_id is not None:
yield streaming_service.format_text_end(current_text_id)
# Mark the last active thinking step as completed using the same title
completion_event = complete_current_step()
if completion_event:
yield completion_event
# Finish the step and message
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()

View file

@ -1,211 +0,0 @@
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State
from app.db import Chat, Podcast
from app.services.task_logging_service import TaskLoggingService
async def generate_chat_podcast(
session: AsyncSession,
chat_id: int,
search_space_id: int,
user_id: int,
podcast_title: str | None = None,
user_prompt: str | None = None,
):
task_logger = TaskLoggingService(session, search_space_id)
# Log task start
log_entry = await task_logger.log_task_start(
task_name="generate_chat_podcast",
source="podcast_task",
message=f"Starting podcast generation for chat {chat_id}",
metadata={
"chat_id": chat_id,
"search_space_id": search_space_id,
"podcast_title": podcast_title,
"user_id": str(user_id),
"user_prompt": user_prompt,
},
)
try:
# Fetch the chat with the specified ID
await task_logger.log_task_progress(
log_entry, f"Fetching chat {chat_id} from database", {"stage": "fetch_chat"}
)
query = select(Chat).filter(
Chat.id == chat_id, Chat.search_space_id == search_space_id
)
result = await session.execute(query)
chat = result.scalars().first()
if not chat:
await task_logger.log_task_failure(
log_entry,
f"Chat with id {chat_id} not found in search space {search_space_id}",
"Chat not found",
{"error_type": "ChatNotFound"},
)
raise ValueError(
f"Chat with id {chat_id} not found in search space {search_space_id}"
)
# Create chat history structure
await task_logger.log_task_progress(
log_entry,
f"Processing chat history for chat {chat_id}",
{"stage": "process_chat_history", "message_count": len(chat.messages)},
)
chat_history_str = "<chat_history>"
processed_messages = 0
for message in chat.messages:
if message["role"] == "user":
chat_history_str += f"<user_message>{message['content']}</user_message>"
processed_messages += 1
elif message["role"] == "assistant":
chat_history_str += (
f"<assistant_message>{message['content']}</assistant_message>"
)
processed_messages += 1
chat_history_str += "</chat_history>"
# Pass it to the SurfSense Podcaster
await task_logger.log_task_progress(
log_entry,
f"Initializing podcast generation for chat {chat_id}",
{
"stage": "initialize_podcast_generation",
"processed_messages": processed_messages,
"content_length": len(chat_history_str),
},
)
config = {
"configurable": {
"podcast_title": podcast_title or "SurfSense Podcast",
"user_id": str(user_id),
"search_space_id": search_space_id,
"user_prompt": user_prompt,
}
}
# Initialize state with database session and streaming service
initial_state = State(source_content=chat_history_str, db_session=session)
# Run the graph directly
await task_logger.log_task_progress(
log_entry,
f"Running podcast generation graph for chat {chat_id}",
{"stage": "run_podcast_graph"},
)
result = await podcaster_graph.ainvoke(initial_state, config=config)
# Convert podcast transcript entries to serializable format
await task_logger.log_task_progress(
log_entry,
f"Processing podcast transcript for chat {chat_id}",
{
"stage": "process_transcript",
"transcript_entries": len(result["podcast_transcript"]),
},
)
serializable_transcript = []
for entry in result["podcast_transcript"]:
serializable_transcript.append(
{"speaker_id": entry.speaker_id, "dialog": entry.dialog}
)
# Create a new podcast entry
await task_logger.log_task_progress(
log_entry,
f"Creating podcast database entry for chat {chat_id}",
{
"stage": "create_podcast_entry",
"file_location": result.get("final_podcast_file_path"),
},
)
# check if podcast already exists for this chat (re-generation)
existing_podcast = await session.execute(
select(Podcast).filter(Podcast.chat_id == chat_id)
)
existing_podcast = existing_podcast.scalars().first()
if existing_podcast:
existing_podcast.podcast_transcript = serializable_transcript
existing_podcast.file_location = result["final_podcast_file_path"]
existing_podcast.chat_state_version = chat.state_version
await session.commit()
await session.refresh(existing_podcast)
return existing_podcast
else:
podcast = Podcast(
title=f"{podcast_title}",
podcast_transcript=serializable_transcript,
file_location=result["final_podcast_file_path"],
search_space_id=search_space_id,
chat_state_version=chat.state_version,
chat_id=chat.id,
)
# Add to session and commit
session.add(podcast)
await session.commit()
await session.refresh(podcast)
# Log success
await task_logger.log_task_success(
log_entry,
f"Successfully generated podcast for chat {chat_id}",
{
"podcast_id": podcast.id,
"podcast_title": podcast_title,
"transcript_entries": len(serializable_transcript),
"file_location": result.get("final_podcast_file_path"),
"processed_messages": processed_messages,
"content_length": len(chat_history_str),
},
)
return podcast
except ValueError as ve:
# ValueError is already logged above for chat not found
if "not found" not in str(ve):
await task_logger.log_task_failure(
log_entry,
f"Value error during podcast generation for chat {chat_id}",
str(ve),
{"error_type": "ValueError"},
)
raise ve
except SQLAlchemyError as db_error:
await session.rollback()
await task_logger.log_task_failure(
log_entry,
f"Database error during podcast generation for chat {chat_id}",
str(db_error),
{"error_type": "SQLAlchemyError"},
)
raise db_error
except Exception as e:
await session.rollback()
await task_logger.log_task_failure(
log_entry,
f"Unexpected error during podcast generation for chat {chat_id}",
str(e),
{"error_type": type(e).__name__},
)
raise RuntimeError(
f"Failed to generate podcast for chat {chat_id}: {e!s}"
) from e

View file

@ -1,9 +1,15 @@
import argparse
import asyncio
import logging
import sys
import uvicorn
from dotenv import load_dotenv
# Fix for Windows: psycopg requires SelectorEventLoop, not ProactorEventLoop
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
from app.config.uvicorn import load_uvicorn_config
logging.basicConfig(

View file

@ -54,6 +54,8 @@ dependencies = [
"trafilatura>=2.0.0",
"fastapi-users[oauth,sqlalchemy]>=15.0.3",
"chonkie[all]>=1.5.0",
"langgraph-checkpoint-postgres>=3.0.2",
"psycopg[binary,pool]>=3.3.2",
]
[dependency-groups]

View file

@ -2983,6 +2983,21 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/48/e3/616e3a7ff737d98c1bbb5700dd62278914e2a9ded09a79a1fa93cf24ce12/langgraph_checkpoint-3.0.1-py3-none-any.whl", hash = "sha256:9b04a8d0edc0474ce4eaf30c5d731cee38f11ddff50a6177eead95b5c4e4220b", size = 46249 },
]
[[package]]
name = "langgraph-checkpoint-postgres"
version = "3.0.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "langgraph-checkpoint" },
{ name = "orjson" },
{ name = "psycopg" },
{ name = "psycopg-pool" },
]
sdist = { url = "https://files.pythonhosted.org/packages/68/4e/ffea5b0d667e10d408b3b2d6dd967ea79e208eef73fe6ee5622625496238/langgraph_checkpoint_postgres-3.0.2.tar.gz", hash = "sha256:448cb8ec245b6fe10171a0f90e9aa047e24a9d3febba6a914644b0c1323da158", size = 127766 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ac/e4/b4248e10289b6e2c2d33586c87c5eb421e566ef5f336ee45269223cc3b92/langgraph_checkpoint_postgres-3.0.2-py3-none-any.whl", hash = "sha256:15c0fb638edfbc54d496f1758d0327d1a081e0ef94dda8f0c91d4b307d6d8545", size = 42710 },
]
[[package]]
name = "langgraph-prebuilt"
version = "1.0.5"
@ -4785,6 +4800,79 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/50/1b/6921afe68c74868b4c9fa424dad3be35b095e16687989ebbb50ce4fceb7c/psutil-7.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553", size = 244885 },
]
[[package]]
name = "psycopg"
version = "3.3.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
{ name = "tzdata", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/e0/1a/7d9ef4fdc13ef7f15b934c393edc97a35c281bb7d3c3329fbfcbe915a7c2/psycopg-3.3.2.tar.gz", hash = "sha256:707a67975ee214d200511177a6a80e56e654754c9afca06a7194ea6bbfde9ca7", size = 165630 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/8c/51/2779ccdf9305981a06b21a6b27e8547c948d85c41c76ff434192784a4c93/psycopg-3.3.2-py3-none-any.whl", hash = "sha256:3e94bc5f4690247d734599af56e51bae8e0db8e4311ea413f801fef82b14a99b", size = 212774 },
]
[package.optional-dependencies]
binary = [
{ name = "psycopg-binary", marker = "implementation_name != 'pypy'" },
]
pool = [
{ name = "psycopg-pool" },
]
[[package]]
name = "psycopg-binary"
version = "3.3.2"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/4e/1e/8614b01c549dd7e385dacdcd83fe194f6b3acb255a53cc67154ee6bf00e7/psycopg_binary-3.3.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a9387ab615f929e71ef0f4a8a51e986fa06236ccfa9f3ec98a88f60fbf230634", size = 4579832 },
{ url = "https://files.pythonhosted.org/packages/26/97/0bb093570fae2f4454d42c1ae6000f15934391867402f680254e4a7def54/psycopg_binary-3.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3ff7489df5e06c12d1829544eaec64970fe27fe300f7cf04c8495fe682064688", size = 4658786 },
{ url = "https://files.pythonhosted.org/packages/61/20/1d9383e3f2038826900a14137b0647d755f67551aab316e1021443105ed5/psycopg_binary-3.3.2-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:9742580ecc8e1ac45164e98d32ca6df90da509c2d3ff26be245d94c430f92db4", size = 5454896 },
{ url = "https://files.pythonhosted.org/packages/a6/62/513c80ad8bbb545e364f7737bf2492d34a4c05eef4f7b5c16428dc42260d/psycopg_binary-3.3.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d45acedcaa58619355f18e0f42af542fcad3fd84ace4b8355d3a5dea23318578", size = 5132731 },
{ url = "https://files.pythonhosted.org/packages/f3/28/ddf5f5905f088024bccb19857949467407c693389a14feb527d6171d8215/psycopg_binary-3.3.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d88f32ff8c47cb7f4e7e7a9d1747dcee6f3baa19ed9afa9e5694fd2fb32b61ed", size = 6724495 },
{ url = "https://files.pythonhosted.org/packages/6e/93/a1157ebcc650960b264542b547f7914d87a42ff0cc15a7584b29d5807e6b/psycopg_binary-3.3.2-cp312-cp312-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:59d0163c4617a2c577cb34afbed93d7a45b8c8364e54b2bd2020ff25d5f5f860", size = 4964979 },
{ url = "https://files.pythonhosted.org/packages/0e/27/65939ba6798f9c5be4a5d9cd2061ebaf0851798525c6811d347821c8132d/psycopg_binary-3.3.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e750afe74e6c17b2c7046d2c3e3173b5a3f6080084671c8aa327215323df155b", size = 4493648 },
{ url = "https://files.pythonhosted.org/packages/8a/c4/5e9e4b9b1c1e27026e43387b0ba4aaf3537c7806465dd3f1d5bde631752a/psycopg_binary-3.3.2-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:f26f113013c4dcfbfe9ced57b5bad2035dda1a7349f64bf726021968f9bccad3", size = 4173392 },
{ url = "https://files.pythonhosted.org/packages/c6/81/cf43fb76993190cee9af1cbcfe28afb47b1928bdf45a252001017e5af26e/psycopg_binary-3.3.2-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:8309ee4569dced5e81df5aa2dcd48c7340c8dee603a66430f042dfbd2878edca", size = 3909241 },
{ url = "https://files.pythonhosted.org/packages/9d/20/c6377a0d17434674351627489deca493ea0b137c522b99c81d3a106372c8/psycopg_binary-3.3.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c6464150e25b68ae3cb04c4e57496ea11ebfaae4d98126aea2f4702dd43e3c12", size = 4219746 },
{ url = "https://files.pythonhosted.org/packages/25/32/716c57b28eefe02a57a4c9d5bf956849597f5ea476c7010397199e56cfde/psycopg_binary-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:716a586f99bbe4f710dc58b40069fcb33c7627e95cc6fc936f73c9235e07f9cf", size = 3537494 },
{ url = "https://files.pythonhosted.org/packages/14/73/7ca7cb22b9ac7393fb5de7d28ca97e8347c375c8498b3bff2c99c1f38038/psycopg_binary-3.3.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:fc5a189e89cbfff174588665bb18d28d2d0428366cc9dae5864afcaa2e57380b", size = 4579068 },
{ url = "https://files.pythonhosted.org/packages/f5/42/0cf38ff6c62c792fc5b55398a853a77663210ebd51ed6f0c4a05b06f95a6/psycopg_binary-3.3.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:083c2e182be433f290dc2c516fd72b9b47054fcd305cce791e0a50d9e93e06f2", size = 4657520 },
{ url = "https://files.pythonhosted.org/packages/3b/60/df846bc84cbf2231e01b0fff48b09841fe486fa177665e50f4995b1bfa44/psycopg_binary-3.3.2-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:ac230e3643d1c436a2dfb59ca84357dfc6862c9f372fc5dbd96bafecae581f9f", size = 5452086 },
{ url = "https://files.pythonhosted.org/packages/ab/85/30c846a00db86b1b53fd5bfd4b4edfbd0c00de8f2c75dd105610bd7568fc/psycopg_binary-3.3.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d8c899a540f6c7585cee53cddc929dd4d2db90fd828e37f5d4017b63acbc1a5d", size = 5131125 },
{ url = "https://files.pythonhosted.org/packages/6d/15/9968732013373f36f8a2a3fb76104dffc8efd9db78709caa5ae1a87b1f80/psycopg_binary-3.3.2-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:50ff10ab8c0abdb5a5451b9315538865b50ba64c907742a1385fdf5f5772b73e", size = 6722914 },
{ url = "https://files.pythonhosted.org/packages/b2/ba/29e361fe02143ac5ff5a1ca3e45697344cfbebe2eaf8c4e7eec164bff9a0/psycopg_binary-3.3.2-cp313-cp313-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:23d2594af848c1fd3d874a9364bef50730124e72df7bb145a20cb45e728c50ed", size = 4966081 },
{ url = "https://files.pythonhosted.org/packages/99/45/1be90c8f1a1a237046903e91202fb06708745c179f220b361d6333ed7641/psycopg_binary-3.3.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ea4fe6b4ead3bbbe27244ea224fcd1f53cb119afc38b71a2f3ce570149a03e30", size = 4493332 },
{ url = "https://files.pythonhosted.org/packages/2e/b5/bbdc07d5f0a5e90c617abd624368182aa131485e18038b2c6c85fc054aed/psycopg_binary-3.3.2-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:742ce48cde825b8e52fb1a658253d6d1ff66d152081cbc76aa45e2986534858d", size = 4170781 },
{ url = "https://files.pythonhosted.org/packages/d1/2a/0d45e4f4da2bd78c3237ffa03475ef3751f69a81919c54a6e610eb1a7c96/psycopg_binary-3.3.2-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:e22bf6b54df994aff37ab52695d635f1ef73155e781eee1f5fa75bc08b58c8da", size = 3910544 },
{ url = "https://files.pythonhosted.org/packages/3a/62/a8e0f092f4dbef9a94b032fb71e214cf0a375010692fbe7493a766339e47/psycopg_binary-3.3.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8db9034cde3bcdafc66980f0130813f5c5d19e74b3f2a19fb3cfbc25ad113121", size = 4220070 },
{ url = "https://files.pythonhosted.org/packages/09/e6/5fc8d8aff8afa114bb4a94a0341b9309311e8bf3ab32d816032f8b984d4e/psycopg_binary-3.3.2-cp313-cp313-win_amd64.whl", hash = "sha256:df65174c7cf6b05ea273ce955927d3270b3a6e27b0b12762b009ce6082b8d3fc", size = 3540922 },
{ url = "https://files.pythonhosted.org/packages/bd/75/ad18c0b97b852aba286d06befb398cc6d383e9dfd0a518369af275a5a526/psycopg_binary-3.3.2-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:9ca24062cd9b2270e4d77576042e9cc2b1d543f09da5aba1f1a3d016cea28390", size = 4596371 },
{ url = "https://files.pythonhosted.org/packages/5a/79/91649d94c8d89f84af5da7c9d474bfba35b08eb8f492ca3422b08f0a6427/psycopg_binary-3.3.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:c749770da0947bc972e512f35366dd4950c0e34afad89e60b9787a37e97cb443", size = 4675139 },
{ url = "https://files.pythonhosted.org/packages/56/ac/b26e004880f054549ec9396594e1ffe435810b0673e428e619ed722e4244/psycopg_binary-3.3.2-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:03b7cd73fb8c45d272a34ae7249713e32492891492681e3cf11dff9531cf37e9", size = 5456120 },
{ url = "https://files.pythonhosted.org/packages/4b/8d/410681dccd6f2999fb115cc248521ec50dd2b0aba66ae8de7e81efdebbee/psycopg_binary-3.3.2-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:43b130e3b6edcb5ee856c7167ccb8561b473308c870ed83978ae478613764f1c", size = 5133484 },
{ url = "https://files.pythonhosted.org/packages/66/30/ebbab99ea2cfa099d7b11b742ce13415d44f800555bfa4ad2911dc645b71/psycopg_binary-3.3.2-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7c1feba5a8c617922321aef945865334e468337b8fc5c73074f5e63143013b5a", size = 6731818 },
{ url = "https://files.pythonhosted.org/packages/70/02/d260646253b7ad805d60e0de47f9b811d6544078452579466a098598b6f4/psycopg_binary-3.3.2-cp314-cp314-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:cabb2a554d9a0a6bf84037d86ca91782f087dfff2a61298d0b00c19c0bc43f6d", size = 4983859 },
{ url = "https://files.pythonhosted.org/packages/72/8d/e778d7bad1a7910aa36281f092bd85c5702f508fd9bb0ea2020ffbb6585c/psycopg_binary-3.3.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:74bc306c4b4df35b09bc8cecf806b271e1c5d708f7900145e4e54a2e5dedfed0", size = 4516388 },
{ url = "https://files.pythonhosted.org/packages/bd/f1/64e82098722e2ab3521797584caf515284be09c1e08a872551b6edbb0074/psycopg_binary-3.3.2-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:d79b0093f0fbf7a962d6a46ae292dc056c65d16a8ee9361f3cfbafd4c197ab14", size = 4192382 },
{ url = "https://files.pythonhosted.org/packages/fa/d0/c20f4e668e89494972e551c31be2a0016e3f50d552d7ae9ac07086407599/psycopg_binary-3.3.2-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:1586e220be05547c77afc326741dd41cc7fba38a81f9931f616ae98865439678", size = 3928660 },
{ url = "https://files.pythonhosted.org/packages/0f/e1/99746c171de22539fd5eb1c9ca21dc805b54cfae502d7451d237d1dbc349/psycopg_binary-3.3.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:458696a5fa5dad5b6fb5d5862c22454434ce4fe1cf66ca6c0de5f904cbc1ae3e", size = 4239169 },
{ url = "https://files.pythonhosted.org/packages/72/f7/212343c1c9cfac35fd943c527af85e9091d633176e2a407a0797856ff7b9/psycopg_binary-3.3.2-cp314-cp314-win_amd64.whl", hash = "sha256:04bb2de4ba69d6f8395b446ede795e8884c040ec71d01dd07ac2b2d18d4153d1", size = 3642122 },
]
[[package]]
name = "psycopg-pool"
version = "3.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/56/9a/9470d013d0d50af0da9c4251614aeb3c1823635cab3edc211e3839db0bcf/psycopg_pool-3.3.0.tar.gz", hash = "sha256:fa115eb2860bd88fce1717d75611f41490dec6135efb619611142b24da3f6db5", size = 31606 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e7/c3/26b8a0908a9db249de3b4169692e1c7c19048a9bc41a4d3209cee7dbb758/psycopg_pool-3.3.0-py3-none-any.whl", hash = "sha256:2e44329155c410b5e8666372db44276a8b1ebd8c90f1c3026ebba40d4bc81063", size = 39995 },
]
[[package]]
name = "psycopg2-binary"
version = "2.9.11"
@ -6293,6 +6381,7 @@ dependencies = [
{ name = "langchain-litellm" },
{ name = "langchain-unstructured" },
{ name = "langgraph" },
{ name = "langgraph-checkpoint-postgres" },
{ name = "linkup-sdk" },
{ name = "litellm" },
{ name = "llama-cloud-services" },
@ -6301,6 +6390,7 @@ dependencies = [
{ name = "numpy" },
{ name = "pgvector" },
{ name = "playwright" },
{ name = "psycopg", extra = ["binary", "pool"] },
{ name = "pypdf" },
{ name = "python-ffmpeg" },
{ name = "redis" },
@ -6351,6 +6441,7 @@ requires-dist = [
{ name = "langchain-litellm", specifier = ">=0.3.5" },
{ name = "langchain-unstructured", specifier = ">=1.0.0" },
{ name = "langgraph", specifier = ">=1.0.5" },
{ name = "langgraph-checkpoint-postgres", specifier = ">=3.0.2" },
{ name = "linkup-sdk", specifier = ">=0.2.4" },
{ name = "litellm", specifier = ">=1.80.10" },
{ name = "llama-cloud-services", specifier = ">=0.6.25" },
@ -6359,6 +6450,7 @@ requires-dist = [
{ name = "numpy", specifier = ">=1.24.0" },
{ name = "pgvector", specifier = ">=0.3.6" },
{ name = "playwright", specifier = ">=1.50.0" },
{ name = "psycopg", extras = ["binary", "pool"], specifier = ">=3.3.2" },
{ name = "pypdf", specifier = ">=5.1.0" },
{ name = "python-ffmpeg", specifier = ">=2.0.12" },
{ name = "redis", specifier = ">=5.2.1" },

View file

@ -3,10 +3,8 @@
import { CTAHomepage } from "@/components/homepage/cta";
import { FeaturesBentoGrid } from "@/components/homepage/features-bento-grid";
import { FeaturesCards } from "@/components/homepage/features-card";
import { Footer } from "@/components/homepage/footer";
import { HeroSection } from "@/components/homepage/hero-section";
import ExternalIntegrations from "@/components/homepage/integrations";
import { Navbar } from "@/components/homepage/navbar";
export default function HomePage() {
return (

View file

@ -6,9 +6,9 @@ import { usersTable } from "@/app/db/schema";
// Define validation schema matching the database schema
const contactSchema = z.object({
name: z.string().min(1, "Name is required").max(255, "Name is too long"),
email: z.string().email("Invalid email address").max(255, "Email is too long"),
email: z.email("Invalid email address").max(255, "Email is too long"),
company: z.string().min(1, "Company is required").max(255, "Company name is too long"),
message: z.string().optional().default(""),
message: z.string().optional().prefault(""),
});
export async function POST(request: NextRequest) {
@ -43,7 +43,7 @@ export async function POST(request: NextRequest) {
{
success: false,
message: "Validation error",
errors: error.errors,
errors: error.issues,
},
{ status: 400 }
);

View file

@ -1,24 +1,25 @@
"use client";
import { useAtom, useAtomValue, useSetAtom } from "jotai";
import { Loader2, PanelRight } from "lucide-react";
import { AnimatePresence, motion } from "motion/react";
import { useAtomValue, useSetAtom } from "jotai";
import { Loader2 } from "lucide-react";
import { useParams, usePathname, useRouter } from "next/navigation";
import { useTranslations } from "next-intl";
import type React from "react";
import { useCallback, useEffect, useMemo, useState } from "react";
import { activeChathatUIAtom, activeChatIdAtom } from "@/atoms/chats/ui.atoms";
import { llmPreferencesAtom } from "@/atoms/llm-config/llm-config-query.atoms";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "sonner";
import { myAccessAtom } from "@/atoms/members/members-query.atoms";
import { updateLLMPreferencesMutationAtom } from "@/atoms/new-llm-config/new-llm-config-mutation.atoms";
import {
globalNewLLMConfigsAtom,
llmPreferencesAtom,
} from "@/atoms/new-llm-config/new-llm-config-query.atoms";
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
import { ChatPanelContainer } from "@/components/chat/ChatPanel/ChatPanelContainer";
import { DashboardBreadcrumb } from "@/components/dashboard-breadcrumb";
import { LanguageSwitcher } from "@/components/LanguageSwitcher";
import { AppSidebarProvider } from "@/components/sidebar/AppSidebarProvider";
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
import { Separator } from "@/components/ui/separator";
import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
import { cn } from "@/lib/utils";
export function DashboardClientLayout({
children,
@ -34,43 +35,27 @@ export function DashboardClientLayout({
const t = useTranslations("dashboard");
const router = useRouter();
const pathname = usePathname();
const searchSpaceIdNum = Number(searchSpaceId);
const { search_space_id, chat_id } = useParams();
const [chatUIState, setChatUIState] = useAtom(activeChathatUIAtom);
const activeChatId = useAtomValue(activeChatIdAtom);
const { search_space_id } = useParams();
const setActiveSearchSpaceIdState = useSetAtom(activeSearchSpaceIdAtom);
const setActiveChatIdState = useSetAtom(activeChatIdAtom);
const [showIndicator, setShowIndicator] = useState(false);
const { isChatPannelOpen } = chatUIState;
// Check if we're on the researcher page
const isResearcherPage = pathname?.includes("/researcher");
// Show indicator when chat becomes active and panel is closed
useEffect(() => {
if (activeChatId && !isChatPannelOpen) {
setShowIndicator(true);
// Hide indicator after 5 seconds
const timer = setTimeout(() => setShowIndicator(false), 5000);
return () => clearTimeout(timer);
} else {
setShowIndicator(false);
}
}, [activeChatId, isChatPannelOpen]);
const { data: preferences = {}, isFetching: loading, error } = useAtomValue(llmPreferencesAtom);
const {
data: preferences = {},
isFetching: loading,
error,
refetch: refetchPreferences,
} = useAtomValue(llmPreferencesAtom);
const { data: globalConfigs = [], isFetching: globalConfigsLoading } =
useAtomValue(globalNewLLMConfigsAtom);
const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom);
const isOnboardingComplete = useCallback(() => {
return !!(
preferences.long_context_llm_id &&
preferences.fast_llm_id &&
preferences.strategic_llm_id
);
return !!(preferences.agent_llm_id && preferences.document_summary_llm_id);
}, [preferences]);
const { data: access = null, isLoading: accessLoading } = useAtomValue(myAccessAtom);
const [hasCheckedOnboarding, setHasCheckedOnboarding] = useState(false);
const [isAutoConfiguring, setIsAutoConfiguring] = useState(false);
const hasAttemptedAutoConfig = useRef(false);
// Skip onboarding check if we're already on the onboarding page
const isOnboardingPage = pathname?.includes("/onboard");
@ -115,27 +100,82 @@ export function DashboardClientLayout({
return;
}
// Wait for both preferences and access data to load
if (!loading && !accessLoading && !hasCheckedOnboarding) {
// Wait for all data to load
if (
!loading &&
!accessLoading &&
!globalConfigsLoading &&
!hasCheckedOnboarding &&
!isAutoConfiguring
) {
const onboardingComplete = isOnboardingComplete();
// Only redirect to onboarding if user is the owner and onboarding is not complete
// Invited members (non-owners) should skip onboarding and use existing config
if (!onboardingComplete && isOwner) {
router.push(`/dashboard/${searchSpaceId}/onboard`);
// If onboarding is complete, nothing to do
if (onboardingComplete) {
setHasCheckedOnboarding(true);
return;
}
// Only handle onboarding for owners
if (!isOwner) {
setHasCheckedOnboarding(true);
return;
}
// If global configs available, auto-configure without going to onboard page
if (globalConfigs.length > 0 && !hasAttemptedAutoConfig.current) {
hasAttemptedAutoConfig.current = true;
setIsAutoConfiguring(true);
const autoConfigureWithGlobal = async () => {
try {
const firstGlobalConfig = globalConfigs[0];
await updatePreferences({
search_space_id: Number(searchSpaceId),
data: {
agent_llm_id: firstGlobalConfig.id,
document_summary_llm_id: firstGlobalConfig.id,
},
});
await refetchPreferences();
toast.success("AI configured automatically!", {
description: `Using ${firstGlobalConfig.name}. Customize in Settings.`,
});
setHasCheckedOnboarding(true);
} catch (error) {
console.error("Auto-configuration failed:", error);
// Fall back to onboard page
router.push(`/dashboard/${searchSpaceId}/onboard`);
} finally {
setIsAutoConfiguring(false);
}
};
autoConfigureWithGlobal();
return;
}
// No global configs - redirect to onboard page
router.push(`/dashboard/${searchSpaceId}/onboard`);
setHasCheckedOnboarding(true);
}
}, [
loading,
accessLoading,
globalConfigsLoading,
isOnboardingComplete,
isOnboardingPage,
isOwner,
isAutoConfiguring,
globalConfigs,
router,
searchSpaceId,
hasCheckedOnboarding,
updatePreferences,
refetchPreferences,
]);
// Synchronize active search space and chat IDs with URL
@ -148,27 +188,27 @@ export function DashboardClientLayout({
: "";
if (!activeSeacrhSpaceId) return;
setActiveSearchSpaceIdState(activeSeacrhSpaceId);
}, [search_space_id]);
}, [search_space_id, setActiveSearchSpaceIdState]);
useEffect(() => {
const activeChatId =
typeof chat_id === "string"
? chat_id
: Array.isArray(chat_id) && chat_id.length > 0
? chat_id[0]
: "";
if (!activeChatId) return;
setActiveChatIdState(activeChatId);
}, [chat_id, search_space_id]);
// Show loading screen while checking onboarding status (only on first load)
if (!hasCheckedOnboarding && (loading || accessLoading) && !isOnboardingPage) {
// Show loading screen while checking onboarding status or auto-configuring
if (
(!hasCheckedOnboarding &&
(loading || accessLoading || globalConfigsLoading) &&
!isOnboardingPage) ||
isAutoConfiguring
) {
return (
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
<Card className="w-[350px] bg-background/60 backdrop-blur-sm">
<CardHeader className="pb-2">
<CardTitle className="text-xl font-medium">{t("loading_config")}</CardTitle>
<CardDescription>{t("checking_llm_prefs")}</CardDescription>
<CardTitle className="text-xl font-medium">
{isAutoConfiguring ? "Setting up AI..." : t("loading_config")}
</CardTitle>
<CardDescription>
{isAutoConfiguring
? "Auto-configuring with available settings"
: t("checking_llm_prefs")}
</CardDescription>
</CardHeader>
<CardContent className="flex justify-center py-6">
<Loader2 className="h-12 w-12 text-primary animate-spin" />
@ -212,123 +252,20 @@ export function DashboardClientLayout({
navMain={translatedNavMain}
/>
<SidebarInset className="h-full ">
<main className="flex h-full">
<div className="flex grow flex-col h-full border-r">
<header className="sticky top-0 z-50 flex h-16 shrink-0 items-center gap-2 bg-background/95 backdrop-blur supports-backdrop-filter:bg-background/60 border-b">
<div className="flex items-center justify-between w-full gap-2 px-4">
<div className="flex items-center gap-2">
<SidebarTrigger className="-ml-1" />
<Separator orientation="vertical" className="h-6" />
<DashboardBreadcrumb />
</div>
<div className="flex items-center gap-2">
<LanguageSwitcher />
{/* Only show artifacts toggle on researcher page */}
{isResearcherPage && (
<motion.div
className="relative"
animate={
showIndicator
? {
scale: [1, 1.05, 1],
}
: {}
}
transition={{
duration: 2,
repeat: showIndicator ? Number.POSITIVE_INFINITY : 0,
ease: "easeInOut",
}}
>
<motion.button
type="button"
onClick={() => {
setChatUIState((prev) => ({
...prev,
isChatPannelOpen: !isChatPannelOpen,
}));
setShowIndicator(false);
}}
className={cn(
"shrink-0 rounded-full p-2 transition-all duration-300 relative",
showIndicator
? "bg-primary/20 hover:bg-primary/30 shadow-lg shadow-primary/25"
: "hover:bg-muted",
activeChatId && !showIndicator && "hover:bg-primary/10"
)}
title="Toggle Artifacts Panel"
whileHover={{ scale: 1.05 }}
whileTap={{ scale: 0.95 }}
>
<motion.div
animate={
showIndicator
? {
rotate: [0, -10, 10, -10, 0],
}
: {}
}
transition={{
duration: 0.5,
repeat: showIndicator ? Number.POSITIVE_INFINITY : 0,
repeatDelay: 2,
}}
>
<PanelRight
className={cn(
"h-4 w-4 transition-colors",
showIndicator && "text-primary"
)}
/>
</motion.div>
</motion.button>
{/* Pulsing indicator badge */}
<AnimatePresence>
{showIndicator && (
<motion.div
initial={{ opacity: 0, scale: 0 }}
animate={{ opacity: 1, scale: 1 }}
exit={{ opacity: 0, scale: 0 }}
className="absolute -right-1 -top-1 pointer-events-none"
>
<motion.div
animate={{
scale: [1, 1.3, 1],
}}
transition={{
duration: 1.5,
repeat: Number.POSITIVE_INFINITY,
ease: "easeInOut",
}}
className="relative"
>
<div className="h-2.5 w-2.5 rounded-full bg-primary shadow-lg" />
<motion.div
animate={{
scale: [1, 2.5, 1],
opacity: [0.6, 0, 0.6],
}}
transition={{
duration: 1.5,
repeat: Number.POSITIVE_INFINITY,
ease: "easeInOut",
}}
className="absolute inset-0 h-2.5 w-2.5 rounded-full bg-primary"
/>
</motion.div>
</motion.div>
)}
</AnimatePresence>
</motion.div>
)}
</div>
<main className="flex flex-col h-full">
<header className="sticky top-0 z-50 flex h-16 shrink-0 items-center gap-2 bg-background/95 backdrop-blur supports-backdrop-filter:bg-background/60 border-b">
<div className="flex items-center justify-between w-full gap-2 px-4">
<div className="flex items-center gap-2">
<SidebarTrigger className="-ml-1" />
<Separator orientation="vertical" className="h-6" />
<DashboardBreadcrumb />
</div>
</header>
<div className="grow flex-1 overflow-auto min-h-[calc(100vh-64px)]">{children}</div>
</div>
{/* Only render chat panel on researcher page */}
{isResearcherPage && <ChatPanelContainer />}
<div className="flex items-center gap-2">
<LanguageSwitcher />
</div>
</div>
</header>
<div className="grow flex-1 overflow-auto min-h-[calc(100vh-64px)]">{children}</div>
</main>
</SidebarInset>
</SidebarProvider>

View file

@ -29,7 +29,7 @@ export default function DashboardLayout({
const customNavMain = [
{
title: "Chat",
url: `/dashboard/${search_space_id}/researcher`,
url: `/dashboard/${search_space_id}/new-chat`,
icon: "SquareTerminal",
items: [],
},

View file

@ -0,0 +1,637 @@
"use client";
import {
type AppendMessage,
AssistantRuntimeProvider,
type ThreadMessageLike,
useExternalStoreRuntime,
} from "@assistant-ui/react";
import { useParams, useRouter } from "next/navigation";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "sonner";
import { Thread } from "@/components/assistant-ui/thread";
import { ChatHeader } from "@/components/new-chat/chat-header";
import type { ThinkingStep } from "@/components/tool-ui/deepagent-thinking";
import { DisplayImageToolUI } from "@/components/tool-ui/display-image";
import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast";
import { LinkPreviewToolUI } from "@/components/tool-ui/link-preview";
import { ScrapeWebpageToolUI } from "@/components/tool-ui/scrape-webpage";
import { getBearerToken } from "@/lib/auth-utils";
import { createAttachmentAdapter, extractAttachmentContent } from "@/lib/chat/attachment-adapter";
import {
isPodcastGenerating,
looksLikePodcastRequest,
setActivePodcastTaskId,
} from "@/lib/chat/podcast-state";
import {
appendMessage,
createThread,
getThreadMessages,
type MessageRecord,
} from "@/lib/chat/thread-persistence";
/**
* Extract thinking steps from message content
*/
function extractThinkingSteps(content: unknown): ThinkingStep[] {
if (!Array.isArray(content)) return [];
const thinkingPart = content.find(
(part: unknown) =>
typeof part === "object" &&
part !== null &&
"type" in part &&
(part as { type: string }).type === "thinking-steps"
) as { type: "thinking-steps"; steps: ThinkingStep[] } | undefined;
return thinkingPart?.steps || [];
}
/**
* Convert backend message to assistant-ui ThreadMessageLike format
* Filters out 'thinking-steps' part as it's handled separately via messageThinkingSteps
*/
function convertToThreadMessage(msg: MessageRecord): ThreadMessageLike {
let content: ThreadMessageLike["content"];
if (typeof msg.content === "string") {
content = [{ type: "text", text: msg.content }];
} else if (Array.isArray(msg.content)) {
// Filter out thinking-steps part - it's handled separately via messageThinkingSteps
const filteredContent = msg.content.filter(
(part: unknown) =>
!(
typeof part === "object" &&
part !== null &&
"type" in part &&
(part as { type: string }).type === "thinking-steps"
)
);
content =
filteredContent.length > 0
? (filteredContent as ThreadMessageLike["content"])
: [{ type: "text", text: "" }];
} else {
content = [{ type: "text", text: String(msg.content) }];
}
return {
id: `msg-${msg.id}`,
role: msg.role,
content,
createdAt: new Date(msg.created_at),
};
}
/**
* Tools that should render custom UI in the chat.
*/
const TOOLS_WITH_UI = new Set([
"generate_podcast",
"link_preview",
"display_image",
"scrape_webpage",
]);
/**
* Type for thinking step data from the backend
*/
interface ThinkingStepData {
id: string;
title: string;
status: "pending" | "in_progress" | "completed";
items: string[];
}
export default function NewChatPage() {
const params = useParams();
const router = useRouter();
const [isInitializing, setIsInitializing] = useState(true);
const [threadId, setThreadId] = useState<number | null>(null);
const [messages, setMessages] = useState<ThreadMessageLike[]>([]);
const [isRunning, setIsRunning] = useState(false);
// Store thinking steps per message ID - kept separate from content to avoid
// "unsupported part type" errors from assistant-ui
const [messageThinkingSteps, setMessageThinkingSteps] = useState<Map<string, ThinkingStep[]>>(
new Map()
);
const abortControllerRef = useRef<AbortController | null>(null);
// Create the attachment adapter for file processing
const attachmentAdapter = useMemo(() => createAttachmentAdapter(), []);
// Extract search_space_id from URL params
const searchSpaceId = useMemo(() => {
const id = params.search_space_id;
const parsed = typeof id === "string" ? Number.parseInt(id, 10) : 0;
return Number.isNaN(parsed) ? 0 : parsed;
}, [params.search_space_id]);
// Extract chat_id from URL params
const urlChatId = useMemo(() => {
const id = params.chat_id;
let parsed = 0;
if (Array.isArray(id) && id.length > 0) {
parsed = Number.parseInt(id[0], 10);
} else if (typeof id === "string") {
parsed = Number.parseInt(id, 10);
}
return Number.isNaN(parsed) ? 0 : parsed;
}, [params.chat_id]);
// Initialize thread and load messages
const initializeThread = useCallback(async () => {
setIsInitializing(true);
try {
if (urlChatId > 0) {
// Thread exists - load messages
setThreadId(urlChatId);
const response = await getThreadMessages(urlChatId);
if (response.messages && response.messages.length > 0) {
const loadedMessages = response.messages.map(convertToThreadMessage);
setMessages(loadedMessages);
// Extract and restore thinking steps from persisted messages
const restoredThinkingSteps = new Map<string, ThinkingStep[]>();
for (const msg of response.messages) {
if (msg.role === "assistant") {
const steps = extractThinkingSteps(msg.content);
if (steps.length > 0) {
restoredThinkingSteps.set(`msg-${msg.id}`, steps);
}
}
}
if (restoredThinkingSteps.size > 0) {
setMessageThinkingSteps(restoredThinkingSteps);
}
}
} else {
// Create new thread
const newThread = await createThread(searchSpaceId, "New Chat");
setThreadId(newThread.id);
router.replace(`/dashboard/${searchSpaceId}/new-chat/${newThread.id}`);
}
} catch (error) {
console.error("[NewChatPage] Failed to initialize thread:", error);
// Keep threadId as null - don't use Date.now() as it creates an invalid ID
// that will cause 404 errors on subsequent API calls
setThreadId(null);
toast.error("Failed to initialize chat. Please try again.");
} finally {
setIsInitializing(false);
}
}, [urlChatId, searchSpaceId, router]);
// Initialize on mount
useEffect(() => {
initializeThread();
}, [initializeThread]);
// Cancel ongoing request
const cancelRun = useCallback(async () => {
if (abortControllerRef.current) {
abortControllerRef.current.abort();
abortControllerRef.current = null;
}
setIsRunning(false);
}, []);
// Handle new message from user
const onNew = useCallback(
async (message: AppendMessage) => {
if (!threadId) return;
// Extract user query text from content parts
let userQuery = "";
for (const part of message.content) {
if (part.type === "text") {
userQuery += part.text;
}
}
// Extract attachments from message
// AppendMessage.attachments contains the processed attachment objects (from adapter.send())
const messageAttachments: Array<Record<string, unknown>> = [];
if (message.attachments && message.attachments.length > 0) {
for (const att of message.attachments) {
messageAttachments.push(att as unknown as Record<string, unknown>);
}
}
if (!userQuery.trim() && messageAttachments.length === 0) return;
// Check if podcast is already generating
if (isPodcastGenerating() && looksLikePodcastRequest(userQuery)) {
toast.warning("A podcast is already being generated.");
return;
}
const token = getBearerToken();
if (!token) {
toast.error("Not authenticated. Please log in again.");
return;
}
// Add user message to state
const userMsgId = `msg-user-${Date.now()}`;
const userMessage: ThreadMessageLike = {
id: userMsgId,
role: "user",
content: message.content,
createdAt: new Date(),
};
setMessages((prev) => [...prev, userMessage]);
// Persist user message (don't await, fire and forget)
appendMessage(threadId, {
role: "user",
content: message.content,
}).catch((err) => console.error("Failed to persist user message:", err));
// Start streaming response
setIsRunning(true);
const controller = new AbortController();
abortControllerRef.current = controller;
// Prepare assistant message
const assistantMsgId = `msg-assistant-${Date.now()}`;
const currentThinkingSteps = new Map<string, ThinkingStepData>();
// Ordered content parts to preserve inline tool call positions
// Each part is either a text segment or a tool call
type ContentPart =
| { type: "text"; text: string }
| {
type: "tool-call";
toolCallId: string;
toolName: string;
args: Record<string, unknown>;
result?: unknown;
};
const contentParts: ContentPart[] = [];
// Track the current text segment index (for appending text deltas)
let currentTextPartIndex = -1;
// Map to track tool call indices for updating results
const toolCallIndices = new Map<string, number>();
// Helper to get or create the current text part for appending text
const appendText = (delta: string) => {
if (currentTextPartIndex >= 0 && contentParts[currentTextPartIndex]?.type === "text") {
// Append to existing text part
(contentParts[currentTextPartIndex] as { type: "text"; text: string }).text += delta;
} else {
// Create new text part
contentParts.push({ type: "text", text: delta });
currentTextPartIndex = contentParts.length - 1;
}
};
// Helper to add a tool call (this "breaks" the current text segment)
const addToolCall = (toolCallId: string, toolName: string, args: Record<string, unknown>) => {
if (TOOLS_WITH_UI.has(toolName)) {
contentParts.push({
type: "tool-call",
toolCallId,
toolName,
args,
});
toolCallIndices.set(toolCallId, contentParts.length - 1);
// Reset text part index so next text creates a new segment
currentTextPartIndex = -1;
}
};
// Helper to update a tool call's args or result
const updateToolCall = (
toolCallId: string,
update: { args?: Record<string, unknown>; result?: unknown }
) => {
const index = toolCallIndices.get(toolCallId);
if (index !== undefined && contentParts[index]?.type === "tool-call") {
const tc = contentParts[index] as ContentPart & { type: "tool-call" };
if (update.args) tc.args = update.args;
if (update.result !== undefined) tc.result = update.result;
}
};
// Helper to build content for UI (without thinking-steps to avoid assistant-ui errors)
const buildContentForUI = (): ThreadMessageLike["content"] => {
// Filter to only include text parts with content and tool-calls with UI
const filtered = contentParts.filter((part) => {
if (part.type === "text") return part.text.length > 0;
if (part.type === "tool-call") return TOOLS_WITH_UI.has(part.toolName);
return false;
});
return filtered.length > 0
? (filtered as ThreadMessageLike["content"])
: [{ type: "text", text: "" }];
};
// Helper to build content for persistence (includes thinking-steps for restoration)
const buildContentForPersistence = (): unknown[] => {
const parts: unknown[] = [];
// Include thinking steps for persistence
if (currentThinkingSteps.size > 0) {
parts.push({
type: "thinking-steps",
steps: Array.from(currentThinkingSteps.values()),
});
}
// Add content parts (filtered)
for (const part of contentParts) {
if (part.type === "text" && part.text.length > 0) {
parts.push(part);
} else if (part.type === "tool-call" && TOOLS_WITH_UI.has(part.toolName)) {
parts.push(part);
}
}
return parts.length > 0 ? parts : [{ type: "text", text: "" }];
};
// Add placeholder assistant message
setMessages((prev) => [
...prev,
{
id: assistantMsgId,
role: "assistant",
content: [{ type: "text", text: "" }],
createdAt: new Date(),
},
]);
try {
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
// Build message history for context
const messageHistory = messages
.filter((m) => m.role === "user" || m.role === "assistant")
.map((m) => {
let text = "";
for (const part of m.content) {
if (typeof part === "object" && part.type === "text" && "text" in part) {
text += part.text;
}
}
return { role: m.role, content: text };
})
.filter((m) => m.content.length > 0);
// Extract attachment content to send with the request
const attachments = extractAttachmentContent(messageAttachments);
const response = await fetch(`${backendUrl}/api/v1/new_chat`, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${token}`,
},
body: JSON.stringify({
chat_id: threadId,
user_query: userQuery.trim(),
search_space_id: searchSpaceId,
messages: messageHistory,
attachments: attachments.length > 0 ? attachments : undefined,
}),
signal: controller.signal,
});
if (!response.ok) {
throw new Error(`Backend error: ${response.status}`);
}
if (!response.body) {
throw new Error("No response body");
}
// Parse SSE stream
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = "";
try {
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const events = buffer.split(/\r?\n\r?\n/);
buffer = events.pop() || "";
for (const event of events) {
const lines = event.split(/\r?\n/);
for (const line of lines) {
if (!line.startsWith("data: ")) continue;
const data = line.slice(6).trim();
if (!data || data === "[DONE]") continue;
try {
const parsed = JSON.parse(data);
switch (parsed.type) {
case "text-delta":
appendText(parsed.delta);
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId ? { ...m, content: buildContentForUI() } : m
)
);
break;
case "tool-input-start":
// Add tool call inline - this breaks the current text segment
addToolCall(parsed.toolCallId, parsed.toolName, {});
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId ? { ...m, content: buildContentForUI() } : m
)
);
break;
case "tool-input-available": {
// Update existing tool call's args, or add if not exists
if (toolCallIndices.has(parsed.toolCallId)) {
updateToolCall(parsed.toolCallId, { args: parsed.input || {} });
} else {
addToolCall(parsed.toolCallId, parsed.toolName, parsed.input || {});
}
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId ? { ...m, content: buildContentForUI() } : m
)
);
break;
}
case "tool-output-available": {
// Update the tool call with its result
updateToolCall(parsed.toolCallId, { result: parsed.output });
// Handle podcast-specific logic
if (parsed.output?.status === "processing" && parsed.output?.task_id) {
// Check if this is a podcast tool by looking at the content part
const idx = toolCallIndices.get(parsed.toolCallId);
if (idx !== undefined) {
const part = contentParts[idx];
if (part?.type === "tool-call" && part.toolName === "generate_podcast") {
setActivePodcastTaskId(parsed.output.task_id);
}
}
}
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId ? { ...m, content: buildContentForUI() } : m
)
);
break;
}
case "data-thinking-step": {
// Handle thinking step events for chain-of-thought display
const stepData = parsed.data as ThinkingStepData;
if (stepData?.id) {
currentThinkingSteps.set(stepData.id, stepData);
// Update thinking steps state for rendering
// The ThinkingStepsScrollHandler in Thread component
// will handle auto-scrolling when this state changes
setMessageThinkingSteps((prev) => {
const newMap = new Map(prev);
newMap.set(assistantMsgId, Array.from(currentThinkingSteps.values()));
return newMap;
});
}
break;
}
case "error":
throw new Error(parsed.errorText || "Server error");
}
} catch (e) {
if (e instanceof SyntaxError) continue;
throw e;
}
}
}
}
} finally {
reader.releaseLock();
}
// Persist assistant message (with thinking steps for restoration on refresh)
const finalContent = buildContentForPersistence();
if (contentParts.length > 0) {
appendMessage(threadId, {
role: "assistant",
content: finalContent,
}).catch((err) => console.error("Failed to persist assistant message:", err));
}
} catch (error) {
if (error instanceof Error && error.name === "AbortError") {
// Request was cancelled
return;
}
console.error("[NewChatPage] Chat error:", error);
toast.error("Failed to get response. Please try again.");
// Update assistant message with error
setMessages((prev) =>
prev.map((m) =>
m.id === assistantMsgId
? {
...m,
content: [
{
type: "text",
text: "Sorry, there was an error. Please try again.",
},
],
}
: m
)
);
} finally {
setIsRunning(false);
abortControllerRef.current = null;
// Note: We no longer clear thinking steps - they persist with the message
}
},
[threadId, searchSpaceId, messages]
);
// Convert message (pass through since already in correct format)
const convertMessage = useCallback(
(message: ThreadMessageLike): ThreadMessageLike => message,
[]
);
// Handle editing a message - removes messages after the edited one and sends as new
const onEdit = useCallback(
async (message: AppendMessage) => {
// Find the message being edited by looking at the parentId
// The parentId tells us which message's response we're editing
// For now, we'll just treat edits like new messages
// A more sophisticated implementation would truncate the history
await onNew(message);
},
[onNew]
);
// Create external store runtime with attachment support
const runtime = useExternalStoreRuntime({
messages,
isRunning,
onNew,
onEdit,
convertMessage,
onCancel: cancelRun,
adapters: {
attachments: attachmentAdapter,
},
});
// Show loading state
if (isInitializing) {
return (
<div className="flex h-[calc(100vh-64px)] items-center justify-center">
<div className="text-muted-foreground">Loading chat...</div>
</div>
);
}
// Show error state if thread initialization failed
if (!threadId) {
return (
<div className="flex h-[calc(100vh-64px)] flex-col items-center justify-center gap-4">
<div className="text-destructive">Failed to initialize chat</div>
<button
type="button"
onClick={() => {
setIsInitializing(true);
initializeThread();
}}
className="rounded-md bg-primary px-4 py-2 text-primary-foreground hover:bg-primary/90"
>
Try Again
</button>
</div>
);
}
return (
<AssistantRuntimeProvider runtime={runtime}>
<GeneratePodcastToolUI />
<LinkPreviewToolUI />
<DisplayImageToolUI />
<ScrapeWebpageToolUI />
<div className="flex flex-col h-[calc(100vh-64px)] max-h-[calc(100vh-64px)] overflow-hidden">
<ChatHeader searchSpaceId={searchSpaceId} />
<div className="flex-1 min-h-0 overflow-hidden">
<Thread messageThinkingSteps={messageThinkingSteps} />
</div>
</div>
</AssistantRuntimeProvider>
);
}

View file

@ -1,312 +1,268 @@
"use client";
import { useAtomValue } from "jotai";
import { FileText, MessageSquare, UserPlus, Users } from "lucide-react";
import { Loader2 } from "lucide-react";
import { motion } from "motion/react";
import { useParams, useRouter } from "next/navigation";
import { useTranslations } from "next-intl";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { useEffect, useRef, useState } from "react";
import { toast } from "sonner";
import { updateLLMPreferencesMutationAtom } from "@/atoms/llm-config/llm-config-mutation.atoms";
import {
globalLLMConfigsAtom,
llmConfigsAtom,
createNewLLMConfigMutationAtom,
updateLLMPreferencesMutationAtom,
} from "@/atoms/new-llm-config/new-llm-config-mutation.atoms";
import {
globalNewLLMConfigsAtom,
llmPreferencesAtom,
} from "@/atoms/llm-config/llm-config-query.atoms";
import { OnboardActionCard } from "@/components/onboard/onboard-action-card";
import { OnboardAdvancedSettings } from "@/components/onboard/onboard-advanced-settings";
import { OnboardHeader } from "@/components/onboard/onboard-header";
import { OnboardLLMSetup } from "@/components/onboard/onboard-llm-setup";
import { OnboardLoading } from "@/components/onboard/onboard-loading";
import { OnboardStats } from "@/components/onboard/onboard-stats";
} from "@/atoms/new-llm-config/new-llm-config-query.atoms";
import { Logo } from "@/components/Logo";
import { LLMConfigForm, type LLMConfigFormData } from "@/components/shared/llm-config-form";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import { getBearerToken, redirectToLogin } from "@/lib/auth-utils";
const OnboardPage = () => {
const t = useTranslations("onboard");
export default function OnboardPage() {
const router = useRouter();
const params = useParams();
const searchSpaceId = Number(params.search_space_id);
// Queries
const {
data: llmConfigs = [],
isFetching: configsLoading,
refetch: refreshConfigs,
} = useAtomValue(llmConfigsAtom);
const { data: globalConfigs = [], isFetching: globalConfigsLoading } =
useAtomValue(globalLLMConfigsAtom);
const {
data: preferences = {},
isFetching: preferencesLoading,
refetch: refreshPreferences,
} = useAtomValue(llmPreferencesAtom);
const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom);
data: globalConfigs = [],
isFetching: globalConfigsLoading,
isSuccess: globalConfigsLoaded,
} = useAtomValue(globalNewLLMConfigsAtom);
const { data: preferences = {}, isFetching: preferencesLoading } =
useAtomValue(llmPreferencesAtom);
// Compute isOnboardingComplete
const isOnboardingComplete = useMemo(() => {
return !!(
preferences.long_context_llm_id &&
preferences.fast_llm_id &&
preferences.strategic_llm_id
);
}, [preferences]);
// Mutations
const { mutateAsync: createConfig, isPending: isCreating } = useAtomValue(
createNewLLMConfigMutationAtom
);
const { mutateAsync: updatePreferences, isPending: isUpdatingPreferences } = useAtomValue(
updateLLMPreferencesMutationAtom
);
// State
const [isAutoConfiguring, setIsAutoConfiguring] = useState(false);
const [autoConfigComplete, setAutoConfigComplete] = useState(false);
const [showAdvancedSettings, setShowAdvancedSettings] = useState(false);
const [showPromptSettings, setShowPromptSettings] = useState(false);
const handleRefreshPreferences = useCallback(async () => {
await refreshPreferences();
}, []);
// Track if we've already attempted auto-configuration
const hasAttemptedAutoConfig = useRef(false);
// Track if onboarding was complete on initial mount
const wasCompleteOnMount = useRef<boolean | null>(null);
const hasCheckedInitialState = useRef(false);
// Check if user is authenticated
// Check authentication
useEffect(() => {
const token = getBearerToken();
if (!token) {
// Save current path and redirect to login
redirectToLogin();
return;
}
}, []);
// Capture onboarding state on first load
// Check if onboarding is already complete
const isOnboardingComplete = preferences.agent_llm_id && preferences.document_summary_llm_id;
// If onboarding is already complete, redirect immediately
useEffect(() => {
if (
!hasCheckedInitialState.current &&
!preferencesLoading &&
!configsLoading &&
!globalConfigsLoading
) {
wasCompleteOnMount.current = isOnboardingComplete;
hasCheckedInitialState.current = true;
if (!preferencesLoading && isOnboardingComplete) {
router.push(`/dashboard/${searchSpaceId}/new-chat`);
}
}, [preferencesLoading, configsLoading, globalConfigsLoading, isOnboardingComplete]);
}, [preferencesLoading, isOnboardingComplete, router, searchSpaceId]);
// Redirect to dashboard if onboarding was already complete
// Auto-configure if global configs are available
useEffect(() => {
if (
wasCompleteOnMount.current === true &&
!preferencesLoading &&
!configsLoading &&
!globalConfigsLoading
) {
const timer = setTimeout(() => {
router.push(`/dashboard/${searchSpaceId}`);
}, 300);
return () => clearTimeout(timer);
}
}, [preferencesLoading, configsLoading, globalConfigsLoading, router, searchSpaceId]);
const autoConfigureWithGlobal = async () => {
if (hasAttemptedAutoConfig.current) return;
if (globalConfigsLoading || preferencesLoading) return;
if (!globalConfigsLoaded) return;
if (isOnboardingComplete) return;
// Auto-configure LLM roles if global configs are available
const autoConfigureLLMs = useCallback(async () => {
if (hasAttemptedAutoConfig.current) return;
if (globalConfigs.length === 0) return;
if (isOnboardingComplete) {
setAutoConfigComplete(true);
return;
}
// Only auto-configure if we have global configs
if (globalConfigs.length > 0) {
hasAttemptedAutoConfig.current = true;
setIsAutoConfiguring(true);
hasAttemptedAutoConfig.current = true;
setIsAutoConfiguring(true);
try {
const firstGlobalConfig = globalConfigs[0];
try {
const allConfigs = [...globalConfigs, ...llmConfigs];
await updatePreferences({
search_space_id: searchSpaceId,
data: {
agent_llm_id: firstGlobalConfig.id,
document_summary_llm_id: firstGlobalConfig.id,
},
});
if (allConfigs.length === 0) {
setIsAutoConfiguring(false);
return;
toast.success("AI configured automatically!", {
description: `Using ${firstGlobalConfig.name}. You can customize this later in Settings.`,
});
// Redirect to new-chat
router.push(`/dashboard/${searchSpaceId}/new-chat`);
} catch (error) {
console.error("Auto-configuration failed:", error);
toast.error("Auto-configuration failed. Please add a configuration manually.");
setIsAutoConfiguring(false);
}
}
};
// Use first available config for all roles
const defaultConfigId = allConfigs[0].id;
autoConfigureWithGlobal();
}, [
globalConfigs,
globalConfigsLoading,
globalConfigsLoaded,
preferencesLoading,
isOnboardingComplete,
updatePreferences,
searchSpaceId,
router,
]);
const newPreferences = {
long_context_llm_id: defaultConfigId,
fast_llm_id: defaultConfigId,
strategic_llm_id: defaultConfigId,
};
// Handle form submission
const handleSubmit = async (formData: LLMConfigFormData) => {
try {
// Create the config
const newConfig = await createConfig(formData);
// Auto-assign to all roles
await updatePreferences({
search_space_id: searchSpaceId,
data: newPreferences,
data: {
agent_llm_id: newConfig.id,
document_summary_llm_id: newConfig.id,
},
});
await refreshPreferences();
setAutoConfigComplete(true);
toast.success("AI models configured automatically!", {
description: "You can customize these in advanced settings.",
toast.success("Configuration created!", {
description: "Redirecting to chat...",
});
// Redirect to new-chat
router.push(`/dashboard/${searchSpaceId}/new-chat`);
} catch (error) {
console.error("Auto-configuration failed:", error);
} finally {
setIsAutoConfiguring(false);
console.error("Failed to create config:", error);
if (error instanceof Error) {
toast.error(error.message || "Failed to create configuration");
}
}
}, [globalConfigs, llmConfigs, isOnboardingComplete, updatePreferences, refreshPreferences]);
};
// Trigger auto-configuration once data is loaded
useEffect(() => {
if (!configsLoading && !globalConfigsLoading && !preferencesLoading) {
autoConfigureLLMs();
}
}, [configsLoading, globalConfigsLoading, preferencesLoading, autoConfigureLLMs]);
const allConfigs = [...globalConfigs, ...llmConfigs];
const isReady = autoConfigComplete || isOnboardingComplete;
const isSubmitting = isCreating || isUpdatingPreferences;
// Loading state
if (configsLoading || preferencesLoading || globalConfigsLoading || isAutoConfiguring) {
if (globalConfigsLoading || preferencesLoading || isAutoConfiguring) {
return (
<OnboardLoading
title={isAutoConfiguring ? "Setting up your AI assistant..." : t("loading_config")}
subtitle={
isAutoConfiguring
? "Auto-configuring optimal settings for you"
: "Please wait while we load your configuration"
}
/>
);
}
// Show LLM setup if no configs available OR if roles are not assigned yet
// This forces users to complete role assignment before seeing the final screen
if (allConfigs.length === 0 || !isOnboardingComplete) {
return (
<OnboardLLMSetup
searchSpaceId={searchSpaceId}
title={t("welcome_title")}
configTitle={
allConfigs.length === 0 ? t("setup_llm_configuration") : t("assign_llm_roles_title")
}
configDescription={
allConfigs.length === 0
? t("configure_providers_and_assign_roles")
: t("complete_role_assignment")
}
onConfigCreated={() => refreshConfigs()}
onConfigDeleted={() => refreshConfigs()}
onPreferencesUpdated={handleRefreshPreferences}
/>
);
}
// Main onboarding view
return (
<div className="min-h-screen bg-background">
<div className="flex items-center justify-center min-h-screen p-4 md:p-8">
<div className="min-h-screen bg-gradient-to-b from-background to-muted/20 flex items-center justify-center">
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.6 }}
className="w-full max-w-5xl"
initial={{ opacity: 0, scale: 0.95 }}
animate={{ opacity: 1, scale: 1 }}
className="text-center space-y-6"
>
<div className="relative">
<div className="absolute inset-0 blur-3xl bg-gradient-to-r from-violet-500/20 to-cyan-500/20 rounded-full" />
<div className="relative flex items-center justify-center w-24 h-24 mx-auto rounded-2xl bg-gradient-to-br from-violet-500 to-purple-600 shadow-2xl shadow-violet-500/25">
<Loader2 className="h-12 w-12 text-white animate-spin" />
</div>
</div>
<div className="space-y-2">
<h2 className="text-2xl font-bold tracking-tight">
{isAutoConfiguring ? "Setting up your AI..." : "Loading..."}
</h2>
<p className="text-muted-foreground">
{isAutoConfiguring
? "Auto-configuring with available settings"
: "Please wait while we check your configuration"}
</p>
</div>
<div className="flex justify-center gap-1">
{[0, 1, 2].map((i) => (
<motion.div
key={i}
className="w-2 h-2 rounded-full bg-violet-500"
animate={{ scale: [1, 1.5, 1], opacity: [0.5, 1, 0.5] }}
transition={{ duration: 1, repeat: Infinity, delay: i * 0.2 }}
/>
))}
</div>
</motion.div>
</div>
);
}
// If global configs exist but auto-config failed, show simple message
if (globalConfigs.length > 0 && !isAutoConfiguring) {
return null; // Will redirect via useEffect
}
// No global configs - show the config form
return (
<div className="min-h-screen bg-gradient-to-b from-background via-background to-muted/30">
<div className="container mx-auto px-4 py-8 md:py-12 max-w-3xl">
<motion.div
initial={{ opacity: 0, y: 20 }}
animate={{ opacity: 1, y: 0 }}
transition={{ duration: 0.5 }}
className="space-y-8"
>
{/* Header */}
<OnboardHeader
title={t("welcome_title")}
subtitle={
isReady ? "You're all set! Choose what you'd like to do next." : t("welcome_subtitle")
}
isReady={isReady}
/>
<div className="text-center space-y-4">
<motion.div
initial={{ scale: 0 }}
animate={{ scale: 1 }}
transition={{ type: "spring", delay: 0.2 }}
className="relative inline-block"
>
<Logo className="w-20 h-20 mx-auto rounded-full" />
</motion.div>
{/* Quick Stats */}
<OnboardStats
globalConfigsCount={globalConfigs.length}
userConfigsCount={llmConfigs.length}
/>
<div className="space-y-2">
<h1 className="text-3xl font-bold tracking-tight">Configure Your AI</h1>
<p className="text-muted-foreground text-lg">
Add your LLM provider to get started with SurfSense
</p>
</div>
</div>
{/* Action Cards */}
{/* Config Form */}
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ delay: 0.6 }}
className="grid grid-cols-1 md:grid-cols-3 gap-6 mb-10"
initial={{ opacity: 0, y: 20 }}
animate={{ opacity: 1, y: 0 }}
transition={{ delay: 0.3 }}
>
<OnboardActionCard
title="Start Chatting"
description="Jump right into the AI researcher and start asking questions"
icon={MessageSquare}
features={[
"AI-powered conversations",
"Research and explore topics",
"Get instant insights",
]}
buttonText="Start Chatting"
onClick={() => router.push(`/dashboard/${searchSpaceId}/researcher`)}
colorScheme="violet"
delay={0.9}
/>
<OnboardActionCard
title="Add Sources"
description="Connect your data sources to start building your knowledge base"
icon={FileText}
features={[
"Connect documents and files",
"Import from various sources",
"Build your knowledge base",
]}
buttonText="Add Sources"
onClick={() => router.push(`/dashboard/${searchSpaceId}/sources/add`)}
colorScheme="blue"
delay={0.8}
/>
<OnboardActionCard
title="Manage Team"
description="Invite team members and collaborate on your search space"
icon={Users}
features={[
"Invite team members",
"Assign roles & permissions",
"Collaborate together",
]}
buttonText="Manage Team"
onClick={() => router.push(`/dashboard/${searchSpaceId}/team`)}
colorScheme="emerald"
delay={0.7}
/>
<Card className="border-2 border-muted shadow-xl overflow-hidden">
<CardHeader className="pb-4">
<CardTitle className="text-xl">LLM Configuration</CardTitle>
</CardHeader>
<CardContent>
<LLMConfigForm
searchSpaceId={searchSpaceId}
onSubmit={handleSubmit}
isSubmitting={isSubmitting}
mode="create"
showAdvanced={true}
submitLabel="Start Using SurfSense"
initialData={{
citations_enabled: true,
use_default_system_instructions: true,
}}
/>
</CardContent>
</Card>
</motion.div>
{/* Advanced Settings */}
<OnboardAdvancedSettings
searchSpaceId={searchSpaceId}
showLLMSettings={showAdvancedSettings}
setShowLLMSettings={setShowAdvancedSettings}
showPromptSettings={showPromptSettings}
setShowPromptSettings={setShowPromptSettings}
onConfigCreated={() => refreshConfigs()}
onConfigDeleted={() => refreshConfigs()}
onPreferencesUpdated={handleRefreshPreferences}
/>
{/* Footer */}
<motion.div
{/* Footer note */}
<motion.p
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ delay: 1.1 }}
className="text-center mt-10 text-muted-foreground text-sm"
transition={{ delay: 0.5 }}
className="text-center text-sm text-muted-foreground"
>
<p>
You can always adjust these settings later in{" "}
<button
type="button"
onClick={() => router.push(`/dashboard/${searchSpaceId}/settings`)}
className="text-primary hover:underline underline-offset-2 transition-colors"
>
Settings
</button>
</p>
</motion.div>
You can add more configurations and customize settings anytime in{" "}
<button
type="button"
onClick={() => router.push(`/dashboard/${searchSpaceId}/settings`)}
className="text-violet-500 hover:underline"
>
Settings
</button>
</motion.p>
</motion.div>
</div>
</div>
);
};
export default OnboardPage;
}

View file

@ -1,24 +0,0 @@
import { Suspense } from "react";
import PodcastsPageClient from "./podcasts-client";
interface PageProps {
params: {
search_space_id: string;
};
}
export default async function PodcastsPage({ params }: PageProps) {
const { search_space_id: searchSpaceId } = await Promise.resolve(params);
return (
<Suspense
fallback={
<div className="flex items-center justify-center h-[60vh]">
<div className="h-8 w-8 animate-spin rounded-full border-4 border-primary border-t-transparent"></div>
</div>
}
>
<PodcastsPageClient searchSpaceId={searchSpaceId} />
</Suspense>
);
}

View file

@ -1,957 +0,0 @@
"use client";
import { format } from "date-fns";
import { useAtom, useAtomValue } from "jotai";
import {
Calendar,
MoreHorizontal,
Pause,
Play,
Podcast as PodcastIcon,
Search,
SkipBack,
SkipForward,
Trash2,
Volume2,
VolumeX,
X,
} from "lucide-react";
import { AnimatePresence, motion, type Variants } from "motion/react";
import Image from "next/image";
import { useEffect, useRef, useState } from "react";
import { toast } from "sonner";
import { deletePodcastMutationAtom } from "@/atoms/podcasts/podcast-mutation.atoms";
import { podcastsAtom } from "@/atoms/podcasts/podcast-query.atoms";
// UI Components
import { Button } from "@/components/ui/button";
import { Card } from "@/components/ui/card";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { Input } from "@/components/ui/input";
import {
Select,
SelectContent,
SelectGroup,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Slider } from "@/components/ui/slider";
import type { Podcast } from "@/contracts/types/podcast.types";
import { podcastsApiService } from "@/lib/apis/podcasts-api.service";
interface PodcastsPageClientProps {
searchSpaceId: string;
}
const pageVariants: Variants = {
initial: { opacity: 0 },
enter: {
opacity: 1,
transition: { duration: 0.4, ease: "easeInOut", staggerChildren: 0.1 },
},
exit: { opacity: 0, transition: { duration: 0.3, ease: "easeInOut" } },
};
const podcastCardVariants: Variants = {
initial: { scale: 0.95, y: 20, opacity: 0 },
animate: {
scale: 1,
y: 0,
opacity: 1,
transition: { type: "spring", stiffness: 300, damping: 25 },
},
exit: { scale: 0.95, y: -20, opacity: 0 },
hover: { y: -5, scale: 1.02, transition: { duration: 0.2 } },
};
const MotionCard = motion(Card);
export default function PodcastsPageClient({ searchSpaceId }: PodcastsPageClientProps) {
const [filteredPodcasts, setFilteredPodcasts] = useState<Podcast[]>([]);
const [searchQuery, setSearchQuery] = useState("");
const [sortOrder, setSortOrder] = useState<string>("newest");
const [deleteDialogOpen, setDeleteDialogOpen] = useState(false);
const [podcastToDelete, setPodcastToDelete] = useState<{
id: number;
title: string;
} | null>(null);
// Audio player state
const [currentPodcast, setCurrentPodcast] = useState<Podcast | null>(null);
const [audioSrc, setAudioSrc] = useState<string | undefined>(undefined);
const [isAudioLoading, setIsAudioLoading] = useState(false);
const [isPlaying, setIsPlaying] = useState(false);
const [currentTime, setCurrentTime] = useState(0);
const [duration, setDuration] = useState(0);
const [volume, setVolume] = useState(0.7);
const [isMuted, setIsMuted] = useState(false);
const audioRef = useRef<HTMLAudioElement | null>(null);
const currentObjectUrlRef = useRef<string | null>(null);
const [{ isPending: isDeletingPodcast, mutateAsync: deletePodcast, error: deleteError }] =
useAtom(deletePodcastMutationAtom);
const {
data: podcasts,
isLoading: isFetchingPodcasts,
error: fetchError,
} = useAtomValue(podcastsAtom);
// Add podcast image URL constant
const PODCAST_IMAGE_URL =
"https://static.vecteezy.com/system/resources/thumbnails/002/157/611/small_2x/illustrations-concept-design-podcast-channel-free-vector.jpg";
useEffect(() => {
if (isFetchingPodcasts) return;
if (fetchError) {
console.error("Error fetching podcasts:", fetchError);
setFilteredPodcasts([]);
return;
}
if (!podcasts) {
setFilteredPodcasts([]);
return;
}
setFilteredPodcasts(podcasts);
}, []);
// Filter and sort podcasts based on search query and sort order
useEffect(() => {
if (!podcasts) return;
let result = [...podcasts];
// Filter by search term
if (searchQuery) {
const query = searchQuery.toLowerCase();
result = result.filter((podcast) => podcast.title.toLowerCase().includes(query));
}
// Filter by search space
result = result.filter((podcast) => podcast.search_space_id === parseInt(searchSpaceId));
// Sort podcasts
result.sort((a, b) => {
const dateA = new Date(a.created_at).getTime();
const dateB = new Date(b.created_at).getTime();
return sortOrder === "newest" ? dateB - dateA : dateA - dateB;
});
setFilteredPodcasts(result);
}, [podcasts, searchQuery, sortOrder, searchSpaceId]);
// Cleanup object URL on unmount or when currentPodcast changes
useEffect(() => {
return () => {
if (currentObjectUrlRef.current) {
URL.revokeObjectURL(currentObjectUrlRef.current);
currentObjectUrlRef.current = null;
}
};
}, []);
// Audio player time update handler
const handleTimeUpdate = () => {
if (audioRef.current) {
setCurrentTime(audioRef.current.currentTime);
}
};
// Audio player metadata loaded handler
const handleMetadataLoaded = () => {
if (audioRef.current) {
setDuration(audioRef.current.duration);
}
};
// Play/pause toggle
const togglePlayPause = () => {
if (audioRef.current) {
if (isPlaying) {
audioRef.current.pause();
} else {
audioRef.current.play();
}
setIsPlaying(!isPlaying);
}
};
// To close player
const closePlayer = () => {
if (isPlaying) {
audioRef.current?.pause();
}
setIsPlaying(false);
setAudioSrc(undefined);
setCurrentTime(0);
setCurrentPodcast(null);
};
// Seek to position
const handleSeek = (value: number[]) => {
if (audioRef.current) {
audioRef.current.currentTime = value[0];
setCurrentTime(value[0]);
}
};
// Volume change
const handleVolumeChange = (value: number[]) => {
if (audioRef.current) {
const newVolume = value[0];
// Set volume
audioRef.current.volume = newVolume;
setVolume(newVolume);
// Handle mute state based on volume
if (newVolume === 0) {
audioRef.current.muted = true;
setIsMuted(true);
} else {
audioRef.current.muted = false;
setIsMuted(false);
}
}
};
// Toggle mute
const toggleMute = () => {
if (audioRef.current) {
const newMutedState = !isMuted;
audioRef.current.muted = newMutedState;
setIsMuted(newMutedState);
// If unmuting, restore previous volume if it was 0
if (!newMutedState && volume === 0) {
const restoredVolume = 0.5;
audioRef.current.volume = restoredVolume;
setVolume(restoredVolume);
}
}
};
// Skip forward 10 seconds
const skipForward = () => {
if (audioRef.current) {
audioRef.current.currentTime = Math.min(
audioRef.current.duration,
audioRef.current.currentTime + 10
);
}
};
// Skip backward 10 seconds
const skipBackward = () => {
if (audioRef.current) {
audioRef.current.currentTime = Math.max(0, audioRef.current.currentTime - 10);
}
};
// Format time in MM:SS
const formatTime = (time: number) => {
const minutes = Math.floor(time / 60);
const seconds = Math.floor(time % 60);
return `${minutes}:${seconds < 10 ? "0" : ""}${seconds}`;
};
// Play podcast - Fetch blob and set object URL
const playPodcast = async (podcast: Podcast) => {
// If the same podcast is selected, just toggle play/pause
if (currentPodcast && currentPodcast.id === podcast.id) {
togglePlayPause();
return;
}
// Prevent multiple simultaneous loading requests
if (isAudioLoading) {
return;
}
try {
// Reset player state and show loading
setCurrentPodcast(podcast);
setAudioSrc(undefined);
setCurrentTime(0);
setDuration(0);
setIsPlaying(false);
setIsAudioLoading(true);
// Revoke previous object URL if exists (only after we've started the new request)
if (currentObjectUrlRef.current) {
URL.revokeObjectURL(currentObjectUrlRef.current);
currentObjectUrlRef.current = null;
}
// Use AbortController to handle timeout or cancellation
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), 30000); // 30 second timeout
try {
const response = await podcastsApiService.loadPodcast({
request: { id: podcast.id },
controller,
});
const objectUrl = URL.createObjectURL(response);
currentObjectUrlRef.current = objectUrl;
// Set audio source
setAudioSrc(objectUrl);
// Wait for the audio to be ready before playing
// We'll handle actual playback in the onLoadedData event instead of here
} catch (error) {
if (error instanceof DOMException && error.name === "AbortError") {
throw new Error("Request timed out. Please try again.");
}
throw error;
} finally {
clearTimeout(timeoutId);
}
} catch (error) {
console.error("Error fetching or playing podcast:", error);
toast.error(error instanceof Error ? error.message : "Failed to load podcast audio.");
// Reset state on error
setCurrentPodcast(null);
setAudioSrc(undefined);
} finally {
setIsAudioLoading(false);
}
};
// Function to handle podcast deletion
const handleDeletePodcast = async () => {
if (!podcastToDelete) return;
try {
await deletePodcast({ id: podcastToDelete.id });
// Close dialog
setDeleteDialogOpen(false);
setPodcastToDelete(null);
// If the current playing podcast is deleted, stop playback
if (currentPodcast && currentPodcast.id === podcastToDelete.id) {
if (audioRef.current) {
audioRef.current.pause();
}
setCurrentPodcast(null);
setIsPlaying(false);
}
} catch (error) {
console.error("Error deleting podcast:", error);
toast.error(error instanceof Error ? error.message : "Failed to delete podcast");
}
};
return (
<motion.div
className="container p-6 mx-auto"
initial="initial"
animate="enter"
exit="exit"
variants={pageVariants}
>
<div className="flex flex-col space-y-4 md:space-y-6">
<div className="flex flex-col space-y-2">
<h1 className="text-3xl font-bold tracking-tight">Podcasts</h1>
<p className="text-muted-foreground">Listen to generated podcasts.</p>
</div>
{/* Filter and Search Bar */}
<div className="flex flex-col space-y-4 md:flex-row md:items-center md:justify-between md:space-y-0">
<div className="flex flex-1 items-center gap-2">
<div className="relative w-full md:w-80">
<Search className="absolute left-2.5 top-2.5 h-4 w-4 text-muted-foreground" />
<Input
type="text"
placeholder="Search podcasts..."
className="pl-8"
value={searchQuery}
onChange={(e) => setSearchQuery(e.target.value)}
/>
</div>
</div>
<div>
<Select value={sortOrder} onValueChange={setSortOrder}>
<SelectTrigger className="w-40">
<SelectValue placeholder="Sort order" />
</SelectTrigger>
<SelectContent>
<SelectGroup>
<SelectItem value="newest">Newest First</SelectItem>
<SelectItem value="oldest">Oldest First</SelectItem>
</SelectGroup>
</SelectContent>
</Select>
</div>
</div>
{/* Status Messages */}
{isFetchingPodcasts && (
<div className="flex items-center justify-center h-40">
<div className="flex flex-col items-center gap-2">
<div className="h-8 w-8 animate-spin rounded-full border-4 border-primary border-t-transparent"></div>
<p className="text-sm text-muted-foreground">Loading podcasts...</p>
</div>
</div>
)}
{fetchError && !isFetchingPodcasts && (
<div className="border border-destructive/50 text-destructive p-4 rounded-md">
<h3 className="font-medium">Error loading podcasts</h3>
<p className="text-sm">{fetchError.message ?? "Failed to load podcasts"}</p>
</div>
)}
{!isFetchingPodcasts && !fetchError && filteredPodcasts.length === 0 && (
<div className="flex flex-col items-center justify-center h-40 gap-2 text-center">
<PodcastIcon className="h-8 w-8 text-muted-foreground" />
<h3 className="font-medium">No podcasts found</h3>
<p className="text-sm text-muted-foreground">
{searchQuery
? "Try adjusting your search filters"
: "Generate podcasts from your chats to get started"}
</p>
</div>
)}
{/* Podcast Grid */}
{!isFetchingPodcasts && !fetchError && filteredPodcasts.length > 0 && (
<AnimatePresence mode="wait">
<motion.div
className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-6"
variants={pageVariants}
initial="initial"
animate="enter"
exit="exit"
>
{filteredPodcasts.map((podcast, index) => (
<MotionCard
key={podcast.id}
variants={podcastCardVariants}
initial="initial"
animate="animate"
exit="exit"
whileHover="hover"
transition={{ duration: 0.2, delay: index * 0.05 }}
className={`
bg-card/60 dark:bg-card/40 backdrop-blur-lg rounded-xl p-4
shadow-md hover:shadow-xl transition-all duration-300
border-border overflow-hidden cursor-pointer
${currentPodcast?.id === podcast.id ? "ring-2 ring-primary ring-offset-2 ring-offset-background" : ""}
`}
layout
onClick={() => playPodcast(podcast)}
>
<div className="relative w-full aspect-[16/10] mb-4 rounded-lg overflow-hidden">
{/* Podcast image with gradient overlay */}
<Image
src={PODCAST_IMAGE_URL}
alt="Podcast illustration"
className="w-full h-full object-cover transition-transform duration-500 group-hover:scale-105 brightness-[0.85] contrast-[1.1]"
loading="lazy"
width={100}
height={100}
/>
{/* Better overlay with gradient for improved text legibility */}
<div className="absolute inset-0 bg-gradient-to-t from-black/60 to-black/10 transition-opacity duration-300"></div>
{/* Loading indicator with improved animation */}
{currentPodcast?.id === podcast.id && isAudioLoading && (
<motion.div
className="absolute inset-0 flex items-center justify-center bg-background/60 backdrop-blur-md z-10"
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
transition={{ duration: 0.2 }}
>
<motion.div
className="flex flex-col items-center gap-3"
initial={{ scale: 0.9 }}
animate={{ scale: 1 }}
transition={{ type: "spring", damping: 20 }}
>
<div className="h-14 w-14 rounded-full border-4 border-primary/30 border-t-primary animate-spin"></div>
<p className="text-sm text-foreground font-medium">Loading podcast...</p>
</motion.div>
</motion.div>
)}
{/* Play button with animations */}
{!(currentPodcast?.id === podcast.id && (isPlaying || isAudioLoading)) && (
<motion.div
className="absolute top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 z-10"
whileHover={{ scale: 1.1 }}
whileTap={{ scale: 0.9 }}
>
<Button
variant="secondary"
size="icon"
className="h-16 w-16 rounded-full
bg-background/80 hover:bg-background/95 backdrop-blur-md
transition-all duration-200 shadow-xl border-0
flex items-center justify-center"
onClick={(e) => {
e.stopPropagation();
playPodcast(podcast);
}}
disabled={isAudioLoading}
>
<motion.div
initial={{ scale: 0.8 }}
animate={{ scale: 1 }}
transition={{
type: "spring",
stiffness: 400,
damping: 10,
}}
className="text-primary w-10 h-10 flex items-center justify-center"
>
<Play className="h-8 w-8 ml-1" />
</motion.div>
</Button>
</motion.div>
)}
{/* Pause button with animations */}
{currentPodcast?.id === podcast.id && isPlaying && !isAudioLoading && (
<motion.div
className="absolute top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 z-10"
whileHover={{ scale: 1.1 }}
whileTap={{ scale: 0.9 }}
>
<Button
variant="secondary"
size="icon"
className="h-16 w-16 rounded-full
bg-background/80 hover:bg-background/95 backdrop-blur-md
transition-all duration-200 shadow-xl border-0
flex items-center justify-center"
onClick={(e) => {
e.stopPropagation();
togglePlayPause();
}}
disabled={isAudioLoading}
>
<motion.div
initial={{ scale: 0.8 }}
animate={{ scale: 1 }}
transition={{
type: "spring",
stiffness: 400,
damping: 10,
}}
className="text-primary w-10 h-10 flex items-center justify-center"
>
<Pause className="h-8 w-8" />
</motion.div>
</Button>
</motion.div>
)}
{/* Now playing indicator */}
{currentPodcast?.id === podcast.id && !isAudioLoading && (
<div className="absolute top-2 left-2 bg-primary text-primary-foreground text-xs px-2 py-1 rounded-full z-10 flex items-center gap-1.5">
<span className="relative flex h-2 w-2">
<span className="animate-ping absolute inline-flex h-full w-full rounded-full bg-primary-foreground opacity-75"></span>
<span className="relative inline-flex rounded-full h-2 w-2 bg-primary-foreground"></span>
</span>
Now Playing
</div>
)}
</div>
<div className="mb-3 px-1">
<h3
className="text-base font-semibold text-foreground truncate"
title={podcast.title}
>
{podcast.title || "Untitled Podcast"}
</h3>
<p className="text-xs text-muted-foreground mt-0.5 flex items-center gap-1.5">
<Calendar className="h-3 w-3" />
{format(new Date(podcast.created_at), "MMM d, yyyy")}
</p>
</div>
{currentPodcast?.id === podcast.id && !isAudioLoading && (
<motion.div
className="mb-3 px-1"
initial={{ opacity: 0, y: 5 }}
animate={{ opacity: 1, y: 0 }}
transition={{ delay: 0.1 }}
>
<Button
variant="ghost"
className="h-1.5 bg-muted rounded-full cursor-pointer group relative overflow-hidden"
onClick={(e) => {
e.stopPropagation();
if (!audioRef.current || !duration) return;
const container = e.currentTarget;
const rect = container.getBoundingClientRect();
const x = e.clientX - rect.left;
const percentage = Math.max(0, Math.min(1, x / rect.width));
const newTime = percentage * duration;
handleSeek([newTime]);
}}
>
<motion.div
className="h-full bg-primary rounded-full relative"
style={{
width: `${(currentTime / duration) * 100}%`,
}}
transition={{ ease: "linear" }}
>
<motion.div
className="absolute right-0 top-1/2 -translate-y-1/2 w-3 h-3
bg-primary rounded-full shadow-md transform scale-0
group-hover:scale-100 transition-transform"
whileHover={{ scale: 1.5 }}
/>
</motion.div>
</Button>
<div className="flex justify-between mt-1.5 text-xs text-muted-foreground">
<span>{formatTime(currentTime)}</span>
<span>{formatTime(duration)}</span>
</div>
</motion.div>
)}
{currentPodcast?.id === podcast.id && !isAudioLoading && (
<motion.div
className="flex items-center justify-between px-2 mt-1"
initial={{ opacity: 0, y: 5 }}
animate={{ opacity: 1, y: 0 }}
transition={{ delay: 0.2 }}
>
<motion.div whileHover={{ scale: 1.2 }} whileTap={{ scale: 0.95 }}>
<Button
variant="ghost"
size="icon"
onClick={(e) => {
e.stopPropagation();
skipBackward();
}}
className="w-9 h-9 text-muted-foreground hover:text-primary transition-colors"
title="Rewind 10 seconds"
disabled={!duration}
>
<SkipBack className="w-5 h-5" />
</Button>
</motion.div>
<motion.div whileHover={{ scale: 1.2 }} whileTap={{ scale: 0.95 }}>
<Button
variant="ghost"
size="icon"
onClick={(e) => {
e.stopPropagation();
togglePlayPause();
}}
className="w-10 h-10 text-primary hover:bg-primary/10 rounded-full transition-colors"
disabled={!duration}
>
{isPlaying ? (
<Pause className="w-6 h-6" />
) : (
<Play className="w-6 h-6 ml-0.5" />
)}
</Button>
</motion.div>
<motion.div whileHover={{ scale: 1.2 }} whileTap={{ scale: 0.95 }}>
<Button
variant="ghost"
size="icon"
onClick={(e) => {
e.stopPropagation();
skipForward();
}}
className="w-9 h-9 text-muted-foreground hover:text-primary transition-colors"
title="Forward 10 seconds"
disabled={!duration}
>
<SkipForward className="w-5 h-5" />
</Button>
</motion.div>
</motion.div>
)}
<div className="absolute top-2 right-2 z-20">
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button
variant="ghost"
size="icon"
className="h-7 w-7 bg-background/50 hover:bg-background/80 rounded-full backdrop-blur-sm"
onClick={(e) => e.stopPropagation()}
>
<MoreHorizontal className="h-4 w-4" />
<span className="sr-only">Open menu</span>
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
<DropdownMenuItem
className="text-destructive focus:text-destructive"
onClick={(e) => {
e.stopPropagation();
setPodcastToDelete({
id: podcast.id,
title: podcast.title,
});
setDeleteDialogOpen(true);
}}
>
<Trash2 className="mr-2 h-4 w-4" />
<span>Delete Podcast</span>
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</div>
</MotionCard>
))}
</motion.div>
</AnimatePresence>
)}
{/* Current Podcast Player (Fixed at bottom) */}
{currentPodcast && !isAudioLoading && audioSrc && (
<motion.div
initial={{ y: 100, opacity: 0 }}
animate={{ y: 0, opacity: 1 }}
exit={{ y: 100, opacity: 0 }}
transition={{ type: "spring", stiffness: 300, damping: 30 }}
className="fixed bottom-0 left-0 right-0 bg-background/95 backdrop-blur-sm border-t p-4 shadow-lg z-50"
>
<div className="container mx-auto">
<div className="flex flex-col md:flex-row items-center gap-4">
<div className="flex-shrink-0">
<motion.div
className="w-12 h-12 bg-primary/20 rounded-md flex items-center justify-center"
animate={{ scale: isPlaying ? [1, 1.05, 1] : 1 }}
transition={{
repeat: isPlaying ? Infinity : 0,
duration: 2,
}}
>
<PodcastIcon className="h-6 w-6 text-primary" />
</motion.div>
</div>
<div className="flex-grow min-w-0">
<h4 className="font-medium text-sm line-clamp-1">{currentPodcast.title}</h4>
<div className="flex items-center gap-2 mt-2">
<div className="flex-grow relative">
<Slider
value={[currentTime]}
min={0}
max={duration || 100}
step={0.1}
onValueChange={handleSeek}
className="relative z-10"
/>
<motion.div
className="absolute left-0 top-1/2 h-2 bg-primary/25 rounded-full -translate-y-1/2"
style={{
width: `${(currentTime / (duration || 100)) * 100}%`,
}}
transition={{ ease: "linear" }}
/>
</div>
<div className="flex-shrink-0 text-xs text-muted-foreground whitespace-nowrap">
{formatTime(currentTime)} / {formatTime(duration)}
</div>
</div>
</div>
<div className="flex items-center gap-2">
<motion.div whileHover={{ scale: 1.1 }} whileTap={{ scale: 0.95 }}>
<Button variant="ghost" size="icon" onClick={skipBackward} className="h-8 w-8">
<SkipBack className="h-4 w-4" />
</Button>
</motion.div>
<motion.div whileHover={{ scale: 1.1 }} whileTap={{ scale: 0.95 }}>
<Button
variant="default"
size="icon"
onClick={togglePlayPause}
className="h-10 w-10 rounded-full"
>
{isPlaying ? (
<Pause className="h-5 w-5" />
) : (
<Play className="h-5 w-5 ml-0.5" />
)}
</Button>
</motion.div>
<motion.div whileHover={{ scale: 1.1 }} whileTap={{ scale: 0.95 }}>
<Button variant="ghost" size="icon" onClick={skipForward} className="h-8 w-8">
<SkipForward className="h-4 w-4" />
</Button>
</motion.div>
<div className="hidden md:flex items-center gap-2 ml-4 w-32">
<motion.div whileHover={{ scale: 1.1 }} whileTap={{ scale: 0.95 }}>
<Button
variant="ghost"
size="icon"
onClick={toggleMute}
className={`h-8 w-8 ${isMuted ? "text-muted-foreground" : "text-primary"}`}
>
{isMuted ? (
<VolumeX className="h-4 w-4" />
) : (
<Volume2 className="h-4 w-4" />
)}
</Button>
</motion.div>
<div className="relative w-full">
<Slider
value={[isMuted ? 0 : volume]}
min={0}
max={1}
step={0.01}
onValueChange={handleVolumeChange}
className="w-full"
disabled={isMuted}
/>
<motion.div
className={`absolute left-0 bottom-0 h-1 bg-primary/30 rounded-full ${isMuted ? "opacity-50" : ""}`}
initial={false}
animate={{ width: `${(isMuted ? 0 : volume) * 100}%` }}
/>
</div>
</div>
<motion.div whileHover={{ scale: 1.1 }} whileTap={{ scale: 0.95 }}>
<Button
variant="default"
size="icon"
onClick={closePlayer}
className="h-10 w-10 rounded-full"
>
<X />
</Button>
</motion.div>
</div>
</div>
</div>
</motion.div>
)}
</div>
{/* Delete Confirmation Dialog */}
<Dialog open={deleteDialogOpen} onOpenChange={setDeleteDialogOpen}>
<DialogContent className="sm:max-w-md">
<DialogHeader>
<DialogTitle className="flex items-center gap-2">
<Trash2 className="h-5 w-5 text-destructive" />
<span>Delete Podcast</span>
</DialogTitle>
<DialogDescription>
Are you sure you want to delete{" "}
<span className="font-medium">{podcastToDelete?.title}</span>? This action cannot be
undone.
</DialogDescription>
</DialogHeader>
<DialogFooter className="flex gap-2 sm:justify-end">
<Button
variant="outline"
onClick={() => setDeleteDialogOpen(false)}
disabled={isDeletingPodcast}
>
Cancel
</Button>
<Button
variant="destructive"
onClick={handleDeletePodcast}
disabled={isDeletingPodcast}
className="gap-2"
>
{isDeletingPodcast ? (
<>
<span className="h-4 w-4 animate-spin rounded-full border-2 border-current border-t-transparent" />
Deleting...
</>
) : (
<>
<Trash2 className="h-4 w-4" />
Delete
</>
)}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
{/* Hidden audio element for playback */}
<audio
ref={audioRef}
src={audioSrc}
preload="auto"
onTimeUpdate={handleTimeUpdate}
onLoadedMetadata={handleMetadataLoaded}
onLoadedData={() => {
// Only auto-play when audio is fully loaded
if (audioRef.current && currentPodcast && audioSrc) {
// Small delay to ensure browser is ready to play
setTimeout(() => {
if (audioRef.current) {
audioRef.current
.play()
.then(() => {
setIsPlaying(true);
})
.catch((error) => {
console.error("Error playing audio:", error);
// Don't show error if it's just the user navigating away
if (error.name !== "AbortError") {
toast.error("Failed to play audio.");
}
setIsPlaying(false);
});
}
}, 100);
}
}}
onEnded={() => setIsPlaying(false)}
onError={(e) => {
console.error("Audio error:", e);
if (audioRef.current?.error) {
// Log the specific error code for debugging
console.error("Audio error code:", audioRef.current.error.code);
// Don't show error message for aborted loads
if (audioRef.current.error.code !== audioRef.current.error.MEDIA_ERR_ABORTED) {
toast.error("Error playing audio. Please try again.");
}
}
// Reset playing state on error
setIsPlaying(false);
}}
>
<track kind="captions" />
</audio>
</motion.div>
);
}

View file

@ -1,291 +0,0 @@
"use client";
import { type CreateMessage, type Message, useChat } from "@ai-sdk/react";
import { useAtom, useAtomValue } from "jotai";
import { useParams, useRouter } from "next/navigation";
import { useEffect, useMemo, useRef } from "react";
import { createChatMutationAtom, updateChatMutationAtom } from "@/atoms/chats/chat-mutation.atoms";
import { activeChatAtom } from "@/atoms/chats/chat-query.atoms";
import { activeChatIdAtom } from "@/atoms/chats/ui.atoms";
import { documentTypeCountsAtom } from "@/atoms/documents/document-query.atoms";
import ChatInterface from "@/components/chat/ChatInterface";
import type { Document } from "@/contracts/types/document.types";
import { useChatState } from "@/hooks/use-chat";
import { useSearchSourceConnectors } from "@/hooks/use-search-source-connectors";
export default function ResearcherPage() {
const { search_space_id } = useParams();
const router = useRouter();
const hasSetInitialConnectors = useRef(false);
const hasInitiatedResponse = useRef<string | null>(null);
const activeChatId = useAtomValue(activeChatIdAtom);
const { data: activeChatState, isFetching: isChatLoading } = useAtomValue(activeChatAtom);
const { mutateAsync: createChat } = useAtomValue(createChatMutationAtom);
const { mutateAsync: updateChat } = useAtomValue(updateChatMutationAtom);
const isNewChat = !activeChatId;
// Reset the flag when chat ID changes (but not hasInitiatedResponse - we need to remember if we already initiated)
useEffect(() => {
hasSetInitialConnectors.current = false;
}, [activeChatId]);
const {
token,
researchMode,
selectedConnectors,
setSelectedConnectors,
selectedDocuments,
setSelectedDocuments,
topK,
setTopK,
} = useChatState({
search_space_id: search_space_id as string,
chat_id: activeChatId ?? undefined,
});
// Fetch all available sources (document types + live search connectors)
// Use the documentTypeCountsAtom for fetching document types
const [documentTypeCountsQuery] = useAtom(documentTypeCountsAtom);
const { data: documentTypeCountsData } = documentTypeCountsQuery;
// Transform the response into the expected format
const documentTypes = useMemo(() => {
if (!documentTypeCountsData) return [];
return Object.entries(documentTypeCountsData).map(([type, count]) => ({
type,
count,
}));
}, [documentTypeCountsData]);
const { connectors: searchConnectors } = useSearchSourceConnectors(
false,
Number(search_space_id)
);
// Filter for non-indexable connectors (live search)
const liveSearchConnectors = useMemo(
() => searchConnectors.filter((connector) => !connector.is_indexable),
[searchConnectors]
);
// Memoize document IDs to prevent infinite re-renders
const documentIds = useMemo(() => {
return selectedDocuments.map((doc) => doc.id);
}, [selectedDocuments]);
// Memoize connector types to prevent infinite re-renders
const connectorTypes = useMemo(() => {
return selectedConnectors;
}, [selectedConnectors]);
// Unified localStorage management for chat state
interface ChatState {
selectedDocuments: Document[];
selectedConnectors: string[];
researchMode: "QNA"; // Always QNA mode
topK: number;
}
const getChatStateStorageKey = (searchSpaceId: string, chatId: string) =>
`surfsense_chat_state_${searchSpaceId}_${chatId}`;
const storeChatState = (searchSpaceId: string, chatId: string, state: ChatState) => {
const key = getChatStateStorageKey(searchSpaceId, chatId);
localStorage.setItem(key, JSON.stringify(state));
};
const restoreChatState = (searchSpaceId: string, chatId: string): ChatState | null => {
const key = getChatStateStorageKey(searchSpaceId, chatId);
const stored = localStorage.getItem(key);
if (stored) {
localStorage.removeItem(key); // Clean up after restoration
try {
return JSON.parse(stored);
} catch (error) {
console.error("Error parsing stored chat state:", error);
return null;
}
}
return null;
};
const handler = useChat({
api: `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/chat`,
streamProtocol: "data",
initialMessages: [],
headers: {
...(token && { Authorization: `Bearer ${token}` }),
},
body: {
data: {
search_space_id: search_space_id,
selected_connectors: connectorTypes,
research_mode: researchMode,
document_ids_to_add_in_context: documentIds,
top_k: topK,
},
},
onError: (error) => {
console.error("Chat error:", error);
},
});
const customHandlerAppend = async (
message: Message | CreateMessage,
chatRequestOptions?: { data?: any }
) => {
// Use the first message content as the chat title (truncated to 100 chars)
const messageContent = typeof message.content === "string" ? message.content : "";
const chatTitle = messageContent.slice(0, 100) || "Untitled Chat";
const newChat = await createChat({
type: researchMode,
title: chatTitle,
initial_connectors: selectedConnectors,
messages: [
{
role: "user",
content: message.content,
},
],
search_space_id: Number(search_space_id),
});
if (newChat) {
// Store chat state before navigation
storeChatState(search_space_id as string, String(newChat.id), {
selectedDocuments,
selectedConnectors,
researchMode,
topK,
});
router.replace(`/dashboard/${search_space_id}/researcher/${newChat.id}`);
}
return String(newChat.id);
};
useEffect(() => {
if (token && !isNewChat && activeChatId) {
const chatData = activeChatState?.chatDetails;
if (!chatData) return;
// Update configuration from chat data
// researchMode is always "QNA", no need to set from chat data
if (chatData.initial_connectors && Array.isArray(chatData.initial_connectors)) {
setSelectedConnectors(chatData.initial_connectors);
}
// Load existing messages
if (chatData.messages && Array.isArray(chatData.messages)) {
if (chatData.messages.length === 1 && chatData.messages[0].role === "user") {
// Single user message - append to trigger LLM response
// Only if we haven't already initiated for this chat and handler doesn't have messages yet
if (hasInitiatedResponse.current !== activeChatId && handler.messages.length === 0) {
hasInitiatedResponse.current = activeChatId;
handler.append({
role: "user",
content: chatData.messages[0].content,
});
}
} else if (chatData.messages.length > 1) {
// Multiple messages - set them all
handler.setMessages(chatData.messages);
}
}
}
}, [token, isNewChat, activeChatId, isChatLoading]);
// Restore chat state from localStorage on page load
useEffect(() => {
if (activeChatId && search_space_id) {
const restoredState = restoreChatState(search_space_id as string, activeChatId);
if (restoredState) {
setSelectedDocuments(restoredState.selectedDocuments);
setSelectedConnectors(restoredState.selectedConnectors);
setTopK(restoredState.topK);
// researchMode is always "QNA", no need to restore
}
}
}, [
activeChatId,
isChatLoading,
search_space_id,
setSelectedDocuments,
setSelectedConnectors,
setTopK,
]);
// Set all sources as default for new chats (only once on initial mount)
useEffect(() => {
if (
isNewChat &&
!hasSetInitialConnectors.current &&
selectedConnectors.length === 0 &&
documentTypes.length > 0
) {
// Combine all document types and live search connectors
const allSourceTypes = [
...documentTypes.map((dt) => dt.type),
...liveSearchConnectors.map((c) => c.connector_type),
];
if (allSourceTypes.length > 0) {
setSelectedConnectors(allSourceTypes);
hasSetInitialConnectors.current = true;
}
}
}, [
isNewChat,
documentTypes,
liveSearchConnectors,
selectedConnectors.length,
setSelectedConnectors,
]);
// Auto-update chat when messages change (only for existing chats)
useEffect(() => {
if (
!isNewChat &&
activeChatId &&
handler.status === "ready" &&
handler.messages.length > 0 &&
handler.messages[handler.messages.length - 1]?.role === "assistant"
) {
const userMessages = handler.messages.filter((msg) => msg.role === "user");
if (userMessages.length === 0) return;
const title = userMessages[0].content;
updateChat({
type: researchMode,
title: title,
initial_connectors: selectedConnectors,
messages: handler.messages,
search_space_id: Number(search_space_id),
id: Number(activeChatId),
});
}
}, [handler.messages, handler.status, activeChatId, isNewChat, isChatLoading]);
if (isChatLoading) {
return (
<div className="flex items-center justify-center h-full">
<div>Loading...</div>
</div>
);
}
return (
<ChatInterface
handler={{
...handler,
append: isNewChat ? customHandlerAppend : handler.append,
}}
onDocumentSelectionChange={setSelectedDocuments}
selectedDocuments={selectedDocuments}
onConnectorSelectionChange={setSelectedConnectors}
selectedConnectors={selectedConnectors}
topK={topK}
onTopKChange={setTopK}
/>
);
}

View file

@ -30,20 +30,20 @@ interface SettingsNavItem {
const settingsNavItems: SettingsNavItem[] = [
{
id: "models",
label: "Model Configs",
description: "Configure AI models and providers",
label: "Agent Configs",
description: "LLM models with prompts & citations",
icon: Bot,
},
{
id: "roles",
label: "LLM Roles",
description: "Manage language model roles",
label: "Role Assignments",
description: "Assign configs to agent roles",
icon: Brain,
},
{
id: "prompts",
label: "System Instructions",
description: "Customize system prompts",
description: "SearchSpace-wide AI instructions",
icon: MessageSquare,
},
];
@ -236,9 +236,6 @@ function SettingsContent({
<h1 className="text-xl md:text-2xl font-bold tracking-tight truncate">
{activeItem?.label}
</h1>
<p className="text-sm text-muted-foreground mt-0.5 truncate">
{activeItem?.description}
</p>
</div>
</div>
</motion.div>
@ -275,7 +272,7 @@ export default function SettingsPage() {
const [isSidebarOpen, setIsSidebarOpen] = useState(false);
const handleBackToApp = useCallback(() => {
router.push(`/dashboard/${searchSpaceId}/researcher`);
router.push(`/dashboard/${searchSpaceId}/new-chat`);
}, [router, searchSpaceId]);
return (

View file

@ -807,7 +807,6 @@ function RolesTab({
<DropdownMenuItem
onClick={() => {
// TODO: Implement edit role dialog/modal
console.log("Edit role not yet implemented", role);
}}
>
<Edit2 className="h-4 w-4 mr-2" />

View file

@ -244,7 +244,7 @@ const DashboardPage = () => {
/>
<div className="flex flex-col h-full justify-between overflow-hidden rounded-xl border bg-muted/30 backdrop-blur-sm transition-all hover:border-primary/50">
<div className="relative h-32 w-full overflow-hidden">
<Link href={`/dashboard/${space.id}/researcher`} key={space.id}>
<Link href={`/dashboard/${space.id}/new-chat`} key={space.id}>
<Image
src="https://images.unsplash.com/photo-1519389950473-47ba0277781c?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1740&q=80"
alt={space.name}
@ -289,7 +289,7 @@ const DashboardPage = () => {
</div>
<Link
className="flex flex-1 flex-col p-4 cursor-pointer"
href={`/dashboard/${space.id}/researcher`}
href={`/dashboard/${space.id}/new-chat`}
key={space.id}
>
<div className="flex flex-1 flex-col justify-between p-1">

View file

@ -158,3 +158,4 @@ button {
}
@source '../node_modules/@llamaindex/chat-ui/**/*.{ts,tsx}';
@source '../node_modules/streamdown/dist/*.js';

View file

@ -1,7 +1,7 @@
import type { Metadata } from "next";
import "./globals.css";
import { GoogleAnalytics } from "@next/third-parties/google";
import { RootProvider } from "fumadocs-ui/provider";
import { RootProvider } from "fumadocs-ui/provider/next";
import { Roboto } from "next/font/google";
import { I18nProvider } from "@/components/providers/I18nProvider";
import { ThemeProvider } from "@/components/theme/theme-provider";

View file

@ -1,93 +0,0 @@
import { atomWithMutation } from "jotai-tanstack-query";
import { toast } from "sonner";
import type {
ChatSummary,
CreateChatRequest,
DeleteChatRequest,
UpdateChatRequest,
} from "@/contracts/types/chat.types";
import { chatsApiService } from "@/lib/apis/chats-api.service";
import { getBearerToken } from "@/lib/auth-utils";
import { cacheKeys } from "@/lib/query-client/cache-keys";
import { queryClient } from "@/lib/query-client/client";
import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms";
import { globalChatsQueryParamsAtom } from "./ui.atoms";
export const deleteChatMutationAtom = atomWithMutation((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
const authToken = getBearerToken();
const chatsQueryParams = get(globalChatsQueryParamsAtom);
return {
mutationKey: cacheKeys.chats.globalQueryParams(chatsQueryParams),
enabled: !!searchSpaceId && !!authToken,
mutationFn: async (request: DeleteChatRequest) => {
return chatsApiService.deleteChat(request);
},
onSuccess: (_, request: DeleteChatRequest) => {
toast.success("Chat deleted successfully");
// Optimistically update the current query
queryClient.setQueryData(
cacheKeys.chats.globalQueryParams(chatsQueryParams),
(oldData: ChatSummary[]) => {
return oldData?.filter((chat) => chat.id !== request.id) ?? [];
}
);
// Invalidate all chat queries to ensure consistency across components
queryClient.invalidateQueries({
queryKey: ["chats"],
});
// Also invalidate the "all-chats" query used by AllChatsSidebar
queryClient.invalidateQueries({
queryKey: ["all-chats"],
});
},
};
});
export const createChatMutationAtom = atomWithMutation((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
const authToken = getBearerToken();
const chatsQueryParams = get(globalChatsQueryParamsAtom);
return {
mutationKey: cacheKeys.chats.globalQueryParams(chatsQueryParams),
enabled: !!searchSpaceId && !!authToken,
mutationFn: async (request: CreateChatRequest) => {
return chatsApiService.createChat(request);
},
onSuccess: () => {
// Invalidate ALL chat queries to ensure sidebar and other components refresh
// Using a partial key match to avoid stale closure issues with specific query params
queryClient.invalidateQueries({
queryKey: ["chats"],
});
// Also invalidate the "all-chats" query used by AllChatsSidebar
queryClient.invalidateQueries({
queryKey: ["all-chats"],
});
},
};
});
export const updateChatMutationAtom = atomWithMutation((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
const authToken = getBearerToken();
const chatsQueryParams = get(globalChatsQueryParamsAtom);
return {
mutationKey: cacheKeys.chats.globalQueryParams(chatsQueryParams),
enabled: !!searchSpaceId && !!authToken,
mutationFn: async (request: UpdateChatRequest) => {
return chatsApiService.updateChat(request);
},
onSuccess: () => {
queryClient.invalidateQueries({
queryKey: cacheKeys.chats.globalQueryParams(chatsQueryParams),
});
},
};
});

View file

@ -1,48 +0,0 @@
import { atomWithQuery } from "jotai-tanstack-query";
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
import { chatsApiService } from "@/lib/apis/chats-api.service";
import { podcastsApiService } from "@/lib/apis/podcasts-api.service";
import { getBearerToken } from "@/lib/auth-utils";
import { cacheKeys } from "@/lib/query-client/cache-keys";
import { activeChatIdAtom, globalChatsQueryParamsAtom } from "./ui.atoms";
export const activeChatAtom = atomWithQuery((get) => {
const activeChatId = get(activeChatIdAtom);
const authToken = getBearerToken();
return {
queryKey: cacheKeys.chats.activeChat(activeChatId ?? ""),
enabled: !!activeChatId && !!authToken,
queryFn: async () => {
if (!authToken) {
throw new Error("No authentication token found");
}
if (!activeChatId) {
throw new Error("No active chat id found");
}
const [podcast, chatDetails] = await Promise.all([
podcastsApiService.getPodcastByChatId({ chat_id: Number(activeChatId) }),
chatsApiService.getChatDetails({ id: Number(activeChatId) }),
]);
return { chatId: activeChatId, chatDetails, podcast };
},
};
});
export const chatsAtom = atomWithQuery((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
const authToken = getBearerToken();
const queryParams = get(globalChatsQueryParamsAtom);
return {
queryKey: cacheKeys.chats.globalQueryParams(queryParams),
enabled: !!searchSpaceId && !!authToken,
queryFn: async () => {
return chatsApiService.getChats({
queryParams: queryParams,
});
},
};
});

View file

@ -1,17 +0,0 @@
import { atom } from "jotai";
import type { GetChatsRequest } from "@/contracts/types/chat.types";
type ActiveChathatUIState = {
isChatPannelOpen: boolean;
};
export const activeChathatUIAtom = atom<ActiveChathatUIState>({
isChatPannelOpen: false,
});
export const activeChatIdAtom = atom<string | null>(null);
export const globalChatsQueryParamsAtom = atom<GetChatsRequest["queryParams"]>({
limit: 5,
skip: 0,
});

View file

@ -1,110 +0,0 @@
import { atomWithMutation } from "jotai-tanstack-query";
import { toast } from "sonner";
import type {
CreateLLMConfigRequest,
DeleteLLMConfigRequest,
GetLLMConfigsResponse,
UpdateLLMConfigRequest,
UpdateLLMConfigResponse,
UpdateLLMPreferencesRequest,
} from "@/contracts/types/llm-config.types";
import { llmConfigApiService } from "@/lib/apis/llm-config-api.service";
import { cacheKeys } from "@/lib/query-client/cache-keys";
import { queryClient } from "@/lib/query-client/client";
import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms";
export const createLLMConfigMutationAtom = atomWithMutation((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
return {
mutationKey: cacheKeys.llmConfigs.all(searchSpaceId!),
enabled: !!searchSpaceId,
mutationFn: async (request: CreateLLMConfigRequest) => {
return llmConfigApiService.createLLMConfig(request);
},
onSuccess: () => {
toast.success("LLM configuration created successfully");
queryClient.invalidateQueries({
queryKey: cacheKeys.llmConfigs.all(searchSpaceId!),
});
queryClient.invalidateQueries({
queryKey: cacheKeys.llmConfigs.global(),
});
},
};
});
export const updateLLMConfigMutationAtom = atomWithMutation((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
return {
mutationKey: cacheKeys.llmConfigs.all(searchSpaceId!),
enabled: !!searchSpaceId,
mutationFn: async (request: UpdateLLMConfigRequest) => {
return llmConfigApiService.updateLLMConfig(request);
},
onSuccess: (_: UpdateLLMConfigResponse, request: UpdateLLMConfigRequest) => {
toast.success("LLM configuration updated successfully");
queryClient.invalidateQueries({
queryKey: cacheKeys.llmConfigs.all(searchSpaceId!),
});
queryClient.invalidateQueries({
queryKey: cacheKeys.llmConfigs.byId(String(request.id)),
});
queryClient.invalidateQueries({
queryKey: cacheKeys.llmConfigs.global(),
});
},
};
});
export const deleteLLMConfigMutationAtom = atomWithMutation((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
const authToken = localStorage.getItem("surfsense_bearer_token");
return {
mutationKey: cacheKeys.llmConfigs.all(searchSpaceId!),
enabled: !!searchSpaceId && !!authToken,
mutationFn: async (request: DeleteLLMConfigRequest) => {
return llmConfigApiService.deleteLLMConfig(request);
},
onSuccess: (_, request: DeleteLLMConfigRequest) => {
toast.success("LLM configuration deleted successfully");
queryClient.setQueryData(
cacheKeys.llmConfigs.all(searchSpaceId!),
(oldData: GetLLMConfigsResponse | undefined) => {
if (!oldData) return oldData;
return oldData.filter((config) => config.id !== request.id);
}
);
queryClient.invalidateQueries({
queryKey: cacheKeys.llmConfigs.byId(String(request.id)),
});
queryClient.invalidateQueries({
queryKey: cacheKeys.llmConfigs.global(),
});
},
};
});
export const updateLLMPreferencesMutationAtom = atomWithMutation((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
return {
mutationKey: cacheKeys.llmConfigs.preferences(searchSpaceId!),
enabled: !!searchSpaceId,
mutationFn: async (request: UpdateLLMPreferencesRequest) => {
return llmConfigApiService.updateLLMPreferences(request);
},
onSuccess: () => {
toast.success("LLM preferences updated successfully");
queryClient.invalidateQueries({
queryKey: cacheKeys.llmConfigs.preferences(searchSpaceId!),
});
},
};
});

View file

@ -1,46 +0,0 @@
import { atomWithQuery } from "jotai-tanstack-query";
import { llmConfigApiService } from "@/lib/apis/llm-config-api.service";
import { cacheKeys } from "@/lib/query-client/cache-keys";
import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms";
export const llmConfigsAtom = atomWithQuery((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
return {
queryKey: cacheKeys.llmConfigs.all(searchSpaceId!),
enabled: !!searchSpaceId,
staleTime: 5 * 60 * 1000, // 5 minutes
queryFn: async () => {
return llmConfigApiService.getLLMConfigs({
queryParams: {
search_space_id: searchSpaceId!,
},
});
},
};
});
export const globalLLMConfigsAtom = atomWithQuery(() => {
return {
queryKey: cacheKeys.llmConfigs.global(),
staleTime: 10 * 60 * 1000, // 10 minutes
queryFn: async () => {
return llmConfigApiService.getGlobalLLMConfigs();
},
};
});
export const llmPreferencesAtom = atomWithQuery((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
return {
queryKey: cacheKeys.llmConfigs.preferences(String(searchSpaceId)),
enabled: !!searchSpaceId,
staleTime: 5 * 60 * 1000, // 5 minutes
queryFn: async () => {
return llmConfigApiService.getLLMPreferences({
search_space_id: Number(searchSpaceId),
});
},
};
});

View file

@ -0,0 +1,116 @@
import { atomWithMutation } from "jotai-tanstack-query";
import { toast } from "sonner";
import type {
CreateNewLLMConfigRequest,
DeleteNewLLMConfigRequest,
GetNewLLMConfigsResponse,
UpdateLLMPreferencesRequest,
UpdateNewLLMConfigRequest,
UpdateNewLLMConfigResponse,
} from "@/contracts/types/new-llm-config.types";
import { newLLMConfigApiService } from "@/lib/apis/new-llm-config-api.service";
import { cacheKeys } from "@/lib/query-client/cache-keys";
import { queryClient } from "@/lib/query-client/client";
import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms";
/**
* Mutation atom for creating a new NewLLMConfig
*/
export const createNewLLMConfigMutationAtom = atomWithMutation((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
return {
mutationKey: ["new-llm-configs", "create"],
enabled: !!searchSpaceId,
mutationFn: async (request: CreateNewLLMConfigRequest) => {
return newLLMConfigApiService.createConfig(request);
},
onSuccess: () => {
toast.success("Configuration created successfully");
queryClient.invalidateQueries({
queryKey: cacheKeys.newLLMConfigs.all(Number(searchSpaceId)),
});
},
onError: (error: Error) => {
toast.error(error.message || "Failed to create configuration");
},
};
});
/**
* Mutation atom for updating an existing NewLLMConfig
*/
export const updateNewLLMConfigMutationAtom = atomWithMutation((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
return {
mutationKey: ["new-llm-configs", "update"],
enabled: !!searchSpaceId,
mutationFn: async (request: UpdateNewLLMConfigRequest) => {
return newLLMConfigApiService.updateConfig(request);
},
onSuccess: (_: UpdateNewLLMConfigResponse, request: UpdateNewLLMConfigRequest) => {
toast.success("Configuration updated successfully");
queryClient.invalidateQueries({
queryKey: cacheKeys.newLLMConfigs.all(Number(searchSpaceId)),
});
queryClient.invalidateQueries({
queryKey: cacheKeys.newLLMConfigs.byId(request.id),
});
},
onError: (error: Error) => {
toast.error(error.message || "Failed to update configuration");
},
};
});
/**
* Mutation atom for deleting a NewLLMConfig
*/
export const deleteNewLLMConfigMutationAtom = atomWithMutation((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
return {
mutationKey: ["new-llm-configs", "delete"],
enabled: !!searchSpaceId,
mutationFn: async (request: DeleteNewLLMConfigRequest) => {
return newLLMConfigApiService.deleteConfig(request);
},
onSuccess: (_, request: DeleteNewLLMConfigRequest) => {
toast.success("Configuration deleted successfully");
queryClient.setQueryData(
cacheKeys.newLLMConfigs.all(Number(searchSpaceId)),
(oldData: GetNewLLMConfigsResponse | undefined) => {
if (!oldData) return oldData;
return oldData.filter((config) => config.id !== request.id);
}
);
},
onError: (error: Error) => {
toast.error(error.message || "Failed to delete configuration");
},
};
});
/**
* Mutation atom for updating LLM preferences (role assignments)
*/
export const updateLLMPreferencesMutationAtom = atomWithMutation((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
return {
mutationKey: ["llm-preferences", "update"],
enabled: !!searchSpaceId,
mutationFn: async (request: UpdateLLMPreferencesRequest) => {
return newLLMConfigApiService.updateLLMPreferences(request);
},
onSuccess: () => {
queryClient.invalidateQueries({
queryKey: cacheKeys.newLLMConfigs.preferences(Number(searchSpaceId)),
});
},
onError: (error: Error) => {
toast.error(error.message || "Failed to update LLM preferences");
},
};
});

View file

@ -0,0 +1,64 @@
import { atomWithQuery } from "jotai-tanstack-query";
import { newLLMConfigApiService } from "@/lib/apis/new-llm-config-api.service";
import { cacheKeys } from "@/lib/query-client/cache-keys";
import { activeSearchSpaceIdAtom } from "../search-spaces/search-space-query.atoms";
/**
* Query atom for fetching all NewLLMConfigs for the active search space
*/
export const newLLMConfigsAtom = atomWithQuery((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
return {
queryKey: cacheKeys.newLLMConfigs.all(Number(searchSpaceId)),
enabled: !!searchSpaceId,
staleTime: 5 * 60 * 1000, // 5 minutes
queryFn: async () => {
return newLLMConfigApiService.getConfigs({
search_space_id: Number(searchSpaceId),
});
},
};
});
/**
* Query atom for fetching global NewLLMConfigs (from YAML, negative IDs)
*/
export const globalNewLLMConfigsAtom = atomWithQuery(() => {
return {
queryKey: cacheKeys.newLLMConfigs.global(),
staleTime: 10 * 60 * 1000, // 10 minutes - global configs rarely change
queryFn: async () => {
return newLLMConfigApiService.getGlobalConfigs();
},
};
});
/**
* Query atom for fetching LLM preferences (role assignments) for the active search space
*/
export const llmPreferencesAtom = atomWithQuery((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
return {
queryKey: cacheKeys.newLLMConfigs.preferences(Number(searchSpaceId)),
enabled: !!searchSpaceId,
staleTime: 5 * 60 * 1000, // 5 minutes
queryFn: async () => {
return newLLMConfigApiService.getLLMPreferences(Number(searchSpaceId));
},
};
});
/**
* Query atom for fetching default system instructions template
*/
export const defaultSystemInstructionsAtom = atomWithQuery(() => {
return {
queryKey: cacheKeys.newLLMConfigs.defaultInstructions(),
staleTime: 60 * 60 * 1000, // 1 hour - this rarely changes
queryFn: async () => {
return newLLMConfigApiService.getDefaultSystemInstructions();
},
};
});

View file

@ -1,51 +0,0 @@
import { atomWithMutation } from "jotai-tanstack-query";
import { toast } from "sonner";
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
import type {
DeletePodcastRequest,
GeneratePodcastRequest,
Podcast,
} from "@/contracts/types/podcast.types";
import { podcastsApiService } from "@/lib/apis/podcasts-api.service";
import { getBearerToken } from "@/lib/auth-utils";
import { cacheKeys } from "@/lib/query-client/cache-keys";
import { queryClient } from "@/lib/query-client/client";
import { globalPodcastsQueryParamsAtom } from "./ui.atoms";
export const deletePodcastMutationAtom = atomWithMutation((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
const authToken = getBearerToken();
const podcastsQueryParams = get(globalPodcastsQueryParamsAtom);
return {
mutationKey: cacheKeys.podcasts.globalQueryParams(podcastsQueryParams),
enabled: !!searchSpaceId && !!authToken,
mutationFn: async (request: DeletePodcastRequest) => {
return podcastsApiService.deletePodcast(request);
},
onSuccess: (_, request: DeletePodcastRequest) => {
toast.success("Podcast deleted successfully");
queryClient.setQueryData(
cacheKeys.podcasts.globalQueryParams(podcastsQueryParams),
(oldData: Podcast[]) => {
return oldData.filter((podcast) => podcast.id !== request.id);
}
);
},
};
});
export const generatePodcastMutationAtom = atomWithMutation((get) => {
const searchSpaceId = get(activeSearchSpaceIdAtom);
const authToken = getBearerToken();
const podcastsQueryParams = get(globalPodcastsQueryParamsAtom);
return {
mutationKey: cacheKeys.podcasts.globalQueryParams(podcastsQueryParams),
enabled: !!searchSpaceId && !!authToken,
mutationFn: async (request: GeneratePodcastRequest) => {
return podcastsApiService.generatePodcast(request);
},
};
});

View file

@ -1,17 +0,0 @@
import { atomWithQuery } from "jotai-tanstack-query";
import { podcastsApiService } from "@/lib/apis/podcasts-api.service";
import { cacheKeys } from "@/lib/query-client/cache-keys";
import { globalPodcastsQueryParamsAtom } from "./ui.atoms";
export const podcastsAtom = atomWithQuery((get) => {
const queryParams = get(globalPodcastsQueryParamsAtom);
return {
queryKey: cacheKeys.podcasts.globalQueryParams(queryParams),
queryFn: async () => {
return podcastsApiService.getPodcasts({
queryParams: queryParams,
});
},
};
});

View file

@ -1,7 +0,0 @@
import { atom } from "jotai";
import type { GetPodcastsRequest } from "@/contracts/types/podcast.types";
export const globalPodcastsQueryParamsAtom = atom<GetPodcastsRequest["queryParams"]>({
limit: 5,
skip: 0,
});

View file

@ -25,13 +25,3 @@ export const searchSpacesAtom = atomWithQuery((get) => {
},
};
});
export const communityPromptsAtom = atomWithQuery(() => {
return {
queryKey: cacheKeys.searchSpaces.communityPrompts,
staleTime: 30 * 60 * 1000,
queryFn: async () => {
return searchSpacesApiService.getCommunityPrompts();
},
};
});

View file

@ -0,0 +1,252 @@
"use client";
import {
AttachmentPrimitive,
ComposerPrimitive,
MessagePrimitive,
useAssistantApi,
useAssistantState,
} from "@assistant-ui/react";
import { FileText, Loader2, PlusIcon, XIcon } from "lucide-react";
import Image from "next/image";
import { type FC, type PropsWithChildren, useEffect, useState } from "react";
import { useShallow } from "zustand/shallow";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
import { Avatar, AvatarFallback, AvatarImage } from "@/components/ui/avatar";
import { Dialog, DialogContent, DialogTitle, DialogTrigger } from "@/components/ui/dialog";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
import { cn } from "@/lib/utils";
const useFileSrc = (file: File | undefined) => {
const [src, setSrc] = useState<string | undefined>(undefined);
useEffect(() => {
if (!file) {
setSrc(undefined);
return;
}
const objectUrl = URL.createObjectURL(file);
setSrc(objectUrl);
return () => {
URL.revokeObjectURL(objectUrl);
};
}, [file]);
return src;
};
const useAttachmentSrc = () => {
const { file, src } = useAssistantState(
useShallow(({ attachment }): { file?: File; src?: string } => {
if (!attachment || attachment.type !== "image") return {};
if (attachment.file) return { file: attachment.file };
// Only try to filter if content is an array (standard assistant-ui format)
// Our custom ChatAttachment has content as a string, so skip this
if (Array.isArray(attachment.content)) {
const src = attachment.content.filter((c) => c.type === "image")[0]?.image;
if (src) return { src };
}
return {};
})
);
return useFileSrc(file) ?? src;
};
type AttachmentPreviewProps = {
src: string;
};
const AttachmentPreview: FC<AttachmentPreviewProps> = ({ src }) => {
const [isLoaded, setIsLoaded] = useState(false);
return (
<Image
src={src}
alt="Image Preview"
width={1}
height={1}
className={
isLoaded
? "aui-attachment-preview-image-loaded block h-auto max-h-[80vh] w-auto max-w-full object-contain"
: "aui-attachment-preview-image-loading hidden"
}
onLoadingComplete={() => setIsLoaded(true)}
priority={false}
/>
);
};
const AttachmentPreviewDialog: FC<PropsWithChildren> = ({ children }) => {
const src = useAttachmentSrc();
if (!src) return children;
return (
<Dialog>
<DialogTrigger
className="aui-attachment-preview-trigger cursor-pointer transition-colors hover:bg-accent/50"
asChild
>
{children}
</DialogTrigger>
<DialogContent className="aui-attachment-preview-dialog-content p-2 sm:max-w-3xl [&>button]:rounded-full [&>button]:bg-foreground/60 [&>button]:p-1 [&>button]:opacity-100 [&>button]:ring-0! [&_svg]:text-background [&>button]:hover:[&_svg]:text-destructive">
<DialogTitle className="aui-sr-only sr-only">Image Attachment Preview</DialogTitle>
<div className="aui-attachment-preview relative mx-auto flex max-h-[80dvh] w-full items-center justify-center overflow-hidden bg-background">
<AttachmentPreview src={src} />
</div>
</DialogContent>
</Dialog>
);
};
const AttachmentThumb: FC = () => {
const isImage = useAssistantState(({ attachment }) => attachment?.type === "image");
// Check if actively processing (running AND progress < 100)
// When progress is 100, processing is done but waiting for send()
const isProcessing = useAssistantState(({ attachment }) => {
const status = attachment?.status;
if (status?.type !== "running") return false;
// If progress is defined and equals 100, processing is complete
const progress = (status as { type: "running"; progress?: number }).progress;
return progress === undefined || progress < 100;
});
const src = useAttachmentSrc();
// Show loading spinner only when actively processing (not when done and waiting for send)
if (isProcessing) {
return (
<div className="flex h-full w-full items-center justify-center bg-muted">
<Loader2 className="size-6 animate-spin text-muted-foreground" />
</div>
);
}
return (
<Avatar className="aui-attachment-tile-avatar h-full w-full rounded-none">
<AvatarImage
src={src}
alt="Attachment preview"
className="aui-attachment-tile-image object-cover"
/>
<AvatarFallback delayMs={isImage ? 200 : 0}>
<FileText className="aui-attachment-tile-fallback-icon size-8 text-muted-foreground" />
</AvatarFallback>
</Avatar>
);
};
const AttachmentUI: FC = () => {
const api = useAssistantApi();
const isComposer = api.attachment.source === "composer";
const isImage = useAssistantState(({ attachment }) => attachment?.type === "image");
// Check if actively processing (running AND progress < 100)
// When progress is 100, processing is done but waiting for send()
const isProcessing = useAssistantState(({ attachment }) => {
const status = attachment?.status;
if (status?.type !== "running") return false;
const progress = (status as { type: "running"; progress?: number }).progress;
return progress === undefined || progress < 100;
});
const typeLabel = useAssistantState(({ attachment }) => {
const type = attachment?.type;
switch (type) {
case "image":
return "Image";
case "document":
return "Document";
case "file":
return "File";
default:
return "File"; // Default fallback for unknown types
}
});
return (
<Tooltip>
<AttachmentPrimitive.Root
className={cn(
"aui-attachment-root relative",
isImage && "aui-attachment-root-composer only:[&>#attachment-tile]:size-24"
)}
>
<AttachmentPreviewDialog>
<TooltipTrigger asChild>
<div
className={cn(
"aui-attachment-tile size-14 cursor-pointer overflow-hidden rounded-[14px] border bg-muted transition-opacity hover:opacity-75",
isComposer && "aui-attachment-tile-composer border-foreground/20",
isProcessing && "animate-pulse"
)}
role="button"
id="attachment-tile"
aria-label={isProcessing ? "Processing attachment..." : `${typeLabel} attachment`}
>
<AttachmentThumb />
</div>
</TooltipTrigger>
</AttachmentPreviewDialog>
{isComposer && !isProcessing && <AttachmentRemove />}
</AttachmentPrimitive.Root>
<TooltipContent side="top">
{isProcessing ? (
<span className="flex items-center gap-1.5">
<Loader2 className="size-3 animate-spin" />
Processing...
</span>
) : (
<AttachmentPrimitive.Name />
)}
</TooltipContent>
</Tooltip>
);
};
const AttachmentRemove: FC = () => {
return (
<AttachmentPrimitive.Remove asChild>
<TooltipIconButton
tooltip="Remove file"
className="aui-attachment-tile-remove absolute top-1.5 right-1.5 size-3.5 rounded-full bg-white text-muted-foreground opacity-100 shadow-sm hover:bg-white! [&_svg]:text-black hover:[&_svg]:text-destructive"
side="top"
>
<XIcon className="aui-attachment-remove-icon size-3 dark:stroke-[2.5px]" />
</TooltipIconButton>
</AttachmentPrimitive.Remove>
);
};
export const UserMessageAttachments: FC = () => {
return (
<div className="aui-user-message-attachments-end col-span-full col-start-1 row-start-1 flex w-full flex-row justify-end gap-2">
<MessagePrimitive.Attachments components={{ Attachment: AttachmentUI }} />
</div>
);
};
export const ComposerAttachments: FC = () => {
return (
<div className="aui-composer-attachments mb-2 flex w-full flex-row items-center gap-2 overflow-x-auto px-1.5 pt-0.5 pb-1 empty:hidden">
<ComposerPrimitive.Attachments components={{ Attachment: AttachmentUI }} />
</div>
);
};
export const ComposerAddAttachment: FC = () => {
return (
<ComposerPrimitive.AddAttachment asChild>
<TooltipIconButton
tooltip="Add Attachment"
side="bottom"
variant="ghost"
size="icon"
className="aui-composer-add-attachment size-[34px] rounded-full p-1 font-semibold text-xs hover:bg-muted-foreground/15 dark:border-muted-foreground/15 dark:hover:bg-muted-foreground/30"
aria-label="Add Attachment"
>
<PlusIcon className="aui-attachment-add-icon size-5 stroke-[1.5px]" />
</TooltipIconButton>
</ComposerPrimitive.AddAttachment>
);
};

View file

@ -0,0 +1,41 @@
"use client";
import type { FC } from "react";
import { useState } from "react";
import { SourceDetailPanel } from "@/components/new-chat/source-detail-panel";
interface InlineCitationProps {
chunkId: number;
citationNumber: number;
}
/**
* Inline citation component for the new chat.
* Renders a clickable numbered badge that opens the SourceDetailPanel with document chunk details.
*/
export const InlineCitation: FC<InlineCitationProps> = ({ chunkId, citationNumber }) => {
const [isOpen, setIsOpen] = useState(false);
return (
<SourceDetailPanel
open={isOpen}
onOpenChange={setIsOpen}
chunkId={chunkId}
sourceType=""
title="Source"
description=""
url=""
>
<span
onClick={() => setIsOpen(true)}
onKeyDown={(e) => e.key === "Enter" && setIsOpen(true)}
className="text-[10px] font-bold bg-primary/80 hover:bg-primary text-primary-foreground rounded-full min-w-4 h-4 px-1 inline-flex items-center justify-center align-super cursor-pointer transition-colors ml-0.5"
title={`View source #${citationNumber}`}
role="button"
tabIndex={0}
>
{citationNumber}
</span>
</SourceDetailPanel>
);
};

View file

@ -0,0 +1,325 @@
"use client";
import "@assistant-ui/react-markdown/styles/dot.css";
import {
type CodeHeaderProps,
MarkdownTextPrimitive,
unstable_memoizeMarkdownComponents as memoizeMarkdownComponents,
useIsMarkdownCodeBlock,
} from "@assistant-ui/react-markdown";
import { CheckIcon, CopyIcon } from "lucide-react";
import { type FC, memo, type ReactNode, useState } from "react";
import remarkGfm from "remark-gfm";
import { InlineCitation } from "@/components/assistant-ui/inline-citation";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
import { cn } from "@/lib/utils";
// Citation pattern: [citation:CHUNK_ID]
const CITATION_REGEX = /\[citation:(\d+)\]/g;
// Track chunk IDs to citation numbers mapping for consistent numbering
// This map is reset when a new message starts rendering
let chunkIdToCitationNumber: Map<number, number> = new Map();
let nextCitationNumber = 1;
/**
* Resets the citation counter - should be called at the start of each message
*/
export function resetCitationCounter() {
chunkIdToCitationNumber = new Map();
nextCitationNumber = 1;
}
/**
* Gets or assigns a citation number for a chunk ID
*/
function getCitationNumber(chunkId: number): number {
if (!chunkIdToCitationNumber.has(chunkId)) {
chunkIdToCitationNumber.set(chunkId, nextCitationNumber++);
}
return chunkIdToCitationNumber.get(chunkId)!;
}
/**
* Parses text and replaces [citation:XXX] patterns with InlineCitation components
*/
function parseTextWithCitations(text: string): ReactNode[] {
const parts: ReactNode[] = [];
let lastIndex = 0;
let match: RegExpExecArray | null;
let instanceIndex = 0;
// Reset regex state
CITATION_REGEX.lastIndex = 0;
while ((match = CITATION_REGEX.exec(text)) !== null) {
// Add text before the citation
if (match.index > lastIndex) {
parts.push(text.substring(lastIndex, match.index));
}
// Add the citation component
const chunkId = Number.parseInt(match[1], 10);
const citationNumber = getCitationNumber(chunkId);
parts.push(
<InlineCitation
key={`citation-${chunkId}-${instanceIndex}`}
chunkId={chunkId}
citationNumber={citationNumber}
/>
);
lastIndex = match.index + match[0].length;
instanceIndex++;
}
// Add any remaining text after the last citation
if (lastIndex < text.length) {
parts.push(text.substring(lastIndex));
}
return parts.length > 0 ? parts : [text];
}
const MarkdownTextImpl = () => {
// Reset citation counter at the start of each render
// This ensures consistent numbering as the message streams in
resetCitationCounter();
return (
<MarkdownTextPrimitive
remarkPlugins={[remarkGfm]}
className="aui-md"
components={defaultComponents}
/>
);
};
export const MarkdownText = memo(MarkdownTextImpl);
const CodeHeader: FC<CodeHeaderProps> = ({ language, code }) => {
const { isCopied, copyToClipboard } = useCopyToClipboard();
const onCopy = () => {
if (!code || isCopied) return;
copyToClipboard(code);
};
return (
<div className="aui-code-header-root mt-4 flex items-center justify-between gap-4 rounded-t-lg bg-muted-foreground/15 px-4 py-2 font-semibold text-foreground text-sm dark:bg-muted-foreground/20">
<span className="aui-code-header-language lowercase [&>span]:text-xs">{language}</span>
<TooltipIconButton tooltip="Copy" onClick={onCopy}>
{!isCopied && <CopyIcon />}
{isCopied && <CheckIcon />}
</TooltipIconButton>
</div>
);
};
const useCopyToClipboard = ({ copiedDuration = 3000 }: { copiedDuration?: number } = {}) => {
const [isCopied, setIsCopied] = useState<boolean>(false);
const copyToClipboard = (value: string) => {
if (!value) return;
navigator.clipboard.writeText(value).then(() => {
setIsCopied(true);
setTimeout(() => setIsCopied(false), copiedDuration);
});
};
return { isCopied, copyToClipboard };
};
/**
* Helper to process children and replace citation patterns with components
*/
function processChildrenWithCitations(children: ReactNode): ReactNode {
if (typeof children === "string") {
const parsed = parseTextWithCitations(children);
return parsed.length === 1 && typeof parsed[0] === "string" ? children : <>{parsed}</>;
}
if (Array.isArray(children)) {
return children.map((child, index) => {
if (typeof child === "string") {
const parsed = parseTextWithCitations(child);
return parsed.length === 1 && typeof parsed[0] === "string" ? (
child
) : (
<span key={index}>{parsed}</span>
);
}
return child;
});
}
return children;
}
const defaultComponents = memoizeMarkdownComponents({
h1: ({ className, children, ...props }) => (
<h1
className={cn(
"aui-md-h1 mb-8 scroll-m-20 font-extrabold text-4xl tracking-tight last:mb-0",
className
)}
{...props}
>
{processChildrenWithCitations(children)}
</h1>
),
h2: ({ className, children, ...props }) => (
<h2
className={cn(
"aui-md-h2 mt-8 mb-4 scroll-m-20 font-semibold text-3xl tracking-tight first:mt-0 last:mb-0",
className
)}
{...props}
>
{processChildrenWithCitations(children)}
</h2>
),
h3: ({ className, children, ...props }) => (
<h3
className={cn(
"aui-md-h3 mt-6 mb-4 scroll-m-20 font-semibold text-2xl tracking-tight first:mt-0 last:mb-0",
className
)}
{...props}
>
{processChildrenWithCitations(children)}
</h3>
),
h4: ({ className, children, ...props }) => (
<h4
className={cn(
"aui-md-h4 mt-6 mb-4 scroll-m-20 font-semibold text-xl tracking-tight first:mt-0 last:mb-0",
className
)}
{...props}
>
{processChildrenWithCitations(children)}
</h4>
),
h5: ({ className, children, ...props }) => (
<h5
className={cn("aui-md-h5 my-4 font-semibold text-lg first:mt-0 last:mb-0", className)}
{...props}
>
{processChildrenWithCitations(children)}
</h5>
),
h6: ({ className, children, ...props }) => (
<h6 className={cn("aui-md-h6 my-4 font-semibold first:mt-0 last:mb-0", className)} {...props}>
{processChildrenWithCitations(children)}
</h6>
),
p: ({ className, children, ...props }) => (
<p className={cn("aui-md-p mt-5 mb-5 leading-7 first:mt-0 last:mb-0", className)} {...props}>
{processChildrenWithCitations(children)}
</p>
),
a: ({ className, children, ...props }) => (
<a
className={cn("aui-md-a font-medium text-primary underline underline-offset-4", className)}
{...props}
>
{processChildrenWithCitations(children)}
</a>
),
blockquote: ({ className, children, ...props }) => (
<blockquote className={cn("aui-md-blockquote border-l-2 pl-6 italic", className)} {...props}>
{processChildrenWithCitations(children)}
</blockquote>
),
ul: ({ className, ...props }) => (
<ul className={cn("aui-md-ul my-5 ml-6 list-disc [&>li]:mt-2", className)} {...props} />
),
ol: ({ className, ...props }) => (
<ol className={cn("aui-md-ol my-5 ml-6 list-decimal [&>li]:mt-2", className)} {...props} />
),
li: ({ className, children, ...props }) => (
<li className={cn("aui-md-li", className)} {...props}>
{processChildrenWithCitations(children)}
</li>
),
hr: ({ className, ...props }) => (
<hr className={cn("aui-md-hr my-5 border-b", className)} {...props} />
),
table: ({ className, ...props }) => (
<table
className={cn(
"aui-md-table my-5 w-full border-separate border-spacing-0 overflow-y-auto",
className
)}
{...props}
/>
),
th: ({ className, children, ...props }) => (
<th
className={cn(
"aui-md-th bg-muted px-4 py-2 text-left font-bold first:rounded-tl-lg last:rounded-tr-lg [[align=center]]:text-center [[align=right]]:text-right",
className
)}
{...props}
>
{processChildrenWithCitations(children)}
</th>
),
td: ({ className, children, ...props }) => (
<td
className={cn(
"aui-md-td border-b border-l px-4 py-2 text-left last:border-r [[align=center]]:text-center [[align=right]]:text-right",
className
)}
{...props}
>
{processChildrenWithCitations(children)}
</td>
),
tr: ({ className, ...props }) => (
<tr
className={cn(
"aui-md-tr m-0 border-b p-0 first:border-t [&:last-child>td:first-child]:rounded-bl-lg [&:last-child>td:last-child]:rounded-br-lg",
className
)}
{...props}
/>
),
sup: ({ className, ...props }) => (
<sup className={cn("aui-md-sup [&>a]:text-xs [&>a]:no-underline", className)} {...props} />
),
pre: ({ className, ...props }) => (
<pre
className={cn(
"aui-md-pre overflow-x-auto rounded-t-none! rounded-b-lg bg-black p-4 text-white",
className
)}
{...props}
/>
),
code: function Code({ className, ...props }) {
const isCodeBlock = useIsMarkdownCodeBlock();
return (
<code
className={cn(
!isCodeBlock && "aui-md-inline-code rounded border bg-muted font-semibold",
className
)}
{...props}
/>
);
},
strong: ({ className, children, ...props }) => (
<strong className={cn("aui-md-strong font-semibold", className)} {...props}>
{processChildrenWithCitations(children)}
</strong>
),
em: ({ className, children, ...props }) => (
<em className={cn("aui-md-em", className)} {...props}>
{processChildrenWithCitations(children)}
</em>
),
CodeHeader,
});

View file

@ -0,0 +1,299 @@
"use client";
import {
ArchiveIcon,
MessageSquareIcon,
MoreVerticalIcon,
PlusIcon,
RotateCcwIcon,
TrashIcon,
} from "lucide-react";
import { useRouter } from "next/navigation";
import { useCallback, useEffect, useState } from "react";
import { Button } from "@/components/ui/button";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuSeparator,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import {
createThreadListManager,
type ThreadListItem,
type ThreadListState,
} from "@/lib/chat/thread-persistence";
import { cn } from "@/lib/utils";
interface ThreadListProps {
searchSpaceId: number;
currentThreadId?: number;
className?: string;
}
export function ThreadList({ searchSpaceId, currentThreadId, className }: ThreadListProps) {
const router = useRouter();
const [state, setState] = useState<ThreadListState>({
threads: [],
archivedThreads: [],
isLoading: true,
error: null,
});
const [showArchived, setShowArchived] = useState(false);
// Create the thread list manager
const manager = useCallback(
() =>
createThreadListManager({
searchSpaceId,
currentThreadId: currentThreadId ?? null,
onThreadSwitch: (threadId) => {
router.push(`/dashboard/${searchSpaceId}/new-chat/${threadId}`);
},
onNewThread: (threadId) => {
router.push(`/dashboard/${searchSpaceId}/new-chat/${threadId}`);
},
}),
[searchSpaceId, currentThreadId, router]
);
// Load threads on mount and when searchSpaceId changes
const loadThreads = useCallback(async () => {
setState((prev) => ({ ...prev, isLoading: true }));
const newState = await manager().loadThreads();
setState(newState);
}, [manager]);
useEffect(() => {
loadThreads();
}, [loadThreads]);
// Handle new thread creation
const handleNewThread = async () => {
await manager().createNewThread();
await loadThreads();
};
// Handle thread actions
const handleArchive = async (threadId: number) => {
const success = await manager().archiveThread(threadId);
if (success) await loadThreads();
};
const handleUnarchive = async (threadId: number) => {
const success = await manager().unarchiveThread(threadId);
if (success) await loadThreads();
};
const handleDelete = async (threadId: number) => {
const success = await manager().deleteThread(threadId);
if (success) {
await loadThreads();
// If we deleted the current thread, redirect to new chat
if (threadId === currentThreadId) {
router.push(`/dashboard/${searchSpaceId}/new-chat`);
}
}
};
const handleSwitchToThread = (threadId: number) => {
manager().switchToThread(threadId);
};
const displayedThreads = showArchived ? state.archivedThreads : state.threads;
if (state.isLoading) {
return (
<div className={cn("flex h-full flex-col", className)}>
<div className="flex items-center justify-center p-4">
<span className="text-muted-foreground text-sm">Loading threads...</span>
</div>
</div>
);
}
if (state.error) {
return (
<div className={cn("flex h-full flex-col", className)}>
<div className="p-4 text-center">
<span className="text-destructive text-sm">{state.error}</span>
<Button variant="ghost" size="sm" className="mt-2" onClick={loadThreads}>
Retry
</Button>
</div>
</div>
);
}
return (
<div className={cn("flex h-full flex-col", className)}>
{/* Header with New Chat button */}
<div className="flex items-center justify-between border-b p-3">
<h2 className="font-semibold text-sm">Conversations</h2>
<Button
variant="ghost"
size="icon"
className="size-8"
onClick={handleNewThread}
title="New Chat"
>
<PlusIcon className="size-4" />
</Button>
</div>
{/* Tab toggle for active/archived */}
<div className="flex border-b">
<button
type="button"
onClick={() => setShowArchived(false)}
className={cn(
"flex-1 px-3 py-2 text-center text-xs font-medium transition-colors",
!showArchived
? "border-b-2 border-primary text-primary"
: "text-muted-foreground hover:text-foreground"
)}
>
Active ({state.threads.length})
</button>
<button
type="button"
onClick={() => setShowArchived(true)}
className={cn(
"flex-1 px-3 py-2 text-center text-xs font-medium transition-colors",
showArchived
? "border-b-2 border-primary text-primary"
: "text-muted-foreground hover:text-foreground"
)}
>
Archived ({state.archivedThreads.length})
</button>
</div>
{/* Thread list */}
<div className="flex-1 overflow-y-auto">
{displayedThreads.length === 0 ? (
<div className="flex flex-col items-center justify-center p-6 text-center">
<MessageSquareIcon className="mb-2 size-8 text-muted-foreground/50" />
<p className="text-muted-foreground text-sm">
{showArchived ? "No archived conversations" : "No conversations yet"}
</p>
{!showArchived && (
<Button variant="outline" size="sm" className="mt-3" onClick={handleNewThread}>
<PlusIcon className="mr-1 size-3" />
Start a conversation
</Button>
)}
</div>
) : (
<div className="space-y-1 p-2">
{displayedThreads.map((thread) => (
<ThreadListItemComponent
key={thread.id}
thread={thread}
isActive={thread.id === currentThreadId}
isArchived={showArchived}
onClick={() => handleSwitchToThread(thread.id)}
onArchive={() => handleArchive(thread.id)}
onUnarchive={() => handleUnarchive(thread.id)}
onDelete={() => handleDelete(thread.id)}
/>
))}
</div>
)}
</div>
</div>
);
}
interface ThreadListItemComponentProps {
thread: ThreadListItem;
isActive: boolean;
isArchived: boolean;
onClick: () => void;
onArchive: () => void;
onUnarchive: () => void;
onDelete: () => void;
}
function ThreadListItemComponent({
thread,
isActive,
isArchived,
onClick,
onArchive,
onUnarchive,
onDelete,
}: ThreadListItemComponentProps) {
return (
<div
className={cn(
"group flex items-center gap-2 rounded-lg px-3 py-2 transition-colors cursor-pointer",
isActive ? "bg-accent text-accent-foreground" : "hover:bg-muted/50"
)}
onClick={onClick}
onKeyDown={(e) => {
if (e.key === "Enter" || e.key === " ") onClick();
}}
role="button"
tabIndex={0}
>
<MessageSquareIcon className="size-4 shrink-0 text-muted-foreground" />
<div className="flex-1 min-w-0">
<p className="truncate text-sm font-medium">{thread.title || "New Chat"}</p>
<p className="truncate text-xs text-muted-foreground">
{formatRelativeTime(new Date(thread.updatedAt))}
</p>
</div>
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button
variant="ghost"
size="icon"
className="size-7 opacity-0 group-hover:opacity-100 transition-opacity"
onClick={(e) => e.stopPropagation()}
>
<MoreVerticalIcon className="size-4" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end">
{isArchived ? (
<DropdownMenuItem onClick={onUnarchive}>
<RotateCcwIcon className="mr-2 size-4" />
Unarchive
</DropdownMenuItem>
) : (
<DropdownMenuItem onClick={onArchive}>
<ArchiveIcon className="mr-2 size-4" />
Archive
</DropdownMenuItem>
)}
<DropdownMenuSeparator />
<DropdownMenuItem onClick={onDelete} className="text-destructive focus:text-destructive">
<TrashIcon className="mr-2 size-4" />
Delete
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</div>
);
}
/**
* Format a date as relative time (e.g., "2 hours ago", "Yesterday")
*/
function formatRelativeTime(date: Date): string {
const now = new Date();
const diffMs = now.getTime() - date.getTime();
const diffSecs = Math.floor(diffMs / 1000);
const diffMins = Math.floor(diffSecs / 60);
const diffHours = Math.floor(diffMins / 60);
const diffDays = Math.floor(diffHours / 24);
if (diffSecs < 60) return "Just now";
if (diffMins < 60) return `${diffMins} min${diffMins === 1 ? "" : "s"} ago`;
if (diffHours < 24) return `${diffHours} hour${diffHours === 1 ? "" : "s"} ago`;
if (diffDays === 1) return "Yesterday";
if (diffDays < 7) return `${diffDays} days ago`;
return date.toLocaleDateString();
}

View file

@ -0,0 +1,862 @@
import {
ActionBarPrimitive,
AssistantIf,
BranchPickerPrimitive,
ComposerPrimitive,
ErrorPrimitive,
MessagePrimitive,
ThreadPrimitive,
useAssistantState,
useMessage,
useThreadViewport,
} from "@assistant-ui/react";
import { useAtomValue } from "jotai";
import {
AlertCircle,
ArrowDownIcon,
ArrowUpIcon,
Brain,
CheckCircle2,
CheckIcon,
ChevronLeftIcon,
ChevronRightIcon,
CopyIcon,
DownloadIcon,
Loader2,
PencilIcon,
Plug2,
Plus,
RefreshCwIcon,
Search,
Sparkles,
SquareIcon,
} from "lucide-react";
import Link from "next/link";
import { type FC, useCallback, useEffect, useMemo, useRef, useState } from "react";
import { getDocumentTypeLabel } from "@/app/dashboard/[search_space_id]/documents/(manage)/components/DocumentTypeIcon";
import { documentTypeCountsAtom } from "@/atoms/documents/document-query.atoms";
import {
globalNewLLMConfigsAtom,
llmPreferencesAtom,
newLLMConfigsAtom,
} from "@/atoms/new-llm-config/new-llm-config-query.atoms";
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
import { currentUserAtom } from "@/atoms/user/user-query.atoms";
import {
ComposerAddAttachment,
ComposerAttachments,
UserMessageAttachments,
} from "@/components/assistant-ui/attachment";
import { MarkdownText } from "@/components/assistant-ui/markdown-text";
import { ToolFallback } from "@/components/assistant-ui/tool-fallback";
import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button";
import {
ChainOfThought,
ChainOfThoughtContent,
ChainOfThoughtItem,
ChainOfThoughtStep,
ChainOfThoughtTrigger,
} from "@/components/prompt-kit/chain-of-thought";
import type { ThinkingStep } from "@/components/tool-ui/deepagent-thinking";
import { Button } from "@/components/ui/button";
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
import { useSearchSourceConnectors } from "@/hooks/use-search-source-connectors";
import { cn } from "@/lib/utils";
/**
* Props for the Thread component
*/
interface ThreadProps {
messageThinkingSteps?: Map<string, ThinkingStep[]>;
}
// Context to pass thinking steps to AssistantMessage
import { createContext, useContext } from "react";
const ThinkingStepsContext = createContext<Map<string, ThinkingStep[]>>(new Map());
/**
* Get icon based on step status and title
*/
function getStepIcon(status: "pending" | "in_progress" | "completed", title: string) {
const titleLower = title.toLowerCase();
if (status === "in_progress") {
return <Loader2 className="size-4 animate-spin text-primary" />;
}
if (status === "completed") {
return <CheckCircle2 className="size-4 text-emerald-500" />;
}
if (titleLower.includes("search") || titleLower.includes("knowledge")) {
return <Search className="size-4 text-muted-foreground" />;
}
if (titleLower.includes("analy") || titleLower.includes("understand")) {
return <Brain className="size-4 text-muted-foreground" />;
}
return <Sparkles className="size-4 text-muted-foreground" />;
}
/**
* Chain of thought display component with smart expand/collapse behavior
*/
const ThinkingStepsDisplay: FC<{ steps: ThinkingStep[]; isThreadRunning?: boolean }> = ({
steps,
isThreadRunning = true,
}) => {
// Track which steps the user has manually toggled (overrides auto behavior)
const [manualOverrides, setManualOverrides] = useState<Record<string, boolean>>({});
// Track previous step statuses to detect changes
const prevStatusesRef = useRef<Record<string, string>>({});
// Derive effective status: if thread stopped and step is in_progress, treat as completed
const getEffectiveStatus = (step: ThinkingStep): "pending" | "in_progress" | "completed" => {
if (step.status === "in_progress" && !isThreadRunning) {
return "completed"; // Thread was stopped, so mark as completed
}
return step.status;
};
// Check if any step is effectively in progress
const hasInProgressStep = steps.some((step) => getEffectiveStatus(step) === "in_progress");
// Find the last completed step index (using effective status)
const lastCompletedIndex = steps
.map((s, i) => (getEffectiveStatus(s) === "completed" ? i : -1))
.filter((i) => i !== -1)
.pop();
// Clear manual overrides when a step's status changes
useEffect(() => {
const currentStatuses: Record<string, string> = {};
steps.forEach((step) => {
currentStatuses[step.id] = step.status;
// If status changed, clear any manual override for this step
if (prevStatusesRef.current[step.id] && prevStatusesRef.current[step.id] !== step.status) {
setManualOverrides((prev) => {
const next = { ...prev };
delete next[step.id];
return next;
});
}
});
prevStatusesRef.current = currentStatuses;
}, [steps]);
if (steps.length === 0) return null;
const getStepOpenState = (step: ThinkingStep, index: number): boolean => {
const effectiveStatus = getEffectiveStatus(step);
// If user has manually toggled, respect that
if (manualOverrides[step.id] !== undefined) {
return manualOverrides[step.id];
}
// Auto behavior: open if in progress
if (effectiveStatus === "in_progress") {
return true;
}
// Auto behavior: keep last completed step open if no in-progress step
if (!hasInProgressStep && index === lastCompletedIndex) {
return true;
}
// Default: collapsed
return false;
};
const handleToggle = (stepId: string, currentOpen: boolean) => {
setManualOverrides((prev) => ({
...prev,
[stepId]: !currentOpen,
}));
};
return (
<div className="mx-auto w-full max-w-(--thread-max-width) px-2 py-2">
<ChainOfThought>
{steps.map((step, index) => {
const effectiveStatus = getEffectiveStatus(step);
const icon = getStepIcon(effectiveStatus, step.title);
const isOpen = getStepOpenState(step, index);
return (
<ChainOfThoughtStep
key={step.id}
open={isOpen}
onOpenChange={() => handleToggle(step.id, isOpen)}
>
<ChainOfThoughtTrigger
leftIcon={icon}
swapIconOnHover={effectiveStatus !== "in_progress"}
className={cn(
effectiveStatus === "in_progress" && "text-foreground font-medium",
effectiveStatus === "completed" && "text-muted-foreground"
)}
>
{step.title}
</ChainOfThoughtTrigger>
{step.items && step.items.length > 0 && (
<ChainOfThoughtContent>
{step.items.map((item, idx) => (
<ChainOfThoughtItem key={`${step.id}-item-${idx}`}>{item}</ChainOfThoughtItem>
))}
</ChainOfThoughtContent>
)}
</ChainOfThoughtStep>
);
})}
</ChainOfThought>
</div>
);
};
/**
* Component that handles auto-scroll when thinking steps update.
* Uses useThreadViewport to scroll to bottom when thinking steps change,
* ensuring the user always sees the latest content during streaming.
*/
const ThinkingStepsScrollHandler: FC = () => {
const thinkingStepsMap = useContext(ThinkingStepsContext);
const viewport = useThreadViewport();
const isRunning = useAssistantState(({ thread }) => thread.isRunning);
// Track the serialized state to detect any changes
const prevStateRef = useRef<string>("");
useEffect(() => {
// Only act during streaming
if (!isRunning) {
prevStateRef.current = "";
return;
}
// Serialize the thinking steps state to detect any changes
// This catches new steps, status changes, and item additions
let stateString = "";
thinkingStepsMap.forEach((steps, msgId) => {
steps.forEach((step) => {
stateString += `${msgId}:${step.id}:${step.status}:${step.items?.length || 0};`;
});
});
// If state changed at all during streaming, scroll
if (stateString !== prevStateRef.current && stateString !== "") {
prevStateRef.current = stateString;
// Multiple attempts to ensure scroll happens after DOM updates
const scrollAttempt = () => {
try {
viewport.scrollToBottom();
} catch (e) {
// Ignore errors - viewport might not be ready
}
};
// Delayed attempts to handle async DOM updates
requestAnimationFrame(scrollAttempt);
setTimeout(scrollAttempt, 100);
}
}, [thinkingStepsMap, viewport, isRunning]);
return null; // This component doesn't render anything
};
export const Thread: FC<ThreadProps> = ({ messageThinkingSteps = new Map() }) => {
return (
<ThinkingStepsContext.Provider value={messageThinkingSteps}>
<ThreadPrimitive.Root
className="aui-root aui-thread-root @container flex h-full flex-col bg-background"
style={{
["--thread-max-width" as string]: "44rem",
}}
>
<ThreadPrimitive.Viewport
turnAnchor="top"
className="aui-thread-viewport relative flex flex-1 flex-col overflow-x-auto overflow-y-scroll scroll-smooth px-4 pt-4"
>
{/* Auto-scroll handler for thinking steps - must be inside Viewport */}
<ThinkingStepsScrollHandler />
<AssistantIf condition={({ thread }) => thread.isEmpty}>
<ThreadWelcome />
</AssistantIf>
<ThreadPrimitive.Messages
components={{
UserMessage,
EditComposer,
AssistantMessage,
}}
/>
<ThreadPrimitive.ViewportFooter className="aui-thread-viewport-footer sticky bottom-0 mx-auto mt-auto flex w-full max-w-(--thread-max-width) flex-col gap-4 overflow-visible rounded-t-3xl bg-background pb-4 md:pb-6">
<ThreadScrollToBottom />
<AssistantIf condition={({ thread }) => !thread.isEmpty}>
<div className="fade-in slide-in-from-bottom-4 animate-in duration-500 ease-out fill-mode-both">
<Composer />
</div>
</AssistantIf>
</ThreadPrimitive.ViewportFooter>
</ThreadPrimitive.Viewport>
</ThreadPrimitive.Root>
</ThinkingStepsContext.Provider>
);
};
const ThreadScrollToBottom: FC = () => {
return (
<ThreadPrimitive.ScrollToBottom asChild>
<TooltipIconButton
tooltip="Scroll to bottom"
variant="outline"
className="aui-thread-scroll-to-bottom -top-12 absolute z-10 self-center rounded-full p-4 disabled:invisible dark:bg-background dark:hover:bg-accent"
>
<ArrowDownIcon />
</TooltipIconButton>
</ThreadPrimitive.ScrollToBottom>
);
};
const getTimeBasedGreeting = (userEmail?: string): string => {
const hour = new Date().getHours();
// Extract first name from email if available
const firstName = userEmail
? userEmail.split("@")[0].split(".")[0].charAt(0).toUpperCase() +
userEmail.split("@")[0].split(".")[0].slice(1)
: null;
// Array of greeting variations for each time period
const morningGreetings = ["Good morning", "Rise and shine", "Morning", "Hey there"];
const afternoonGreetings = ["Good afternoon", "Afternoon", "Hey there", "Hi there"];
const eveningGreetings = ["Good evening", "Evening", "Hey there", "Hi there"];
const nightGreetings = ["Good night", "Evening", "Hey there", "Winding down"];
const lateNightGreetings = ["Still up", "Night owl mode", "The night is young", "Hi there"];
// Select a random greeting based on time
let greeting: string;
if (hour < 5) {
// Late night: midnight to 5 AM
greeting = lateNightGreetings[Math.floor(Math.random() * lateNightGreetings.length)];
} else if (hour < 12) {
greeting = morningGreetings[Math.floor(Math.random() * morningGreetings.length)];
} else if (hour < 18) {
greeting = afternoonGreetings[Math.floor(Math.random() * afternoonGreetings.length)];
} else if (hour < 22) {
greeting = eveningGreetings[Math.floor(Math.random() * eveningGreetings.length)];
} else {
// Night: 10 PM to midnight
greeting = nightGreetings[Math.floor(Math.random() * nightGreetings.length)];
}
// Add personalization with first name if available
if (firstName) {
return `${greeting}, ${firstName}!`;
}
return `${greeting}!`;
};
const ThreadWelcome: FC = () => {
const { data: user } = useAtomValue(currentUserAtom);
return (
<div className="aui-thread-welcome-root mx-auto flex w-full max-w-(--thread-max-width) grow flex-col items-center px-4 relative">
{/* Greeting positioned above the composer - fixed position */}
<div className="aui-thread-welcome-message absolute bottom-[calc(50%+5rem)] left-0 right-0 flex flex-col items-center text-center z-10">
<h1 className="aui-thread-welcome-message-inner fade-in slide-in-from-bottom-2 animate-in text-5xl delay-100 duration-500 ease-out fill-mode-both">
{getTimeBasedGreeting(user?.email)}
</h1>
</div>
{/* Composer - top edge fixed, expands downward only */}
<div className="fade-in slide-in-from-bottom-3 animate-in delay-200 duration-500 ease-out fill-mode-both w-full flex items-start justify-center absolute top-[calc(50%-3.5rem)] left-0 right-0">
<Composer />
</div>
</div>
);
};
const Composer: FC = () => {
// Check if a model is configured - needed to disable input
const { data: userConfigs } = useAtomValue(newLLMConfigsAtom);
const { data: globalConfigs } = useAtomValue(globalNewLLMConfigsAtom);
const { data: preferences } = useAtomValue(llmPreferencesAtom);
const hasModelConfigured = useMemo(() => {
if (!preferences) return false;
const agentLlmId = preferences.agent_llm_id;
if (agentLlmId === null || agentLlmId === undefined) return false;
// Check if the configured model actually exists
if (agentLlmId < 0) {
return globalConfigs?.some((c) => c.id === agentLlmId) ?? false;
}
return userConfigs?.some((c) => c.id === agentLlmId) ?? false;
}, [preferences, globalConfigs, userConfigs]);
return (
<ComposerPrimitive.Root className="aui-composer-root relative flex w-full flex-col">
<ComposerPrimitive.AttachmentDropzone className="aui-composer-attachment-dropzone flex w-full flex-col rounded-2xl border-input bg-muted px-1 pt-2 outline-none transition-shadow data-[dragging=true]:border-ring data-[dragging=true]:border-dashed data-[dragging=true]:bg-accent/50">
<ComposerAttachments />
<ComposerPrimitive.Input
placeholder={
hasModelConfigured
? "Ask SurfSense"
: "Select a model from the header to start chatting..."
}
className={cn(
"aui-composer-input mb-1 max-h-32 min-h-14 w-full resize-none bg-transparent px-4 pt-2 pb-3 text-sm outline-none placeholder:text-muted-foreground focus-visible:ring-0",
!hasModelConfigured && "cursor-not-allowed opacity-60"
)}
rows={1}
autoFocus={hasModelConfigured}
aria-label="Message input"
disabled={!hasModelConfigured}
/>
<ComposerAction />
</ComposerPrimitive.AttachmentDropzone>
</ComposerPrimitive.Root>
);
};
const ConnectorIndicator: FC = () => {
const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom);
const { connectors, isLoading: connectorsLoading } = useSearchSourceConnectors(
false,
searchSpaceId ? Number(searchSpaceId) : undefined
);
const { data: documentTypeCounts, isLoading: documentTypesLoading } =
useAtomValue(documentTypeCountsAtom);
const [isOpen, setIsOpen] = useState(false);
const closeTimeoutRef = useRef<NodeJS.Timeout | null>(null);
const isLoading = connectorsLoading || documentTypesLoading;
// Get document types that have documents in the search space
const activeDocumentTypes = documentTypeCounts
? Object.entries(documentTypeCounts).filter(([_, count]) => count > 0)
: [];
const hasConnectors = connectors.length > 0;
const hasSources = hasConnectors || activeDocumentTypes.length > 0;
const totalSourceCount = connectors.length + activeDocumentTypes.length;
const handleMouseEnter = useCallback(() => {
// Clear any pending close timeout
if (closeTimeoutRef.current) {
clearTimeout(closeTimeoutRef.current);
closeTimeoutRef.current = null;
}
setIsOpen(true);
}, []);
const handleMouseLeave = useCallback(() => {
// Delay closing by 150ms for better UX
closeTimeoutRef.current = setTimeout(() => {
setIsOpen(false);
}, 150);
}, []);
if (!searchSpaceId) return null;
return (
<Popover open={isOpen} onOpenChange={setIsOpen}>
<PopoverTrigger asChild>
<button
type="button"
className={cn(
"size-[34px] rounded-full p-1 flex items-center justify-center transition-colors relative",
"hover:bg-muted-foreground/15 dark:hover:bg-muted-foreground/30",
"outline-none focus:outline-none focus-visible:outline-none",
"border-0 ring-0 focus:ring-0 shadow-none focus:shadow-none",
"data-[state=open]:bg-transparent data-[state=open]:shadow-none data-[state=open]:ring-0",
"text-muted-foreground"
)}
aria-label={
hasSources ? `View ${totalSourceCount} connected sources` : "Add your first connector"
}
onMouseEnter={handleMouseEnter}
onMouseLeave={handleMouseLeave}
>
{isLoading ? (
<Loader2 className="size-4 animate-spin" />
) : (
<>
<Plug2 className="size-4" />
{totalSourceCount > 0 ? (
<span className="absolute -top-0.5 -right-0.5 flex items-center justify-center min-w-[16px] h-4 px-1 text-[10px] font-medium rounded-full bg-primary text-primary-foreground shadow-sm">
{totalSourceCount > 99 ? "99+" : totalSourceCount}
</span>
) : (
<span className="absolute -top-0.5 -right-0.5 flex items-center justify-center size-3 rounded-full bg-muted-foreground/30 border border-background">
<span className="size-1.5 rounded-full bg-muted-foreground/60" />
</span>
)}
</>
)}
</button>
</PopoverTrigger>
<PopoverContent
side="bottom"
align="start"
className="w-64 p-3"
onMouseEnter={handleMouseEnter}
onMouseLeave={handleMouseLeave}
>
{hasSources ? (
<div className="space-y-3">
<div className="flex items-center justify-between">
<p className="text-xs font-medium text-muted-foreground">Connected Sources</p>
<span className="text-xs font-medium bg-muted px-1.5 py-0.5 rounded">
{totalSourceCount}
</span>
</div>
<div className="flex flex-wrap gap-2">
{/* Document types from the search space */}
{activeDocumentTypes.map(([docType, count]) => (
<div
key={docType}
className="flex items-center gap-1.5 rounded-md bg-muted/80 px-2.5 py-1.5 text-xs border border-border/50"
>
{getConnectorIcon(docType, "size-3.5")}
<span className="truncate max-w-[100px]">{getDocumentTypeLabel(docType)}</span>
</div>
))}
{/* Search source connectors */}
{connectors.map((connector) => (
<div
key={`connector-${connector.id}`}
className="flex items-center gap-1.5 rounded-md bg-muted/80 px-2.5 py-1.5 text-xs border border-border/50"
>
{getConnectorIcon(connector.connector_type, "size-3.5")}
<span className="truncate max-w-[100px]">{connector.name}</span>
</div>
))}
</div>
<div className="pt-1 border-t border-border/50">
<Link
href={`/dashboard/${searchSpaceId}/connectors/add`}
className="inline-flex items-center gap-1.5 text-xs text-muted-foreground hover:text-foreground transition-colors"
>
<Plus className="size-3" />
Add more sources
<ChevronRightIcon className="size-3" />
</Link>
</div>
</div>
) : (
<div className="space-y-2">
<p className="text-sm font-medium">No sources yet</p>
<p className="text-xs text-muted-foreground">
Add documents or connect data sources to enhance search results.
</p>
<Link
href={`/dashboard/${searchSpaceId}/connectors/add`}
className="inline-flex items-center gap-1.5 rounded-md bg-primary px-3 py-1.5 text-xs font-medium text-primary-foreground hover:bg-primary/90 transition-colors mt-1"
>
<Plus className="size-3" />
Add Connector
</Link>
</div>
)}
</PopoverContent>
</Popover>
);
};
const ComposerAction: FC = () => {
// Check if any attachments are still being processed (running AND progress < 100)
// When progress is 100, processing is done but waiting for send()
const hasProcessingAttachments = useAssistantState(({ composer }) =>
composer.attachments?.some((att) => {
const status = att.status;
if (status?.type !== "running") return false;
const progress = (status as { type: "running"; progress?: number }).progress;
return progress === undefined || progress < 100;
})
);
// Check if composer text is empty
const isComposerEmpty = useAssistantState(({ composer }) => {
const text = composer.text?.trim() || "";
return text.length === 0;
});
// Check if a model is configured
const { data: userConfigs } = useAtomValue(newLLMConfigsAtom);
const { data: globalConfigs } = useAtomValue(globalNewLLMConfigsAtom);
const { data: preferences } = useAtomValue(llmPreferencesAtom);
const hasModelConfigured = useMemo(() => {
if (!preferences) return false;
const agentLlmId = preferences.agent_llm_id;
if (agentLlmId === null || agentLlmId === undefined) return false;
// Check if the configured model actually exists
if (agentLlmId < 0) {
return globalConfigs?.some((c) => c.id === agentLlmId) ?? false;
}
return userConfigs?.some((c) => c.id === agentLlmId) ?? false;
}, [preferences, globalConfigs, userConfigs]);
const isSendDisabled = hasProcessingAttachments || isComposerEmpty || !hasModelConfigured;
return (
<div className="aui-composer-action-wrapper relative mx-2 mb-2 flex items-center justify-between">
<div className="flex items-center gap-1">
<ComposerAddAttachment />
<ConnectorIndicator />
</div>
{/* Show processing indicator when attachments are being processed */}
{hasProcessingAttachments && (
<div className="flex items-center gap-1.5 text-muted-foreground text-xs">
<Loader2 className="size-3 animate-spin" />
<span>Processing...</span>
</div>
)}
{/* Show warning when no model is configured */}
{!hasModelConfigured && !hasProcessingAttachments && (
<div className="flex items-center gap-1.5 text-amber-600 dark:text-amber-400 text-xs">
<AlertCircle className="size-3" />
<span>Select a model</span>
</div>
)}
<AssistantIf condition={({ thread }) => !thread.isRunning}>
<ComposerPrimitive.Send asChild disabled={isSendDisabled}>
<TooltipIconButton
tooltip={
!hasModelConfigured
? "Please select a model from the header to start chatting"
: hasProcessingAttachments
? "Wait for attachments to process"
: isComposerEmpty
? "Enter a message to send"
: "Send message"
}
side="bottom"
type="submit"
variant="default"
size="icon"
className={cn(
"aui-composer-send size-8 rounded-full",
isSendDisabled && "cursor-not-allowed opacity-50"
)}
aria-label="Send message"
disabled={isSendDisabled}
>
<ArrowUpIcon className="aui-composer-send-icon size-4" />
</TooltipIconButton>
</ComposerPrimitive.Send>
</AssistantIf>
<AssistantIf condition={({ thread }) => thread.isRunning}>
<ComposerPrimitive.Cancel asChild>
<Button
type="button"
variant="default"
size="icon"
className="aui-composer-cancel size-8 rounded-full"
aria-label="Stop generating"
>
<SquareIcon className="aui-composer-cancel-icon size-3 fill-current" />
</Button>
</ComposerPrimitive.Cancel>
</AssistantIf>
</div>
);
};
const MessageError: FC = () => {
return (
<MessagePrimitive.Error>
<ErrorPrimitive.Root className="aui-message-error-root mt-2 rounded-md border border-destructive bg-destructive/10 p-3 text-destructive text-sm dark:bg-destructive/5 dark:text-red-200">
<ErrorPrimitive.Message className="aui-message-error-message line-clamp-2" />
</ErrorPrimitive.Root>
</MessagePrimitive.Error>
);
};
/**
* Custom component to render thinking steps from Context
*/
const ThinkingStepsPart: FC = () => {
const thinkingStepsMap = useContext(ThinkingStepsContext);
// Get the current message ID to look up thinking steps
const messageId = useMessage((m) => m.id);
const thinkingSteps = thinkingStepsMap.get(messageId) || [];
// Check if thread is still running (for stopping the spinner when cancelled)
const isThreadRunning = useAssistantState(({ thread }) => thread.isRunning);
if (thinkingSteps.length === 0) return null;
return (
<div className="mb-3">
<ThinkingStepsDisplay steps={thinkingSteps} isThreadRunning={isThreadRunning} />
</div>
);
};
const AssistantMessageInner: FC = () => {
return (
<>
{/* Render thinking steps from message content - this ensures proper scroll tracking */}
<ThinkingStepsPart />
<div className="aui-assistant-message-content wrap-break-word px-2 text-foreground leading-relaxed">
<MessagePrimitive.Parts
components={{
Text: MarkdownText,
tools: { Fallback: ToolFallback },
}}
/>
<MessageError />
</div>
<div className="aui-assistant-message-footer mt-1 ml-2 flex">
<BranchPicker />
<AssistantActionBar />
</div>
</>
);
};
const AssistantMessage: FC = () => {
return (
<MessagePrimitive.Root
className="aui-assistant-message-root fade-in slide-in-from-bottom-1 relative mx-auto w-full max-w-(--thread-max-width) animate-in py-3 duration-150"
data-role="assistant"
>
<AssistantMessageInner />
</MessagePrimitive.Root>
);
};
const AssistantActionBar: FC = () => {
return (
<ActionBarPrimitive.Root
hideWhenRunning
autohide="not-last"
autohideFloat="single-branch"
className="aui-assistant-action-bar-root -ml-1 col-start-3 row-start-2 flex gap-1 text-muted-foreground data-floating:absolute data-floating:rounded-md data-floating:border data-floating:bg-background data-floating:p-1 data-floating:shadow-sm"
>
<ActionBarPrimitive.Copy asChild>
<TooltipIconButton tooltip="Copy">
<AssistantIf condition={({ message }) => message.isCopied}>
<CheckIcon />
</AssistantIf>
<AssistantIf condition={({ message }) => !message.isCopied}>
<CopyIcon />
</AssistantIf>
</TooltipIconButton>
</ActionBarPrimitive.Copy>
<ActionBarPrimitive.ExportMarkdown asChild>
<TooltipIconButton tooltip="Export as Markdown">
<DownloadIcon />
</TooltipIconButton>
</ActionBarPrimitive.ExportMarkdown>
<ActionBarPrimitive.Reload asChild>
<TooltipIconButton tooltip="Refresh">
<RefreshCwIcon />
</TooltipIconButton>
</ActionBarPrimitive.Reload>
</ActionBarPrimitive.Root>
);
};
const UserMessage: FC = () => {
return (
<MessagePrimitive.Root
className="aui-user-message-root fade-in slide-in-from-bottom-1 mx-auto grid w-full max-w-(--thread-max-width) animate-in auto-rows-auto grid-cols-[minmax(72px,1fr)_auto] content-start gap-y-2 px-2 py-3 duration-150 [&:where(>*)]:col-start-2"
data-role="user"
>
<UserMessageAttachments />
<div className="aui-user-message-content-wrapper relative col-start-2 min-w-0">
<div className="aui-user-message-content wrap-break-word rounded-2xl bg-muted px-4 py-2.5 text-foreground">
<MessagePrimitive.Parts />
</div>
<div className="aui-user-action-bar-wrapper -translate-x-full -translate-y-1/2 absolute top-1/2 left-0 pr-2">
<UserActionBar />
</div>
</div>
<BranchPicker className="aui-user-branch-picker -mr-1 col-span-full col-start-1 row-start-3 justify-end" />
</MessagePrimitive.Root>
);
};
const UserActionBar: FC = () => {
return (
<ActionBarPrimitive.Root
hideWhenRunning
autohide="not-last"
className="aui-user-action-bar-root flex flex-col items-end"
>
<ActionBarPrimitive.Edit asChild>
<TooltipIconButton tooltip="Edit" className="aui-user-action-edit p-4">
<PencilIcon />
</TooltipIconButton>
</ActionBarPrimitive.Edit>
</ActionBarPrimitive.Root>
);
};
const EditComposer: FC = () => {
return (
<MessagePrimitive.Root className="aui-edit-composer-wrapper mx-auto flex w-full max-w-(--thread-max-width) flex-col px-2 py-3">
<ComposerPrimitive.Root className="aui-edit-composer-root ml-auto flex w-full max-w-[85%] flex-col rounded-2xl bg-muted">
<ComposerPrimitive.Input
className="aui-edit-composer-input min-h-14 w-full resize-none bg-transparent p-4 text-foreground text-sm outline-none"
autoFocus
/>
<div className="aui-edit-composer-footer mx-3 mb-3 flex items-center gap-2 self-end">
<ComposerPrimitive.Cancel asChild>
<Button variant="ghost" size="sm">
Cancel
</Button>
</ComposerPrimitive.Cancel>
<ComposerPrimitive.Send asChild>
<Button size="sm">Update</Button>
</ComposerPrimitive.Send>
</div>
</ComposerPrimitive.Root>
</MessagePrimitive.Root>
);
};
const BranchPicker: FC<BranchPickerPrimitive.Root.Props> = ({ className, ...rest }) => {
return (
<BranchPickerPrimitive.Root
hideWhenSingleBranch
className={cn(
"aui-branch-picker-root -ml-2 mr-2 inline-flex items-center text-muted-foreground text-xs",
className
)}
{...rest}
>
<BranchPickerPrimitive.Previous asChild>
<TooltipIconButton tooltip="Previous">
<ChevronLeftIcon />
</TooltipIconButton>
</BranchPickerPrimitive.Previous>
<span className="aui-branch-picker-state font-medium">
<BranchPickerPrimitive.Number /> / <BranchPickerPrimitive.Count />
</span>
<BranchPickerPrimitive.Next asChild>
<TooltipIconButton tooltip="Next">
<ChevronRightIcon />
</TooltipIconButton>
</BranchPickerPrimitive.Next>
</BranchPickerPrimitive.Root>
);
};

View file

@ -0,0 +1,76 @@
import type { ToolCallMessagePartComponent } from "@assistant-ui/react";
import { CheckIcon, ChevronDownIcon, ChevronUpIcon, XCircleIcon } from "lucide-react";
import { useState } from "react";
import { Button } from "@/components/ui/button";
import { cn } from "@/lib/utils";
export const ToolFallback: ToolCallMessagePartComponent = ({
toolName,
argsText,
result,
status,
}) => {
const [isCollapsed, setIsCollapsed] = useState(true);
const isCancelled = status?.type === "incomplete" && status.reason === "cancelled";
const cancelledReason =
isCancelled && status.error
? typeof status.error === "string"
? status.error
: JSON.stringify(status.error)
: null;
return (
<div
className={cn(
"aui-tool-fallback-root mb-4 flex w-full flex-col gap-3 rounded-lg border py-3",
isCancelled && "border-muted-foreground/30 bg-muted/30"
)}
>
<div className="aui-tool-fallback-header flex items-center gap-2 px-4">
{isCancelled ? (
<XCircleIcon className="aui-tool-fallback-icon size-4 text-muted-foreground" />
) : (
<CheckIcon className="aui-tool-fallback-icon size-4" />
)}
<p
className={cn(
"aui-tool-fallback-title grow",
isCancelled && "text-muted-foreground line-through"
)}
>
{isCancelled ? "Cancelled tool: " : "Used tool: "}
<b>{toolName}</b>
</p>
<Button onClick={() => setIsCollapsed(!isCollapsed)}>
{isCollapsed ? <ChevronUpIcon /> : <ChevronDownIcon />}
</Button>
</div>
{!isCollapsed && (
<div className="aui-tool-fallback-content flex flex-col gap-2 border-t pt-2">
{cancelledReason && (
<div className="aui-tool-fallback-cancelled-root px-4">
<p className="aui-tool-fallback-cancelled-header font-semibold text-muted-foreground">
Cancelled reason:
</p>
<p className="aui-tool-fallback-cancelled-reason text-muted-foreground">
{cancelledReason}
</p>
</div>
)}
<div className={cn("aui-tool-fallback-args-root px-4", isCancelled && "opacity-60")}>
<pre className="aui-tool-fallback-args-value whitespace-pre-wrap">{argsText}</pre>
</div>
{!isCancelled && result !== undefined && (
<div className="aui-tool-fallback-result-root border-t border-dashed px-4 pt-2">
<p className="aui-tool-fallback-result-header font-semibold">Result:</p>
<pre className="aui-tool-fallback-result-content whitespace-pre-wrap">
{typeof result === "string" ? result : JSON.stringify(result, null, 2)}
</pre>
</div>
)}
</div>
)}
</div>
);
};

View file

@ -0,0 +1,36 @@
"use client";
import { Slottable } from "@radix-ui/react-slot";
import { type ComponentPropsWithRef, forwardRef } from "react";
import { Button } from "@/components/ui/button";
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
import { cn } from "@/lib/utils";
export type TooltipIconButtonProps = ComponentPropsWithRef<typeof Button> & {
tooltip: string;
side?: "top" | "bottom" | "left" | "right";
};
export const TooltipIconButton = forwardRef<HTMLButtonElement, TooltipIconButtonProps>(
({ children, tooltip, side = "bottom", className, ...rest }, ref) => {
return (
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="ghost"
size="icon"
{...rest}
className={cn("aui-button-icon size-6 p-1", className)}
ref={ref}
>
<Slottable>{children}</Slottable>
<span className="aui-sr-only sr-only">{tooltip}</span>
</Button>
</TooltipTrigger>
<TooltipContent side={side}>{tooltip}</TooltipContent>
</Tooltip>
);
}
);
TooltipIconButton.displayName = "TooltipIconButton";

View file

@ -1,151 +0,0 @@
"use client";
import { useInView } from "motion/react";
import { Manrope } from "next/font/google";
import { useEffect, useMemo, useReducer, useRef } from "react";
import { RoughNotation, RoughNotationGroup } from "react-rough-notation";
import { useSidebar } from "@/components/ui/sidebar";
import { cn } from "@/lib/utils";
// Font configuration - could be moved to a global font config file
const manrope = Manrope({
subsets: ["latin"],
weight: ["400", "700"],
display: "swap", // Optimize font loading
variable: "--font-manrope",
});
// Constants for timing - makes it easier to adjust and more maintainable
const TIMING = {
SIDEBAR_TRANSITION: 300, // Wait for sidebar transition + buffer
LAYOUT_SETTLE: 100, // Small delay to ensure layout is fully settled
} as const;
// Animation configuration
const ANIMATION_CONFIG = {
HIGHLIGHT: {
type: "highlight" as const,
animationDuration: 2000,
iterations: 3,
color: "#3b82f680",
multiline: true,
},
UNDERLINE: {
type: "underline" as const,
animationDuration: 2000,
iterations: 3,
color: "#10b981",
},
} as const;
// State management with useReducer for better organization
interface HighlightState {
shouldShowHighlight: boolean;
layoutStable: boolean;
}
type HighlightAction =
| { type: "SIDEBAR_CHANGED" }
| { type: "LAYOUT_STABILIZED" }
| { type: "SHOW_HIGHLIGHT" }
| { type: "HIDE_HIGHLIGHT" };
const highlightReducer = (state: HighlightState, action: HighlightAction): HighlightState => {
switch (action.type) {
case "SIDEBAR_CHANGED":
return {
shouldShowHighlight: false,
layoutStable: false,
};
case "LAYOUT_STABILIZED":
return {
...state,
layoutStable: true,
};
case "SHOW_HIGHLIGHT":
return {
...state,
shouldShowHighlight: true,
};
case "HIDE_HIGHLIGHT":
return {
...state,
shouldShowHighlight: false,
};
default:
return state;
}
};
const initialState: HighlightState = {
shouldShowHighlight: false,
layoutStable: true,
};
export function AnimatedEmptyState() {
const ref = useRef<HTMLDivElement>(null);
const isInView = useInView(ref);
const [{ shouldShowHighlight, layoutStable }, dispatch] = useReducer(
highlightReducer,
initialState
);
// Memoize class names to prevent unnecessary recalculations
const headingClassName = useMemo(
() =>
cn(
"text-3xl sm:text-4xl md:text-5xl lg:text-6xl font-bold tracking-tight text-neutral-900 dark:text-neutral-50 mb-6",
manrope.className
),
[]
);
const paragraphClassName = useMemo(
() => "text-lg sm:text-xl text-neutral-600 dark:text-neutral-300 mb-8 max-w-2xl mx-auto",
[]
);
// Handle sidebar state changes
useEffect(() => {
dispatch({ type: "SIDEBAR_CHANGED" });
const stabilizeTimer = setTimeout(() => {
dispatch({ type: "LAYOUT_STABILIZED" });
}, TIMING.SIDEBAR_TRANSITION);
return () => clearTimeout(stabilizeTimer);
}, []);
// Handle highlight visibility based on layout stability and viewport visibility
useEffect(() => {
if (!layoutStable || !isInView) {
dispatch({ type: "HIDE_HIGHLIGHT" });
return;
}
const showTimer = setTimeout(() => {
dispatch({ type: "SHOW_HIGHLIGHT" });
}, TIMING.LAYOUT_SETTLE);
return () => clearTimeout(showTimer);
}, [layoutStable, isInView]);
return (
<div ref={ref} className="flex-1 flex items-center justify-center w-full min-h-fit">
<div className="max-w-4xl mx-auto px-4 py-10 text-center">
<RoughNotationGroup show={shouldShowHighlight}>
<h1 className={headingClassName}>
<RoughNotation {...ANIMATION_CONFIG.HIGHLIGHT}>
<span>SurfSense</span>
</RoughNotation>
</h1>
<p className={paragraphClassName}>
<RoughNotation {...ANIMATION_CONFIG.UNDERLINE}>Let's Start Surfing</RoughNotation>{" "}
through your knowledge base.
</p>
</RoughNotationGroup>
</div>
</div>
);
}

View file

@ -1,30 +0,0 @@
"use client";
import type React from "react";
import { useState } from "react";
import { SheetTrigger } from "@/components/ui/sheet";
import { SourceDetailSheet } from "./SourceDetailSheet";
export const CitationDisplay: React.FC<{ index: number; node: any }> = ({ index, node }) => {
const chunkId = Number(node?.id);
const sourceType = node?.metadata?.source_type;
const [isOpen, setIsOpen] = useState(false);
return (
<SourceDetailSheet
open={isOpen}
onOpenChange={setIsOpen}
chunkId={chunkId}
sourceType={sourceType}
title={node?.metadata?.title || node?.metadata?.group_name || "Source"}
description={node?.text}
url={node?.url}
>
<SheetTrigger asChild>
<span className="text-[10px] font-bold bg-slate-500 hover:bg-slate-600 text-white rounded-full w-4 h-4 inline-flex items-center justify-center align-super cursor-pointer transition-colors">
{index + 1}
</span>
</SheetTrigger>
</SourceDetailSheet>
);
};

View file

@ -1,36 +0,0 @@
"use client";
import { getAnnotationData, type Message, useChatUI } from "@llamaindex/chat-ui";
import { SuggestedQuestions } from "@llamaindex/chat-ui/widgets";
import {
Accordion,
AccordionContent,
AccordionItem,
AccordionTrigger,
} from "@/components/ui/accordion";
export const ChatFurtherQuestions: React.FC<{ message: Message }> = ({ message }) => {
const annotations: string[][] = getAnnotationData(message, "FURTHER_QUESTIONS");
const { append, requestData } = useChatUI();
if (annotations.length !== 1 || annotations[0].length === 0) {
return null;
}
return (
<Accordion type="single" collapsible className="w-full border rounded-md bg-card shadow-sm">
<AccordionItem value="suggested-questions" className="border-0">
<AccordionTrigger className="px-4 py-3 text-sm font-medium text-foreground transition-colors">
Further Suggested Questions
</AccordionTrigger>
<AccordionContent className="px-4 pb-4 pt-0">
<SuggestedQuestions
questions={annotations[0]}
append={append}
requestData={requestData}
/>
</AccordionContent>
</AccordionItem>
</Accordion>
);
};

View file

@ -1,851 +0,0 @@
"use client";
import { ChatInput } from "@llamaindex/chat-ui";
import { useAtom, useAtomValue } from "jotai";
import { Brain, Check, FolderOpen, Minus, Plus, PlusCircle, Zap } from "lucide-react";
import { useParams, useRouter } from "next/navigation";
import React, { Suspense, useCallback, useMemo, useState } from "react";
import { documentTypeCountsAtom } from "@/atoms/documents/document-query.atoms";
import { updateLLMPreferencesMutationAtom } from "@/atoms/llm-config/llm-config-mutation.atoms";
import {
globalLLMConfigsAtom,
llmConfigsAtom,
llmPreferencesAtom,
} from "@/atoms/llm-config/llm-config-query.atoms";
import { DocumentsDataTable } from "@/components/chat/DocumentsDataTable";
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogTitle,
DialogTrigger,
} from "@/components/ui/dialog";
import { Input } from "@/components/ui/input";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip";
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
import type { Document } from "@/contracts/types/document.types";
import { useSearchSourceConnectors } from "@/hooks/use-search-source-connectors";
const DocumentSelector = React.memo(
({
onSelectionChange,
selectedDocuments = [],
}: {
onSelectionChange?: (documents: Document[]) => void;
selectedDocuments?: Document[];
}) => {
const { search_space_id } = useParams();
const [isOpen, setIsOpen] = useState(false);
const handleOpenChange = useCallback((open: boolean) => {
setIsOpen(open);
}, []);
const handleSelectionChange = useCallback(
(documents: Document[]) => {
onSelectionChange?.(documents);
},
[onSelectionChange]
);
const handleDone = useCallback(() => {
setIsOpen(false);
}, []);
const selectedCount = React.useMemo(() => selectedDocuments.length, [selectedDocuments.length]);
return (
<Dialog open={isOpen} onOpenChange={handleOpenChange}>
<DialogTrigger asChild>
<Button
variant="outline"
size="sm"
className="h-9 gap-2 px-3 border-dashed hover:border-solid hover:bg-accent/50 transition-all"
>
<FolderOpen className="h-4 w-4 text-muted-foreground" />
<span className="text-xs font-medium">
{selectedCount > 0 ? `Selected` : "Documents"}
</span>
{selectedCount > 0 && (
<Badge variant="secondary" className="h-5 px-1.5 text-xs font-medium">
{selectedCount}
</Badge>
)}
</Button>
</DialogTrigger>
<DialogContent className="max-w-[95vw] md:max-w-5xl h-[90vh] md:h-[85vh] p-0 flex flex-col">
<div className="flex flex-col h-full">
<div className="px-4 md:px-6 py-4 border-b flex-shrink-0 bg-muted/30">
<DialogTitle className="text-lg md:text-xl font-semibold">
Select Documents
</DialogTitle>
<DialogDescription className="mt-1.5 text-sm">
Choose specific documents to include in your research context
</DialogDescription>
</div>
<div className="flex-1 min-h-0 p-4 md:p-6">
<DocumentsDataTable
searchSpaceId={Number(search_space_id)}
onSelectionChange={handleSelectionChange}
onDone={handleDone}
initialSelectedDocuments={selectedDocuments}
/>
</div>
</div>
</DialogContent>
</Dialog>
);
}
);
DocumentSelector.displayName = "DocumentSelector";
const ConnectorSelector = React.memo(
({
onSelectionChange,
selectedConnectors = [],
}: {
onSelectionChange?: (connectorTypes: string[]) => void;
selectedConnectors?: string[];
}) => {
const { search_space_id } = useParams();
const router = useRouter();
const [isOpen, setIsOpen] = useState(false);
// Use the documentTypeCountsAtom for fetching document types
const [documentTypeCountsQuery] = useAtom(documentTypeCountsAtom);
const {
data: documentTypeCountsData,
isLoading,
refetch: fetchDocumentTypes,
} = documentTypeCountsQuery;
// Transform the response into the expected format
const documentTypes = useMemo(() => {
if (!documentTypeCountsData) return [];
return Object.entries(documentTypeCountsData).map(([type, count]) => ({
type,
count,
}));
}, [documentTypeCountsData]);
const isLoaded = !!documentTypeCountsData;
// Fetch live search connectors immediately (non-indexable)
const {
connectors: searchConnectors,
isLoading: connectorsLoading,
isLoaded: connectorsLoaded,
fetchConnectors,
} = useSearchSourceConnectors(false, Number(search_space_id));
// Filter for non-indexable connectors (live search)
const liveSearchConnectors = React.useMemo(
() => searchConnectors.filter((connector) => !connector.is_indexable),
[searchConnectors]
);
const handleOpenChange = useCallback((open: boolean) => {
setIsOpen(open);
// Data is already loaded on mount, no need to fetch again
}, []);
const handleConnectorToggle = useCallback(
(connectorType: string) => {
const isSelected = selectedConnectors.includes(connectorType);
const newSelection = isSelected
? selectedConnectors.filter((type) => type !== connectorType)
: [...selectedConnectors, connectorType];
onSelectionChange?.(newSelection);
},
[selectedConnectors, onSelectionChange]
);
const handleSelectAll = useCallback(() => {
const allTypes = [
...documentTypes.map((dt) => dt.type),
...liveSearchConnectors.map((c) => c.connector_type),
];
onSelectionChange?.(allTypes);
}, [documentTypes, liveSearchConnectors, onSelectionChange]);
const handleClearAll = useCallback(() => {
onSelectionChange?.([]);
}, [onSelectionChange]);
// Get display name for connector type
const getDisplayName = (type: string) => {
return type
.split("_")
.map((word) => word.charAt(0) + word.slice(1).toLowerCase())
.join(" ");
};
// Get selected document types with their counts
const selectedDocTypes = documentTypes.filter((dt) => selectedConnectors.includes(dt.type));
const selectedLiveConnectors = liveSearchConnectors.filter((c) =>
selectedConnectors.includes(c.connector_type)
);
// Total selected count
const totalSelectedCount = selectedDocTypes.length + selectedLiveConnectors.length;
const totalAvailableCount = documentTypes.length + liveSearchConnectors.length;
return (
<Dialog open={isOpen} onOpenChange={handleOpenChange}>
<DialogTrigger asChild>
<Button
variant="outline"
size="sm"
className="relative h-9 gap-2 px-3 border-dashed hover:border-solid hover:bg-accent/50 transition-all"
>
<div className="flex items-center gap-1.5">
{totalSelectedCount > 0 ? (
<>
<div className="flex items-center -space-x-2">
{selectedDocTypes.slice(0, 2).map((docType) => (
<div
key={docType.type}
className="flex h-6 w-6 items-center justify-center rounded-full border-2 border-background bg-muted"
>
{getConnectorIcon(docType.type, "h-3 w-3")}
</div>
))}
{selectedLiveConnectors
.slice(0, 3 - selectedDocTypes.slice(0, 2).length)
.map((connector) => (
<div
key={connector.id}
className="flex h-6 w-6 items-center justify-center rounded-full border-2 border-background bg-muted"
>
{getConnectorIcon(connector.connector_type, "h-3 w-3")}
</div>
))}
</div>
<span className="text-xs font-medium">
{totalSelectedCount} {totalSelectedCount === 1 ? "source" : "sources"}
</span>
</>
) : (
<>
<Brain className="h-4 w-4 text-muted-foreground" />
<span className="text-xs font-medium">Sources</span>
</>
)}
</div>
</Button>
</DialogTrigger>
<DialogContent className="sm:max-w-2xl max-h-[85vh] flex flex-col">
<div className="space-y-4 flex-1 overflow-y-auto pr-2">
<div>
<DialogTitle className="text-xl">Select Sources</DialogTitle>
<DialogDescription className="mt-1.5">
Choose indexed document types and live search connectors to include in your search
</DialogDescription>
</div>
{isLoading || connectorsLoading ? (
<div className="flex justify-center py-8">
<div className="animate-spin h-8 w-8 border-3 border-primary border-t-transparent rounded-full" />
</div>
) : totalAvailableCount === 0 ? (
<div className="flex flex-col items-center justify-center py-12 text-center">
<div className="rounded-full bg-muted p-4 mb-4">
<Brain className="h-8 w-8 text-muted-foreground" />
</div>
<h4 className="text-sm font-medium mb-1">No sources found</h4>
<p className="text-xs text-muted-foreground max-w-xs mb-4">
Add documents or configure search connectors for this search space
</p>
<Button
onClick={() => {
setIsOpen(false);
router.push(`/dashboard/${search_space_id}/sources/add`);
}}
className="gap-2"
>
<PlusCircle className="h-4 w-4" />
Add Sources
</Button>
</div>
) : (
<>
{/* Live Search Connectors Section */}
{liveSearchConnectors.length > 0 && (
<div className="space-y-2">
<div className="flex items-center gap-2 pb-2">
<Zap className="h-4 w-4 text-primary" />
<h3 className="text-sm font-semibold">Live Search Connectors</h3>
<Badge variant="outline" className="text-xs">
Real-time
</Badge>
</div>
<div className="grid grid-cols-2 gap-3">
{liveSearchConnectors.map((connector) => {
const isSelected = selectedConnectors.includes(connector.connector_type);
return (
<button
key={connector.id}
onClick={() => handleConnectorToggle(connector.connector_type)}
type="button"
className={`group relative flex items-center gap-3 p-3 rounded-lg border-2 transition-all ${
isSelected
? "border-primary bg-primary/5 shadow-sm"
: "border-border hover:border-primary/50 hover:bg-accent/50"
}`}
>
<div
className={`flex h-10 w-10 items-center justify-center rounded-md transition-colors ${
isSelected ? "bg-primary/10" : "bg-muted group-hover:bg-primary/5"
}`}
>
{getConnectorIcon(
connector.connector_type,
`h-5 w-5 ${isSelected ? "text-primary" : "text-muted-foreground group-hover:text-primary"}`
)}
</div>
<div className="flex-1 text-left min-w-0">
<div className="flex items-center gap-2">
<p className="text-sm font-medium truncate">{connector.name}</p>
{isSelected && (
<div className="flex h-5 w-5 items-center justify-center rounded-full bg-primary">
<Check className="h-3 w-3 text-primary-foreground" />
</div>
)}
</div>
<p className="text-xs text-muted-foreground mt-0.5 truncate">
{getDisplayName(connector.connector_type)}
</p>
</div>
</button>
);
})}
</div>
</div>
)}
{/* Document Types Section */}
{documentTypes.length > 0 && (
<div className="space-y-2">
<div className="flex items-center gap-2 pb-2">
<FolderOpen className="h-4 w-4 text-primary" />
<h3 className="text-sm font-semibold">Indexed Document Types</h3>
<Badge variant="outline" className="text-xs">
Stored
</Badge>
</div>
<div className="grid grid-cols-2 gap-3">
{documentTypes.map((docType) => {
const isSelected = selectedConnectors.includes(docType.type);
return (
<button
key={docType.type}
onClick={() => handleConnectorToggle(docType.type)}
type="button"
className={`group relative flex items-center gap-3 p-3 rounded-lg border-2 transition-all ${
isSelected
? "border-primary bg-primary/5 shadow-sm"
: "border-border hover:border-primary/50 hover:bg-accent/50"
}`}
>
<div
className={`flex h-10 w-10 items-center justify-center rounded-md transition-colors ${
isSelected ? "bg-primary/10" : "bg-muted group-hover:bg-primary/5"
}`}
>
{getConnectorIcon(
docType.type,
`h-5 w-5 ${isSelected ? "text-primary" : "text-muted-foreground group-hover:text-primary"}`
)}
</div>
<div className="flex-1 text-left min-w-0">
<div className="flex items-center gap-2">
<p className="text-sm font-medium truncate">
{getDisplayName(docType.type)}
</p>
{isSelected && (
<div className="flex h-5 w-5 items-center justify-center rounded-full bg-primary">
<Check className="h-3 w-3 text-primary-foreground" />
</div>
)}
</div>
<p className="text-xs text-muted-foreground mt-0.5">
{docType.count} {docType.count === 1 ? "document" : "documents"}
</p>
</div>
</button>
);
})}
</div>
</div>
)}
</>
)}
</div>
{totalAvailableCount > 0 && (
<DialogFooter className="flex flex-row justify-between items-center gap-2 pt-4 border-t">
<Button
variant="ghost"
size="sm"
onClick={handleClearAll}
disabled={selectedConnectors.length === 0}
className="text-xs"
>
Clear All
</Button>
<Button
size="sm"
onClick={handleSelectAll}
disabled={selectedConnectors.length === totalAvailableCount}
className="text-xs"
>
Select All ({totalAvailableCount})
</Button>
</DialogFooter>
)}
</DialogContent>
</Dialog>
);
}
);
ConnectorSelector.displayName = "ConnectorSelector";
const TopKSelector = React.memo(
({ topK = 10, onTopKChange }: { topK?: number; onTopKChange?: (topK: number) => void }) => {
const MIN_VALUE = 1;
const MAX_VALUE = 100;
const handleIncrement = React.useCallback(() => {
if (topK < MAX_VALUE) {
onTopKChange?.(topK + 1);
}
}, [topK, onTopKChange]);
const handleDecrement = React.useCallback(() => {
if (topK > MIN_VALUE) {
onTopKChange?.(topK - 1);
}
}, [topK, onTopKChange]);
const handleInputChange = React.useCallback(
(e: React.ChangeEvent<HTMLInputElement>) => {
const value = e.target.value;
// Allow empty input for editing
if (value === "") {
return;
}
const numValue = parseInt(value, 10);
if (!isNaN(numValue) && numValue >= MIN_VALUE && numValue <= MAX_VALUE) {
onTopKChange?.(numValue);
}
},
[onTopKChange]
);
const handleInputBlur = React.useCallback(
(e: React.FocusEvent<HTMLInputElement>) => {
const value = e.target.value;
if (value === "") {
// Reset to default if empty
onTopKChange?.(10);
return;
}
const numValue = parseInt(value, 10);
if (isNaN(numValue) || numValue < MIN_VALUE) {
onTopKChange?.(MIN_VALUE);
} else if (numValue > MAX_VALUE) {
onTopKChange?.(MAX_VALUE);
}
},
[onTopKChange]
);
return (
<TooltipProvider>
<Tooltip delayDuration={200}>
<TooltipTrigger asChild>
<div className="flex items-center h-8 border rounded-md bg-background hover:bg-accent/50 transition-colors">
<Button
type="button"
variant="ghost"
size="icon"
className="h-full w-7 rounded-l-md rounded-r-none hover:bg-accent border-r"
onClick={handleDecrement}
disabled={topK <= MIN_VALUE}
>
<Minus className="h-3.5 w-3.5" />
</Button>
<div className="flex flex-col items-center justify-center px-2 min-w-[60px]">
<Input
type="number"
value={topK}
onChange={handleInputChange}
onBlur={handleInputBlur}
min={MIN_VALUE}
max={MAX_VALUE}
className="h-5 w-full px-1 text-center text-sm font-semibold border-0 bg-transparent focus-visible:ring-0 focus-visible:ring-offset-0 [appearance:textfield] [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none"
/>
<span className="text-[10px] text-muted-foreground leading-none">Results</span>
</div>
<Button
type="button"
variant="ghost"
size="icon"
className="h-full w-7 rounded-r-md rounded-l-none hover:bg-accent border-l"
onClick={handleIncrement}
disabled={topK >= MAX_VALUE}
>
<Plus className="h-3.5 w-3.5" />
</Button>
</div>
</TooltipTrigger>
<TooltipContent side="top" className="max-w-xs">
<div className="space-y-2">
<p className="text-sm font-semibold">Results per Source</p>
<p className="text-xs text-muted-foreground leading-relaxed">
Control how many results to fetch from each data source. Set a higher number to get
more information, or a lower number for faster, more focused results.
</p>
<div className="flex items-center gap-2 text-xs text-muted-foreground pt-1 border-t">
<span>Recommended: 5-20</span>
<span></span>
<span>
Range: {MIN_VALUE}-{MAX_VALUE}
</span>
</div>
</div>
</TooltipContent>
</Tooltip>
</TooltipProvider>
);
}
);
TopKSelector.displayName = "TopKSelector";
const LLMSelector = React.memo(() => {
const { search_space_id } = useParams();
const searchSpaceId = Number(search_space_id);
const {
data: llmConfigs = [],
isFetching: llmLoading,
isError: error,
} = useAtomValue(llmConfigsAtom);
const {
data: globalConfigs = [],
isFetching: globalConfigsLoading,
isError: globalConfigsError,
} = useAtomValue(globalLLMConfigsAtom);
// Replace useLLMPreferences with jotai atoms
const { data: preferences = {}, isFetching: preferencesLoading } =
useAtomValue(llmPreferencesAtom);
const { mutateAsync: updatePreferences } = useAtomValue(updateLLMPreferencesMutationAtom);
const isLoading = llmLoading || preferencesLoading || globalConfigsLoading;
// Combine global and custom configs
const allConfigs = React.useMemo(() => {
return [...globalConfigs.map((config) => ({ ...config, is_global: true })), ...llmConfigs];
}, [globalConfigs, llmConfigs]);
// Memoize the selected config to avoid repeated lookups
const selectedConfig = React.useMemo(() => {
if (!preferences.fast_llm_id || !allConfigs.length) return null;
return allConfigs.find((config) => config.id === preferences.fast_llm_id) || null;
}, [preferences.fast_llm_id, allConfigs]);
// Memoize the display value for the trigger
const displayValue = React.useMemo(() => {
if (!selectedConfig) return null;
return (
<div className="flex items-center gap-1">
<span className="font-medium text-xs">{selectedConfig.provider}</span>
<span className="text-muted-foreground"></span>
<span className="hidden sm:inline text-muted-foreground text-xs truncate max-w-[60px]">
{selectedConfig.name}
</span>
{"is_global" in selectedConfig && selectedConfig.is_global && (
<span className="text-xs">🌐</span>
)}
</div>
);
}, [selectedConfig]);
const handleValueChange = React.useCallback(
(value: string) => {
const llmId = value ? parseInt(value, 10) : undefined;
updatePreferences({
search_space_id: searchSpaceId,
data: { fast_llm_id: llmId },
});
},
[updatePreferences, searchSpaceId]
);
// Loading skeleton
if (isLoading) {
return (
<div className="h-8 min-w-[100px] sm:min-w-[120px]">
<div className="h-8 rounded-md bg-muted animate-pulse flex items-center px-3">
<div className="w-3 h-3 rounded bg-muted-foreground/20 mr-2" />
<div className="h-3 w-16 rounded bg-muted-foreground/20" />
</div>
</div>
);
}
// Error state
if (error || globalConfigsError) {
return (
<div className="h-8 min-w-[100px] sm:min-w-[120px]">
<Button
variant="outline"
size="sm"
className="h-8 px-3 text-xs text-destructive border-destructive/50 hover:bg-destructive/10"
disabled
>
<span className="text-xs">Error</span>
</Button>
</div>
);
}
return (
<div className="h-8 min-w-0">
<Select
value={preferences.fast_llm_id?.toString() || ""}
onValueChange={handleValueChange}
disabled={isLoading}
>
<SelectTrigger className="h-8 w-auto min-w-[100px] sm:min-w-[120px] px-3 text-xs border-border bg-background hover:bg-muted/50 transition-colors duration-200 focus:ring-2 focus:ring-primary/20">
<div className="flex items-center gap-2 min-w-0">
<Zap className="h-3 w-3 text-primary flex-shrink-0" />
<SelectValue placeholder="Fast LLM" className="text-xs">
{displayValue || <span className="text-muted-foreground">Select LLM</span>}
</SelectValue>
</div>
</SelectTrigger>
<SelectContent align="end" className="w-[300px] max-h-[400px]">
<div className="px-3 py-2 text-xs font-medium text-muted-foreground border-b bg-muted/30">
<div className="flex items-center gap-2">
<Zap className="h-3 w-3" />
Fast LLM Selection
</div>
</div>
{allConfigs.length === 0 ? (
<div className="px-4 py-6 text-center">
<div className="mx-auto w-12 h-12 rounded-full bg-muted flex items-center justify-center mb-3">
<Brain className="h-5 w-5 text-muted-foreground" />
</div>
<h4 className="text-sm font-medium mb-1">No LLM configurations</h4>
<p className="text-xs text-muted-foreground mb-3">
Configure AI models to get started
</p>
<Button
variant="outline"
size="sm"
className="text-xs"
onClick={() => window.open("/settings", "_blank")}
>
Open Settings
</Button>
</div>
) : (
<div className="py-1">
{/* Global Configurations */}
{globalConfigs.length > 0 && (
<>
<div className="px-3 py-1.5 text-xs font-semibold text-muted-foreground">
Global Configurations
</div>
{globalConfigs.map((config) => (
<SelectItem
key={config.id}
value={config.id.toString()}
className="px-3 py-2 cursor-pointer hover:bg-accent/50 focus:bg-accent"
>
<div className="flex items-center justify-between w-full min-w-0">
<div className="flex items-center gap-3 min-w-0 flex-1">
<div className="flex h-8 w-8 items-center justify-center rounded-md bg-primary/10 flex-shrink-0">
<Brain className="h-4 w-4 text-primary" />
</div>
<div className="min-w-0 flex-1">
<div className="flex items-center gap-2 mb-1 flex-wrap">
<span className="font-medium text-sm truncate">{config.name}</span>
<Badge
variant="outline"
className="text-xs px-1.5 py-0.5 flex-shrink-0"
>
{config.provider}
</Badge>
<Badge
variant="secondary"
className="text-xs px-1.5 py-0.5 flex-shrink-0"
>
🌐 Global
</Badge>
</div>
<p className="text-xs text-muted-foreground font-mono truncate">
{config.model_name}
</p>
</div>
</div>
</div>
</SelectItem>
))}
</>
)}
{/* Custom Configurations */}
{llmConfigs.length > 0 && (
<>
<div className="px-3 py-1.5 text-xs font-semibold text-muted-foreground">
Your Configurations
</div>
{llmConfigs.map((config) => (
<SelectItem
key={config.id}
value={config.id.toString()}
className="px-3 py-2 cursor-pointer hover:bg-accent/50 focus:bg-accent"
>
<div className="flex items-center justify-between w-full min-w-0">
<div className="flex items-center gap-3 min-w-0 flex-1">
<div className="flex h-8 w-8 items-center justify-center rounded-md bg-primary/10 flex-shrink-0">
<Brain className="h-4 w-4 text-primary" />
</div>
<div className="min-w-0 flex-1">
<div className="flex items-center gap-2 mb-1">
<span className="font-medium text-sm truncate">{config.name}</span>
<Badge
variant="outline"
className="text-xs px-1.5 py-0.5 flex-shrink-0"
>
{config.provider}
</Badge>
</div>
<p className="text-xs text-muted-foreground font-mono truncate">
{config.model_name}
</p>
</div>
</div>
</div>
</SelectItem>
))}
</>
)}
</div>
)}
</SelectContent>
</Select>
</div>
);
});
LLMSelector.displayName = "LLMSelector";
const CustomChatInputOptions = React.memo(
({
onDocumentSelectionChange,
selectedDocuments,
onConnectorSelectionChange,
selectedConnectors,
topK,
onTopKChange,
}: {
onDocumentSelectionChange?: (documents: Document[]) => void;
selectedDocuments?: Document[];
onConnectorSelectionChange?: (connectorTypes: string[]) => void;
selectedConnectors?: string[];
topK?: number;
onTopKChange?: (topK: number) => void;
}) => {
// Memoize the loading fallback to prevent recreation
const loadingFallback = React.useMemo(
() => <div className="h-9 w-24 animate-pulse bg-muted/50 rounded-md" />,
[]
);
return (
<div className="flex flex-wrap gap-2 items-center">
<div className="flex items-center gap-2">
<Suspense fallback={loadingFallback}>
<DocumentSelector
onSelectionChange={onDocumentSelectionChange}
selectedDocuments={selectedDocuments}
/>
</Suspense>
<Suspense fallback={loadingFallback}>
<ConnectorSelector
onSelectionChange={onConnectorSelectionChange}
selectedConnectors={selectedConnectors}
/>
</Suspense>
</div>
<div className="h-4 w-px bg-border hidden sm:block" />
<TopKSelector topK={topK} onTopKChange={onTopKChange} />
<div className="h-4 w-px bg-border hidden sm:block" />
<LLMSelector />
</div>
);
}
);
CustomChatInputOptions.displayName = "CustomChatInputOptions";
export const ChatInputUI = React.memo(
({
onDocumentSelectionChange,
selectedDocuments,
onConnectorSelectionChange,
selectedConnectors,
topK,
onTopKChange,
}: {
onDocumentSelectionChange?: (documents: Document[]) => void;
selectedDocuments?: Document[];
onConnectorSelectionChange?: (connectorTypes: string[]) => void;
selectedConnectors?: string[];
topK?: number;
onTopKChange?: (topK: number) => void;
}) => {
return (
<ChatInput className="p-2">
<ChatInput.Form className="flex gap-2">
<ChatInput.Field className="flex-1" />
<ChatInput.Submit />
</ChatInput.Form>
<CustomChatInputOptions
onDocumentSelectionChange={onDocumentSelectionChange}
selectedDocuments={selectedDocuments}
onConnectorSelectionChange={onConnectorSelectionChange}
selectedConnectors={selectedConnectors}
topK={topK}
onTopKChange={onTopKChange}
/>
</ChatInput>
);
}
);
ChatInputUI.displayName = "ChatInputUI";

Some files were not shown because too many files have changed in this diff Show more