diff --git a/surfsense_backend/alembic/versions/25_migrate_llm_configs_to_search_spaces.py b/surfsense_backend/alembic/versions/25_migrate_llm_configs_to_search_spaces.py new file mode 100644 index 000000000..116a3c687 --- /dev/null +++ b/surfsense_backend/alembic/versions/25_migrate_llm_configs_to_search_spaces.py @@ -0,0 +1,352 @@ +"""Migrate LLM configs to search spaces and add user preferences + +Revision ID: 25 +Revises: 24 +Create Date: 2025-01-10 14:00:00.000000 + +Changes: +1. Migrate llm_configs from user association to search_space association +2. Create user_search_space_preferences table for per-user LLM preferences +3. Migrate existing user LLM preferences to user_search_space_preferences +4. Remove LLM preference columns from user table +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "25" +down_revision: str | None = "24" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """ + Upgrade schema to support collaborative search spaces with per-user preferences. + + Migration steps: + 1. Add search_space_id to llm_configs + 2. Migrate existing llm_configs to first search space of their user + 3. Replace user_id with search_space_id in llm_configs + 4. Create user_search_space_preferences table + 5. Migrate user LLM preferences to user_search_space_preferences + 6. Remove LLM preference columns from user table + """ + + from sqlalchemy import inspect + + conn = op.get_bind() + inspector = inspect(conn) + + # Get existing columns + llm_config_columns = [col["name"] for col in inspector.get_columns("llm_configs")] + user_columns = [col["name"] for col in inspector.get_columns("user")] + + # ===== STEP 1: Add search_space_id to llm_configs ===== + if "search_space_id" not in llm_config_columns: + op.add_column( + "llm_configs", + sa.Column("search_space_id", sa.Integer(), nullable=True), + ) + + # ===== STEP 2: Populate search_space_id with user's first search space ===== + # This ensures existing LLM configs are assigned to a valid search space + op.execute( + """ + UPDATE llm_configs lc + SET search_space_id = ( + SELECT id + FROM searchspaces ss + WHERE ss.user_id = lc.user_id + ORDER BY ss.created_at ASC + LIMIT 1 + ) + WHERE search_space_id IS NULL AND user_id IS NOT NULL + """ + ) + + # ===== STEP 3: Make search_space_id NOT NULL and add FK constraint ===== + op.alter_column( + "llm_configs", + "search_space_id", + nullable=False, + ) + + # Add foreign key constraint + foreign_keys = [fk["name"] for fk in inspector.get_foreign_keys("llm_configs")] + if "fk_llm_configs_search_space_id" not in foreign_keys: + op.create_foreign_key( + "fk_llm_configs_search_space_id", + "llm_configs", + "searchspaces", + ["search_space_id"], + ["id"], + ondelete="CASCADE", + ) + + # Drop old user_id foreign key if it exists + if "fk_llm_configs_user_id_user" in foreign_keys: + op.drop_constraint( + "fk_llm_configs_user_id_user", + "llm_configs", + type_="foreignkey", + ) + + # Remove user_id column + if "user_id" in llm_config_columns: + op.drop_column("llm_configs", "user_id") + + # ===== STEP 4: Create user_search_space_preferences table ===== + op.execute( + """ + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = 'user_search_space_preferences' + ) THEN + CREATE TABLE user_search_space_preferences ( + id SERIAL PRIMARY KEY, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE, + search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE, + long_context_llm_id INTEGER REFERENCES llm_configs(id) ON DELETE SET NULL, + fast_llm_id INTEGER REFERENCES llm_configs(id) ON DELETE SET NULL, + strategic_llm_id INTEGER REFERENCES llm_configs(id) ON DELETE SET NULL, + CONSTRAINT uq_user_searchspace UNIQUE (user_id, search_space_id) + ); + END IF; + END$$; + """ + ) + + # Create indexes + op.execute( + """ + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_indexes + WHERE tablename = 'user_search_space_preferences' + AND indexname = 'ix_user_search_space_preferences_id' + ) THEN + CREATE INDEX ix_user_search_space_preferences_id + ON user_search_space_preferences(id); + END IF; + + IF NOT EXISTS ( + SELECT 1 FROM pg_indexes + WHERE tablename = 'user_search_space_preferences' + AND indexname = 'ix_user_search_space_preferences_created_at' + ) THEN + CREATE INDEX ix_user_search_space_preferences_created_at + ON user_search_space_preferences(created_at); + END IF; + END$$; + """ + ) + + # ===== STEP 5: Migrate user LLM preferences to user_search_space_preferences ===== + # For each user, create preferences for each of their search spaces + if all( + col in user_columns + for col in ["long_context_llm_id", "fast_llm_id", "strategic_llm_id"] + ): + op.execute( + """ + INSERT INTO user_search_space_preferences + (user_id, search_space_id, long_context_llm_id, fast_llm_id, strategic_llm_id, created_at) + SELECT + u.id as user_id, + ss.id as search_space_id, + u.long_context_llm_id, + u.fast_llm_id, + u.strategic_llm_id, + NOW() as created_at + FROM "user" u + CROSS JOIN searchspaces ss + WHERE ss.user_id = u.id + ON CONFLICT (user_id, search_space_id) DO NOTHING + """ + ) + + # ===== STEP 6: Remove LLM preference columns from user table ===== + # Get fresh list of foreign keys after previous operations + user_foreign_keys = [fk["name"] for fk in inspector.get_foreign_keys("user")] + + # Drop foreign key constraints if they exist + if "fk_user_long_context_llm_id_llm_configs" in user_foreign_keys: + op.drop_constraint( + "fk_user_long_context_llm_id_llm_configs", + "user", + type_="foreignkey", + ) + + if "fk_user_fast_llm_id_llm_configs" in user_foreign_keys: + op.drop_constraint( + "fk_user_fast_llm_id_llm_configs", + "user", + type_="foreignkey", + ) + + if "fk_user_strategic_llm_id_llm_configs" in user_foreign_keys: + op.drop_constraint( + "fk_user_strategic_llm_id_llm_configs", + "user", + type_="foreignkey", + ) + + # Drop columns from user table + if "long_context_llm_id" in user_columns: + op.drop_column("user", "long_context_llm_id") + + if "fast_llm_id" in user_columns: + op.drop_column("user", "fast_llm_id") + + if "strategic_llm_id" in user_columns: + op.drop_column("user", "strategic_llm_id") + + +def downgrade() -> None: + """ + Downgrade schema back to user-owned LLM configs. + + WARNING: This downgrade will result in data loss: + - LLM configs will be moved back to user ownership (first occurrence kept) + - Per-search-space user preferences will be consolidated to user level + - Additional LLM configs in search spaces beyond the first will be deleted + """ + + from sqlalchemy import inspect + + conn = op.get_bind() + inspector = inspect(conn) + + # Get existing columns and constraints + llm_config_columns = [col["name"] for col in inspector.get_columns("llm_configs")] + user_columns = [col["name"] for col in inspector.get_columns("user")] + + # ===== STEP 1: Add LLM preference columns back to user table ===== + if "long_context_llm_id" not in user_columns: + op.add_column( + "user", + sa.Column("long_context_llm_id", sa.Integer(), nullable=True), + ) + + if "fast_llm_id" not in user_columns: + op.add_column( + "user", + sa.Column("fast_llm_id", sa.Integer(), nullable=True), + ) + + if "strategic_llm_id" not in user_columns: + op.add_column( + "user", + sa.Column("strategic_llm_id", sa.Integer(), nullable=True), + ) + + # ===== STEP 2: Migrate preferences back to user table ===== + # Take the first preference for each user + op.execute( + """ + UPDATE "user" u + SET + long_context_llm_id = ussp.long_context_llm_id, + fast_llm_id = ussp.fast_llm_id, + strategic_llm_id = ussp.strategic_llm_id + FROM ( + SELECT DISTINCT ON (user_id) + user_id, + long_context_llm_id, + fast_llm_id, + strategic_llm_id + FROM user_search_space_preferences + ORDER BY user_id, created_at ASC + ) ussp + WHERE u.id = ussp.user_id + """ + ) + + # ===== STEP 3: Add foreign key constraints back to user table ===== + op.create_foreign_key( + "fk_user_long_context_llm_id_llm_configs", + "user", + "llm_configs", + ["long_context_llm_id"], + ["id"], + ondelete="SET NULL", + ) + + op.create_foreign_key( + "fk_user_fast_llm_id_llm_configs", + "user", + "llm_configs", + ["fast_llm_id"], + ["id"], + ondelete="SET NULL", + ) + + op.create_foreign_key( + "fk_user_strategic_llm_id_llm_configs", + "user", + "llm_configs", + ["strategic_llm_id"], + ["id"], + ondelete="SET NULL", + ) + + # ===== STEP 4: Drop user_search_space_preferences table ===== + op.execute("DROP TABLE IF EXISTS user_search_space_preferences CASCADE") + + # ===== STEP 5: Add user_id back to llm_configs ===== + if "user_id" not in llm_config_columns: + op.add_column( + "llm_configs", + sa.Column("user_id", postgresql.UUID(), nullable=True), + ) + + # Populate user_id from search_space + op.execute( + """ + UPDATE llm_configs lc + SET user_id = ss.user_id + FROM searchspaces ss + WHERE lc.search_space_id = ss.id + """ + ) + + # Make user_id NOT NULL + op.alter_column( + "llm_configs", + "user_id", + nullable=False, + ) + + # Add foreign key constraint for user_id + op.create_foreign_key( + "fk_llm_configs_user_id_user", + "llm_configs", + "user", + ["user_id"], + ["id"], + ondelete="CASCADE", + ) + + # ===== STEP 6: Remove search_space_id from llm_configs ===== + # Drop foreign key constraint + foreign_keys = [fk["name"] for fk in inspector.get_foreign_keys("llm_configs")] + if "fk_llm_configs_search_space_id" in foreign_keys: + op.drop_constraint( + "fk_llm_configs_search_space_id", + "llm_configs", + type_="foreignkey", + ) + + # Drop search_space_id column + if "search_space_id" in llm_config_columns: + op.drop_column("llm_configs", "search_space_id") diff --git a/surfsense_backend/app/agents/podcaster/configuration.py b/surfsense_backend/app/agents/podcaster/configuration.py index c4c5f9e9c..453f12676 100644 --- a/surfsense_backend/app/agents/podcaster/configuration.py +++ b/surfsense_backend/app/agents/podcaster/configuration.py @@ -17,6 +17,7 @@ class Configuration: # and when you invoke the graph podcast_title: str user_id: str + search_space_id: int @classmethod def from_runnable_config( diff --git a/surfsense_backend/app/agents/podcaster/nodes.py b/surfsense_backend/app/agents/podcaster/nodes.py index 891928d90..bce9882d6 100644 --- a/surfsense_backend/app/agents/podcaster/nodes.py +++ b/surfsense_backend/app/agents/podcaster/nodes.py @@ -28,11 +28,12 @@ async def create_podcast_transcript( # Get configuration from runnable config configuration = Configuration.from_runnable_config(config) user_id = configuration.user_id + search_space_id = configuration.search_space_id # Get user's long context LLM - llm = await get_user_long_context_llm(state.db_session, user_id) + llm = await get_user_long_context_llm(state.db_session, user_id, search_space_id) if not llm: - error_message = f"No long context LLM configured for user {user_id}" + error_message = f"No long context LLM configured for user {user_id} in search space {search_space_id}" print(error_message) raise RuntimeError(error_message) diff --git a/surfsense_backend/app/agents/researcher/nodes.py b/surfsense_backend/app/agents/researcher/nodes.py index 5cd314d6d..0835fb861 100644 --- a/surfsense_backend/app/agents/researcher/nodes.py +++ b/surfsense_backend/app/agents/researcher/nodes.py @@ -577,6 +577,7 @@ async def write_answer_outline( user_query = configuration.user_query num_sections = configuration.num_sections user_id = configuration.user_id + search_space_id = configuration.search_space_id writer( { @@ -587,9 +588,9 @@ async def write_answer_outline( ) # Get user's strategic LLM - llm = await get_user_strategic_llm(state.db_session, user_id) + llm = await get_user_strategic_llm(state.db_session, user_id, search_space_id) if not llm: - error_message = f"No strategic LLM configured for user {user_id}" + error_message = f"No strategic LLM configured for user {user_id} in search space {search_space_id}" writer({"yield_value": streaming_service.format_error(error_message)}) raise RuntimeError(error_message) @@ -1854,6 +1855,7 @@ async def reformulate_user_query( user_query=user_query, session=state.db_session, user_id=configuration.user_id, + search_space_id=configuration.search_space_id, chat_history_str=chat_history_str, ) @@ -2093,6 +2095,7 @@ async def generate_further_questions( configuration = Configuration.from_runnable_config(config) chat_history = state.chat_history user_id = configuration.user_id + search_space_id = configuration.search_space_id streaming_service = state.streaming_service # Get reranked documents from the state (will be populated by sub-agents) @@ -2107,9 +2110,9 @@ async def generate_further_questions( ) # Get user's fast LLM - llm = await get_user_fast_llm(state.db_session, user_id) + llm = await get_user_fast_llm(state.db_session, user_id, search_space_id) if not llm: - error_message = f"No fast LLM configured for user {user_id}" + error_message = f"No fast LLM configured for user {user_id} in search space {search_space_id}" print(error_message) writer({"yield_value": streaming_service.format_error(error_message)}) diff --git a/surfsense_backend/app/agents/researcher/qna_agent/nodes.py b/surfsense_backend/app/agents/researcher/qna_agent/nodes.py index 4e01bbb58..fd6861efb 100644 --- a/surfsense_backend/app/agents/researcher/qna_agent/nodes.py +++ b/surfsense_backend/app/agents/researcher/qna_agent/nodes.py @@ -101,11 +101,12 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any documents = state.reranked_documents user_query = configuration.user_query user_id = configuration.user_id + search_space_id = configuration.search_space_id # Get user's fast LLM - llm = await get_user_fast_llm(state.db_session, user_id) + llm = await get_user_fast_llm(state.db_session, user_id, search_space_id) if not llm: - error_message = f"No fast LLM configured for user {user_id}" + error_message = f"No fast LLM configured for user {user_id} in search space {search_space_id}" print(error_message) raise RuntimeError(error_message) diff --git a/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py b/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py index 91a8bf84e..153cafac5 100644 --- a/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py +++ b/surfsense_backend/app/agents/researcher/sub_section_writer/nodes.py @@ -107,11 +107,12 @@ async def write_sub_section(state: State, config: RunnableConfig) -> dict[str, A configuration = Configuration.from_runnable_config(config) documents = state.reranked_documents user_id = configuration.user_id + search_space_id = configuration.search_space_id # Get user's fast LLM - llm = await get_user_fast_llm(state.db_session, user_id) + llm = await get_user_fast_llm(state.db_session, user_id, search_space_id) if not llm: - error_message = f"No fast LLM configured for user {user_id}" + error_message = f"No fast LLM configured for user {user_id} in search space {search_space_id}" print(error_message) raise RuntimeError(error_message) diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 7af9c0661..eb33145cf 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -240,6 +240,17 @@ class SearchSpace(BaseModel, TimestampMixin): order_by="SearchSourceConnector.id", cascade="all, delete-orphan", ) + llm_configs = relationship( + "LLMConfig", + back_populates="search_space", + order_by="LLMConfig.id", + cascade="all, delete-orphan", + ) + user_preferences = relationship( + "UserSearchSpacePreference", + back_populates="search_space", + cascade="all, delete-orphan", + ) class SearchSourceConnector(BaseModel, TimestampMixin): @@ -288,10 +299,54 @@ class LLMConfig(BaseModel, TimestampMixin): # For any other parameters that litellm supports litellm_params = Column(JSON, nullable=True, default={}) + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False + ) + search_space = relationship("SearchSpace", back_populates="llm_configs") + + +class UserSearchSpacePreference(BaseModel, TimestampMixin): + __tablename__ = "user_search_space_preferences" + __table_args__ = ( + UniqueConstraint( + "user_id", + "search_space_id", + name="uq_user_searchspace", + ), + ) + user_id = Column( UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False ) - user = relationship("User", back_populates="llm_configs", foreign_keys=[user_id]) + search_space_id = Column( + Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False + ) + + # User-specific LLM preferences for this search space + long_context_llm_id = Column( + Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True + ) + fast_llm_id = Column( + Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True + ) + strategic_llm_id = Column( + Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True + ) + + # Future RBAC fields can be added here + # role = Column(String(50), nullable=True) # e.g., 'owner', 'editor', 'viewer' + # permissions = Column(JSON, nullable=True) + + user = relationship("User", back_populates="search_space_preferences") + search_space = relationship("SearchSpace", back_populates="user_preferences") + + long_context_llm = relationship( + "LLMConfig", foreign_keys=[long_context_llm_id], post_update=True + ) + fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True) + strategic_llm = relationship( + "LLMConfig", foreign_keys=[strategic_llm_id], post_update=True + ) class Log(BaseModel, TimestampMixin): @@ -321,64 +376,22 @@ if config.AUTH_TYPE == "GOOGLE": "OAuthAccount", lazy="joined" ) search_spaces = relationship("SearchSpace", back_populates="user") - llm_configs = relationship( - "LLMConfig", + search_space_preferences = relationship( + "UserSearchSpacePreference", back_populates="user", - foreign_keys="LLMConfig.user_id", cascade="all, delete-orphan", ) - long_context_llm_id = Column( - Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True - ) - fast_llm_id = Column( - Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True - ) - strategic_llm_id = Column( - Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True - ) - - long_context_llm = relationship( - "LLMConfig", foreign_keys=[long_context_llm_id], post_update=True - ) - fast_llm = relationship( - "LLMConfig", foreign_keys=[fast_llm_id], post_update=True - ) - strategic_llm = relationship( - "LLMConfig", foreign_keys=[strategic_llm_id], post_update=True - ) - else: class User(SQLAlchemyBaseUserTableUUID, Base): search_spaces = relationship("SearchSpace", back_populates="user") - llm_configs = relationship( - "LLMConfig", + search_space_preferences = relationship( + "UserSearchSpacePreference", back_populates="user", - foreign_keys="LLMConfig.user_id", cascade="all, delete-orphan", ) - long_context_llm_id = Column( - Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True - ) - fast_llm_id = Column( - Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True - ) - strategic_llm_id = Column( - Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True - ) - - long_context_llm = relationship( - "LLMConfig", foreign_keys=[long_context_llm_id], post_update=True - ) - fast_llm = relationship( - "LLMConfig", foreign_keys=[fast_llm_id], post_update=True - ) - strategic_llm = relationship( - "LLMConfig", foreign_keys=[strategic_llm_id], post_update=True - ) - engine = create_async_engine(DATABASE_URL) async_session_maker = async_sessionmaker(engine, expire_on_commit=False) diff --git a/surfsense_backend/app/routes/chats_routes.py b/surfsense_backend/app/routes/chats_routes.py index 1dcda505c..e4d02686f 100644 --- a/surfsense_backend/app/routes/chats_routes.py +++ b/surfsense_backend/app/routes/chats_routes.py @@ -17,19 +17,17 @@ from app.tasks.stream_connector_search_results import stream_connector_search_re from app.users import current_active_user from app.utils.check_ownership import check_ownership from app.utils.validators import ( - validate_search_space_id, - validate_document_ids, validate_connectors, + validate_document_ids, + validate_messages, validate_research_mode, validate_search_mode, - validate_messages, + validate_search_space_id, ) router = APIRouter() - - @router.post("/chat") async def handle_chat_data( request: AISDKChatRequest, @@ -38,20 +36,22 @@ async def handle_chat_data( ): # Validate and sanitize all input data messages = validate_messages(request.messages) - + if messages[-1]["role"] != "user": raise HTTPException( status_code=400, detail="Last message must be a user message" ) user_query = messages[-1]["content"] - + # Extract and validate data from request request_data = request.data or {} search_space_id = validate_search_space_id(request_data.get("search_space_id")) research_mode = validate_research_mode(request_data.get("research_mode")) selected_connectors = validate_connectors(request_data.get("selected_connectors")) - document_ids_to_add_in_context = validate_document_ids(request_data.get("document_ids_to_add_in_context")) + document_ids_to_add_in_context = validate_document_ids( + request_data.get("document_ids_to_add_in_context") + ) search_mode_str = validate_search_mode(request_data.get("search_mode")) # Check if the search space belongs to the current user @@ -132,21 +132,16 @@ async def read_chats( # Validate pagination parameters if skip < 0: raise HTTPException( - status_code=400, - detail="skip must be a non-negative integer" + status_code=400, detail="skip must be a non-negative integer" ) - + if limit <= 0 or limit > 1000: # Reasonable upper limit - raise HTTPException( - status_code=400, - detail="limit must be between 1 and 1000" - ) - + raise HTTPException(status_code=400, detail="limit must be between 1 and 1000") + # Validate search_space_id if provided if search_space_id is not None and search_space_id <= 0: raise HTTPException( - status_code=400, - detail="search_space_id must be a positive integer" + status_code=400, detail="search_space_id must be a positive integer" ) try: # Select specific fields excluding messages diff --git a/surfsense_backend/app/routes/llm_config_routes.py b/surfsense_backend/app/routes/llm_config_routes.py index ce76dc9bc..63d540d2c 100644 --- a/surfsense_backend/app/routes/llm_config_routes.py +++ b/surfsense_backend/app/routes/llm_config_routes.py @@ -2,15 +2,72 @@ from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select +from sqlalchemy.orm import selectinload -from app.db import LLMConfig, User, get_async_session +from app.db import ( + LLMConfig, + SearchSpace, + User, + UserSearchSpacePreference, + get_async_session, +) from app.schemas import LLMConfigCreate, LLMConfigRead, LLMConfigUpdate from app.users import current_active_user -from app.utils.check_ownership import check_ownership router = APIRouter() +# Helper function to check search space access +async def check_search_space_access( + session: AsyncSession, search_space_id: int, user: User +) -> SearchSpace: + """Verify that the user has access to the search space""" + result = await session.execute( + select(SearchSpace).filter( + SearchSpace.id == search_space_id, SearchSpace.user_id == user.id + ) + ) + search_space = result.scalars().first() + if not search_space: + raise HTTPException( + status_code=404, + detail="Search space not found or you don't have permission to access it", + ) + return search_space + + +# Helper function to get or create user search space preference +async def get_or_create_user_preference( + session: AsyncSession, user_id, search_space_id: int +) -> UserSearchSpacePreference: + """Get or create user preference for a search space""" + result = await session.execute( + select(UserSearchSpacePreference) + .filter( + UserSearchSpacePreference.user_id == user_id, + UserSearchSpacePreference.search_space_id == search_space_id, + ) + .options( + selectinload(UserSearchSpacePreference.long_context_llm), + selectinload(UserSearchSpacePreference.fast_llm), + selectinload(UserSearchSpacePreference.strategic_llm), + ) + ) + preference = result.scalars().first() + + if not preference: + # Create new preference entry + preference = UserSearchSpacePreference( + user_id=user_id, + search_space_id=search_space_id, + ) + session.add(preference) + await session.commit() + await session.refresh(preference) + + return preference + + class LLMPreferencesUpdate(BaseModel): """Schema for updating user LLM preferences""" @@ -36,9 +93,12 @@ async def create_llm_config( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Create a new LLM configuration for the authenticated user""" + """Create a new LLM configuration for a search space""" try: - db_llm_config = LLMConfig(**llm_config.model_dump(), user_id=user.id) + # Verify user has access to the search space + await check_search_space_access(session, llm_config.search_space_id, user) + + db_llm_config = LLMConfig(**llm_config.model_dump()) session.add(db_llm_config) await session.commit() await session.refresh(db_llm_config) @@ -54,20 +114,26 @@ async def create_llm_config( @router.get("/llm-configs/", response_model=list[LLMConfigRead]) async def read_llm_configs( + search_space_id: int, skip: int = 0, limit: int = 200, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Get all LLM configurations for the authenticated user""" + """Get all LLM configurations for a search space""" try: + # Verify user has access to the search space + await check_search_space_access(session, search_space_id, user) + result = await session.execute( select(LLMConfig) - .filter(LLMConfig.user_id == user.id) + .filter(LLMConfig.search_space_id == search_space_id) .offset(skip) .limit(limit) ) return result.scalars().all() + except HTTPException: + raise except Exception as e: raise HTTPException( status_code=500, detail=f"Failed to fetch LLM configurations: {e!s}" @@ -82,7 +148,18 @@ async def read_llm_config( ): """Get a specific LLM configuration by ID""" try: - llm_config = await check_ownership(session, LLMConfig, llm_config_id, user) + # Get the LLM config + result = await session.execute( + select(LLMConfig).filter(LLMConfig.id == llm_config_id) + ) + llm_config = result.scalars().first() + + if not llm_config: + raise HTTPException(status_code=404, detail="LLM configuration not found") + + # Verify user has access to the search space + await check_search_space_access(session, llm_config.search_space_id, user) + return llm_config except HTTPException: raise @@ -101,7 +178,18 @@ async def update_llm_config( ): """Update an existing LLM configuration""" try: - db_llm_config = await check_ownership(session, LLMConfig, llm_config_id, user) + # Get the LLM config + result = await session.execute( + select(LLMConfig).filter(LLMConfig.id == llm_config_id) + ) + db_llm_config = result.scalars().first() + + if not db_llm_config: + raise HTTPException(status_code=404, detail="LLM configuration not found") + + # Verify user has access to the search space + await check_search_space_access(session, db_llm_config.search_space_id, user) + update_data = llm_config_update.model_dump(exclude_unset=True) for key, value in update_data.items(): @@ -127,7 +215,18 @@ async def delete_llm_config( ): """Delete an LLM configuration""" try: - db_llm_config = await check_ownership(session, LLMConfig, llm_config_id, user) + # Get the LLM config + result = await session.execute( + select(LLMConfig).filter(LLMConfig.id == llm_config_id) + ) + db_llm_config = result.scalars().first() + + if not db_llm_config: + raise HTTPException(status_code=404, detail="LLM configuration not found") + + # Verify user has access to the search space + await check_search_space_access(session, db_llm_config.search_space_id, user) + await session.delete(db_llm_config) await session.commit() return {"message": "LLM configuration deleted successfully"} @@ -143,99 +242,101 @@ async def delete_llm_config( # User LLM Preferences endpoints -@router.get("/users/me/llm-preferences", response_model=LLMPreferencesRead) +@router.get( + "/search-spaces/{search_space_id}/llm-preferences", + response_model=LLMPreferencesRead, +) async def get_user_llm_preferences( + search_space_id: int, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Get the current user's LLM preferences""" + """Get the current user's LLM preferences for a specific search space""" try: - # Refresh user to get latest relationships - await session.refresh(user) + # Verify user has access to the search space + await check_search_space_access(session, search_space_id, user) - result = { - "long_context_llm_id": user.long_context_llm_id, - "fast_llm_id": user.fast_llm_id, - "strategic_llm_id": user.strategic_llm_id, - "long_context_llm": None, - "fast_llm": None, - "strategic_llm": None, + # Get or create user preference for this search space + preference = await get_or_create_user_preference( + session, user.id, search_space_id + ) + + return { + "long_context_llm_id": preference.long_context_llm_id, + "fast_llm_id": preference.fast_llm_id, + "strategic_llm_id": preference.strategic_llm_id, + "long_context_llm": preference.long_context_llm, + "fast_llm": preference.fast_llm, + "strategic_llm": preference.strategic_llm, } - - # Fetch the actual LLM configs if they exist - if user.long_context_llm_id: - long_context_llm = await session.execute( - select(LLMConfig).filter( - LLMConfig.id == user.long_context_llm_id, - LLMConfig.user_id == user.id, - ) - ) - llm_config = long_context_llm.scalars().first() - if llm_config: - result["long_context_llm"] = llm_config - - if user.fast_llm_id: - fast_llm = await session.execute( - select(LLMConfig).filter( - LLMConfig.id == user.fast_llm_id, LLMConfig.user_id == user.id - ) - ) - llm_config = fast_llm.scalars().first() - if llm_config: - result["fast_llm"] = llm_config - - if user.strategic_llm_id: - strategic_llm = await session.execute( - select(LLMConfig).filter( - LLMConfig.id == user.strategic_llm_id, LLMConfig.user_id == user.id - ) - ) - llm_config = strategic_llm.scalars().first() - if llm_config: - result["strategic_llm"] = llm_config - - return result + except HTTPException: + raise except Exception as e: raise HTTPException( status_code=500, detail=f"Failed to fetch LLM preferences: {e!s}" ) from e -@router.put("/users/me/llm-preferences", response_model=LLMPreferencesRead) +@router.put( + "/search-spaces/{search_space_id}/llm-preferences", + response_model=LLMPreferencesRead, +) async def update_user_llm_preferences( + search_space_id: int, preferences: LLMPreferencesUpdate, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): - """Update the current user's LLM preferences""" + """Update the current user's LLM preferences for a specific search space""" try: - # Validate that all provided LLM config IDs belong to the user + # Verify user has access to the search space + await check_search_space_access(session, search_space_id, user) + + # Get or create user preference for this search space + preference = await get_or_create_user_preference( + session, user.id, search_space_id + ) + + # Validate that all provided LLM config IDs belong to the search space update_data = preferences.model_dump(exclude_unset=True) for _key, llm_config_id in update_data.items(): if llm_config_id is not None: - # Verify ownership of the LLM config + # Verify the LLM config belongs to the search space result = await session.execute( select(LLMConfig).filter( - LLMConfig.id == llm_config_id, LLMConfig.user_id == user.id + LLMConfig.id == llm_config_id, + LLMConfig.search_space_id == search_space_id, ) ) llm_config = result.scalars().first() if not llm_config: raise HTTPException( status_code=404, - detail=f"LLM configuration {llm_config_id} not found or you don't have permission to access it", + detail=f"LLM configuration {llm_config_id} not found in this search space", ) # Update user preferences for key, value in update_data.items(): - setattr(user, key, value) + setattr(preference, key, value) await session.commit() - await session.refresh(user) + await session.refresh(preference) + + # Reload relationships + await session.refresh( + preference, ["long_context_llm", "fast_llm", "strategic_llm"] + ) # Return updated preferences - return await get_user_llm_preferences(session, user) + return { + "long_context_llm_id": preference.long_context_llm_id, + "fast_llm_id": preference.fast_llm_id, + "strategic_llm_id": preference.strategic_llm_id, + "long_context_llm": preference.long_context_llm, + "fast_llm": preference.fast_llm, + "strategic_llm": preference.strategic_llm, + } except HTTPException: raise except Exception as e: diff --git a/surfsense_backend/app/schemas/llm_config.py b/surfsense_backend/app/schemas/llm_config.py index c3c003397..8beb65347 100644 --- a/surfsense_backend/app/schemas/llm_config.py +++ b/surfsense_backend/app/schemas/llm_config.py @@ -1,4 +1,3 @@ -import uuid from datetime import datetime from typing import Any @@ -30,7 +29,9 @@ class LLMConfigBase(BaseModel): class LLMConfigCreate(LLMConfigBase): - pass + search_space_id: int = Field( + ..., description="Search space ID to associate the LLM config with" + ) class LLMConfigUpdate(BaseModel): @@ -56,6 +57,6 @@ class LLMConfigUpdate(BaseModel): class LLMConfigRead(LLMConfigBase, IDModel, TimestampModel): id: int created_at: datetime - user_id: uuid.UUID + search_space_id: int model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/schemas/search_space.py b/surfsense_backend/app/schemas/search_space.py index 49b5cd094..00bfdc0f6 100644 --- a/surfsense_backend/app/schemas/search_space.py +++ b/surfsense_backend/app/schemas/search_space.py @@ -24,4 +24,4 @@ class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel): created_at: datetime user_id: uuid.UUID - model_config = ConfigDict(from_attributes=True) \ No newline at end of file + model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/services/llm_service.py b/surfsense_backend/app/services/llm_service.py index 9135e49dc..d9299549c 100644 --- a/surfsense_backend/app/services/llm_service.py +++ b/surfsense_backend/app/services/llm_service.py @@ -1,10 +1,14 @@ import logging +import litellm from langchain_litellm import ChatLiteLLM from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from app.db import LLMConfig, User +from app.db import LLMConfig, UserSearchSpacePreference + +# Configure litellm to automatically drop unsupported parameters +litellm.drop_params = True logger = logging.getLogger(__name__) @@ -16,54 +20,67 @@ class LLMRole: async def get_user_llm_instance( - session: AsyncSession, user_id: str, role: str + session: AsyncSession, user_id: str, search_space_id: int, role: str ) -> ChatLiteLLM | None: """ - Get a ChatLiteLLM instance for a specific user and role. + Get a ChatLiteLLM instance for a specific user, search space, and role. Args: session: Database session user_id: User ID + search_space_id: Search Space ID role: LLM role ('long_context', 'fast', or 'strategic') Returns: ChatLiteLLM instance or None if not found """ try: - # Get user with their LLM preferences - result = await session.execute(select(User).where(User.id == user_id)) - user = result.scalars().first() + # Get user's LLM preferences for this search space + result = await session.execute( + select(UserSearchSpacePreference).where( + UserSearchSpacePreference.user_id == user_id, + UserSearchSpacePreference.search_space_id == search_space_id, + ) + ) + preference = result.scalars().first() - if not user: - logger.error(f"User {user_id} not found") + if not preference: + logger.error( + f"No LLM preferences found for user {user_id} in search space {search_space_id}" + ) return None # Get the appropriate LLM config ID based on role llm_config_id = None if role == LLMRole.LONG_CONTEXT: - llm_config_id = user.long_context_llm_id + llm_config_id = preference.long_context_llm_id elif role == LLMRole.FAST: - llm_config_id = user.fast_llm_id + llm_config_id = preference.fast_llm_id elif role == LLMRole.STRATEGIC: - llm_config_id = user.strategic_llm_id + llm_config_id = preference.strategic_llm_id else: logger.error(f"Invalid LLM role: {role}") return None if not llm_config_id: - logger.error(f"No {role} LLM configured for user {user_id}") + logger.error( + f"No {role} LLM configured for user {user_id} in search space {search_space_id}" + ) return None # Get the LLM configuration result = await session.execute( select(LLMConfig).where( - LLMConfig.id == llm_config_id, LLMConfig.user_id == user_id + LLMConfig.id == llm_config_id, + LLMConfig.search_space_id == search_space_id, ) ) llm_config = result.scalars().first() if not llm_config: - logger.error(f"LLM config {llm_config_id} not found for user {user_id}") + logger.error( + f"LLM config {llm_config_id} not found in search space {search_space_id}" + ) return None # Build the model string for litellm @@ -113,19 +130,25 @@ async def get_user_llm_instance( async def get_user_long_context_llm( - session: AsyncSession, user_id: str + session: AsyncSession, user_id: str, search_space_id: int ) -> ChatLiteLLM | None: - """Get user's long context LLM instance.""" - return await get_user_llm_instance(session, user_id, LLMRole.LONG_CONTEXT) + """Get user's long context LLM instance for a specific search space.""" + return await get_user_llm_instance( + session, user_id, search_space_id, LLMRole.LONG_CONTEXT + ) -async def get_user_fast_llm(session: AsyncSession, user_id: str) -> ChatLiteLLM | None: - """Get user's fast LLM instance.""" - return await get_user_llm_instance(session, user_id, LLMRole.FAST) +async def get_user_fast_llm( + session: AsyncSession, user_id: str, search_space_id: int +) -> ChatLiteLLM | None: + """Get user's fast LLM instance for a specific search space.""" + return await get_user_llm_instance(session, user_id, search_space_id, LLMRole.FAST) async def get_user_strategic_llm( - session: AsyncSession, user_id: str + session: AsyncSession, user_id: str, search_space_id: int ) -> ChatLiteLLM | None: - """Get user's strategic LLM instance.""" - return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC) + """Get user's strategic LLM instance for a specific search space.""" + return await get_user_llm_instance( + session, user_id, search_space_id, LLMRole.STRATEGIC + ) diff --git a/surfsense_backend/app/services/query_service.py b/surfsense_backend/app/services/query_service.py index 4a4bc59be..d2759ab27 100644 --- a/surfsense_backend/app/services/query_service.py +++ b/surfsense_backend/app/services/query_service.py @@ -17,6 +17,7 @@ class QueryService: user_query: str, session: AsyncSession, user_id: str, + search_space_id: int, chat_history_str: str | None = None, ) -> str: """ @@ -27,6 +28,7 @@ class QueryService: user_query: The original user query session: Database session for accessing user LLM configs user_id: User ID to get their specific LLM configuration + search_space_id: Search Space ID to get user's LLM preferences chat_history_str: Optional chat history string Returns: @@ -37,10 +39,10 @@ class QueryService: try: # Get the user's strategic LLM instance - llm = await get_user_strategic_llm(session, user_id) + llm = await get_user_strategic_llm(session, user_id, search_space_id) if not llm: print( - f"Warning: No strategic LLM configured for user {user_id}. Using original query." + f"Warning: No strategic LLM configured for user {user_id} in search space {search_space_id}. Using original query." ) return user_query diff --git a/surfsense_backend/app/tasks/connector_indexers/airtable_indexer.py b/surfsense_backend/app/tasks/connector_indexers/airtable_indexer.py index 7ba2f2d44..0cc21bb47 100644 --- a/surfsense_backend/app/tasks/connector_indexers/airtable_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/airtable_indexer.py @@ -260,7 +260,9 @@ async def index_airtable_records( continue # Generate document summary - user_llm = await get_user_long_context_llm(session, user_id) + user_llm = await get_user_long_context_llm( + session, user_id, search_space_id + ) if user_llm: document_metadata = { diff --git a/surfsense_backend/app/tasks/connector_indexers/clickup_indexer.py b/surfsense_backend/app/tasks/connector_indexers/clickup_indexer.py index 3120fcbc7..5ee7342fa 100644 --- a/surfsense_backend/app/tasks/connector_indexers/clickup_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/clickup_indexer.py @@ -222,7 +222,9 @@ async def index_clickup_tasks( continue # Generate summary with metadata - user_llm = await get_user_long_context_llm(session, user_id) + user_llm = await get_user_long_context_llm( + session, user_id, search_space_id + ) if user_llm: document_metadata = { diff --git a/surfsense_backend/app/tasks/connector_indexers/confluence_indexer.py b/surfsense_backend/app/tasks/connector_indexers/confluence_indexer.py index 625992922..28cb3b1f4 100644 --- a/surfsense_backend/app/tasks/connector_indexers/confluence_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/confluence_indexer.py @@ -233,7 +233,9 @@ async def index_confluence_pages( continue # Generate summary with metadata - user_llm = await get_user_long_context_llm(session, user_id) + user_llm = await get_user_long_context_llm( + session, user_id, search_space_id + ) comment_count = len(comments) if user_llm: diff --git a/surfsense_backend/app/tasks/connector_indexers/discord_indexer.py b/surfsense_backend/app/tasks/connector_indexers/discord_indexer.py index c538f12f7..08c995f64 100644 --- a/surfsense_backend/app/tasks/connector_indexers/discord_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/discord_indexer.py @@ -325,7 +325,9 @@ async def index_discord_messages( continue # Get user's long context LLM - user_llm = await get_user_long_context_llm(session, user_id) + user_llm = await get_user_long_context_llm( + session, user_id, search_space_id + ) if not user_llm: logger.error( f"No long context LLM configured for user {user_id}" diff --git a/surfsense_backend/app/tasks/connector_indexers/github_indexer.py b/surfsense_backend/app/tasks/connector_indexers/github_indexer.py index ba01e3979..9cc0c0993 100644 --- a/surfsense_backend/app/tasks/connector_indexers/github_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/github_indexer.py @@ -213,7 +213,9 @@ async def index_github_repos( continue # Generate summary with metadata - user_llm = await get_user_long_context_llm(session, user_id) + user_llm = await get_user_long_context_llm( + session, user_id, search_space_id + ) if user_llm: # Extract file extension from file path file_extension = ( diff --git a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py index 4d4794284..be5169612 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py @@ -266,7 +266,9 @@ async def index_google_calendar_events( continue # Generate summary with metadata - user_llm = await get_user_long_context_llm(session, user_id) + user_llm = await get_user_long_context_llm( + session, user_id, search_space_id + ) if user_llm: document_metadata = { diff --git a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py index 38fec29a9..872e19d03 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py @@ -210,7 +210,9 @@ async def index_google_gmail_messages( continue # Generate summary with metadata - user_llm = await get_user_long_context_llm(session, user_id) + user_llm = await get_user_long_context_llm( + session, user_id, search_space_id + ) if user_llm: document_metadata = { diff --git a/surfsense_backend/app/tasks/connector_indexers/jira_indexer.py b/surfsense_backend/app/tasks/connector_indexers/jira_indexer.py index c199faed0..e9d556954 100644 --- a/surfsense_backend/app/tasks/connector_indexers/jira_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/jira_indexer.py @@ -216,7 +216,9 @@ async def index_jira_issues( continue # Generate summary with metadata - user_llm = await get_user_long_context_llm(session, user_id) + user_llm = await get_user_long_context_llm( + session, user_id, search_space_id + ) comment_count = len(formatted_issue.get("comments", [])) if user_llm: diff --git a/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py b/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py index 6ca145357..aca1e2040 100644 --- a/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py @@ -228,7 +228,9 @@ async def index_linear_issues( continue # Generate summary with metadata - user_llm = await get_user_long_context_llm(session, user_id) + user_llm = await get_user_long_context_llm( + session, user_id, search_space_id + ) state = formatted_issue.get("state", "Unknown") description = formatted_issue.get("description", "") comment_count = len(formatted_issue.get("comments", [])) diff --git a/surfsense_backend/app/tasks/connector_indexers/luma_indexer.py b/surfsense_backend/app/tasks/connector_indexers/luma_indexer.py index d7b6d3058..3d8970654 100644 --- a/surfsense_backend/app/tasks/connector_indexers/luma_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/luma_indexer.py @@ -270,7 +270,9 @@ async def index_luma_events( continue # Generate summary with metadata - user_llm = await get_user_long_context_llm(session, user_id) + user_llm = await get_user_long_context_llm( + session, user_id, search_space_id + ) if user_llm: document_metadata = { diff --git a/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py b/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py index a6c8853a3..b290f86da 100644 --- a/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py @@ -299,7 +299,9 @@ async def index_notion_pages( continue # Get user's long context LLM - user_llm = await get_user_long_context_llm(session, user_id) + user_llm = await get_user_long_context_llm( + session, user_id, search_space_id + ) if not user_llm: logger.error(f"No long context LLM configured for user {user_id}") skipped_pages.append(f"{page_title} (no LLM configured)") diff --git a/surfsense_backend/app/tasks/document_processors/extension_processor.py b/surfsense_backend/app/tasks/document_processors/extension_processor.py index 8f8433148..ed25b8fbd 100644 --- a/surfsense_backend/app/tasks/document_processors/extension_processor.py +++ b/surfsense_backend/app/tasks/document_processors/extension_processor.py @@ -104,9 +104,11 @@ async def add_extension_received_document( return existing_document # Get user's long context LLM - user_llm = await get_user_long_context_llm(session, user_id) + user_llm = await get_user_long_context_llm(session, user_id, search_space_id) if not user_llm: - raise RuntimeError(f"No long context LLM configured for user {user_id}") + raise RuntimeError( + f"No long context LLM configured for user {user_id} in search space {search_space_id}" + ) # Generate summary with metadata document_metadata = { diff --git a/surfsense_backend/app/tasks/document_processors/file_processors.py b/surfsense_backend/app/tasks/document_processors/file_processors.py index 3803f4b2b..573b2c28c 100644 --- a/surfsense_backend/app/tasks/document_processors/file_processors.py +++ b/surfsense_backend/app/tasks/document_processors/file_processors.py @@ -60,9 +60,11 @@ async def add_received_file_document_using_unstructured( # TODO: Check if file_markdown exceeds token limit of embedding model # Get user's long context LLM - user_llm = await get_user_long_context_llm(session, user_id) + user_llm = await get_user_long_context_llm(session, user_id, search_space_id) if not user_llm: - raise RuntimeError(f"No long context LLM configured for user {user_id}") + raise RuntimeError( + f"No long context LLM configured for user {user_id} in search space {search_space_id}" + ) # Generate summary with metadata document_metadata = { @@ -140,9 +142,11 @@ async def add_received_file_document_using_llamacloud( return existing_document # Get user's long context LLM - user_llm = await get_user_long_context_llm(session, user_id) + user_llm = await get_user_long_context_llm(session, user_id, search_space_id) if not user_llm: - raise RuntimeError(f"No long context LLM configured for user {user_id}") + raise RuntimeError( + f"No long context LLM configured for user {user_id} in search space {search_space_id}" + ) # Generate summary with metadata document_metadata = { @@ -221,9 +225,11 @@ async def add_received_file_document_using_docling( return existing_document # Get user's long context LLM - user_llm = await get_user_long_context_llm(session, user_id) + user_llm = await get_user_long_context_llm(session, user_id, search_space_id) if not user_llm: - raise RuntimeError(f"No long context LLM configured for user {user_id}") + raise RuntimeError( + f"No long context LLM configured for user {user_id} in search space {search_space_id}" + ) # Generate summary using chunked processing for large documents from app.services.docling_service import create_docling_service diff --git a/surfsense_backend/app/tasks/document_processors/markdown_processor.py b/surfsense_backend/app/tasks/document_processors/markdown_processor.py index 493b046af..fa3c79d81 100644 --- a/surfsense_backend/app/tasks/document_processors/markdown_processor.py +++ b/surfsense_backend/app/tasks/document_processors/markdown_processor.py @@ -75,9 +75,11 @@ async def add_received_markdown_file_document( return existing_document # Get user's long context LLM - user_llm = await get_user_long_context_llm(session, user_id) + user_llm = await get_user_long_context_llm(session, user_id, search_space_id) if not user_llm: - raise RuntimeError(f"No long context LLM configured for user {user_id}") + raise RuntimeError( + f"No long context LLM configured for user {user_id} in search space {search_space_id}" + ) # Generate summary with metadata document_metadata = { diff --git a/surfsense_backend/app/tasks/document_processors/url_crawler.py b/surfsense_backend/app/tasks/document_processors/url_crawler.py index eddcda388..682086112 100644 --- a/surfsense_backend/app/tasks/document_processors/url_crawler.py +++ b/surfsense_backend/app/tasks/document_processors/url_crawler.py @@ -161,9 +161,11 @@ async def add_crawled_url_document( ) # Get user's long context LLM - user_llm = await get_user_long_context_llm(session, user_id) + user_llm = await get_user_long_context_llm(session, user_id, search_space_id) if not user_llm: - raise RuntimeError(f"No long context LLM configured for user {user_id}") + raise RuntimeError( + f"No long context LLM configured for user {user_id} in search space {search_space_id}" + ) # Generate summary await task_logger.log_task_progress( diff --git a/surfsense_backend/app/tasks/document_processors/youtube_processor.py b/surfsense_backend/app/tasks/document_processors/youtube_processor.py index 37981f6ae..a28a7f186 100644 --- a/surfsense_backend/app/tasks/document_processors/youtube_processor.py +++ b/surfsense_backend/app/tasks/document_processors/youtube_processor.py @@ -234,9 +234,11 @@ async def add_youtube_video_document( ) # Get user's long context LLM - user_llm = await get_user_long_context_llm(session, user_id) + user_llm = await get_user_long_context_llm(session, user_id, search_space_id) if not user_llm: - raise RuntimeError(f"No long context LLM configured for user {user_id}") + raise RuntimeError( + f"No long context LLM configured for user {user_id} in search space {search_space_id}" + ) # Generate summary await task_logger.log_task_progress( diff --git a/surfsense_backend/app/tasks/podcast_tasks.py b/surfsense_backend/app/tasks/podcast_tasks.py index 312ae5bc3..e5f828ef2 100644 --- a/surfsense_backend/app/tasks/podcast_tasks.py +++ b/surfsense_backend/app/tasks/podcast_tasks.py @@ -98,6 +98,7 @@ async def generate_chat_podcast( "configurable": { "podcast_title": "SurfSense", "user_id": str(user_id), + "search_space_id": search_space_id, } } # Initialize state with database session and streaming service diff --git a/surfsense_backend/app/utils/validators.py b/surfsense_backend/app/utils/validators.py index 7677d54f2..437d23b55 100644 --- a/surfsense_backend/app/utils/validators.py +++ b/surfsense_backend/app/utils/validators.py @@ -16,89 +16,80 @@ from fastapi import HTTPException def validate_search_space_id(search_space_id: Any) -> int: """ Validate and convert search_space_id to integer. - + Args: search_space_id: The search space ID to validate - + Returns: int: Validated search space ID - + Raises: HTTPException: If validation fails """ if search_space_id is None: - raise HTTPException( - status_code=400, - detail="search_space_id is required" - ) - + raise HTTPException(status_code=400, detail="search_space_id is required") + if isinstance(search_space_id, bool): raise HTTPException( - status_code=400, - detail="search_space_id must be an integer, not a boolean" + status_code=400, detail="search_space_id must be an integer, not a boolean" ) - + if isinstance(search_space_id, int): if search_space_id <= 0: raise HTTPException( - status_code=400, - detail="search_space_id must be a positive integer" - + status_code=400, detail="search_space_id must be a positive integer" ) return search_space_id - + if isinstance(search_space_id, str): # Check if it's a valid integer string if not search_space_id.strip(): raise HTTPException( - status_code=400, - detail="search_space_id cannot be empty" + status_code=400, detail="search_space_id cannot be empty" ) - + # Check for valid integer format (no leading zeros, no decimal points) - if not re.match(r'^[1-9]\d*$', search_space_id.strip()): + if not re.match(r"^[1-9]\d*$", search_space_id.strip()): raise HTTPException( status_code=400, - detail="search_space_id must be a valid positive integer" + detail="search_space_id must be a valid positive integer", ) - + value = int(search_space_id.strip()) # Regex already guarantees value > 0, but check retained for clarity if value <= 0: raise HTTPException( - status_code=400, - detail="search_space_id must be a positive integer" + status_code=400, detail="search_space_id must be a positive integer" ) return value - + raise HTTPException( status_code=400, - detail="search_space_id must be an integer or string representation of an integer" + detail="search_space_id must be an integer or string representation of an integer", ) def validate_document_ids(document_ids: Any) -> list[int]: """ Validate and convert document_ids to list of integers. - + Args: document_ids: The document IDs to validate - + Returns: List[int]: Validated list of document IDs - + Raises: HTTPException: If validation fails """ if document_ids is None: return [] - + if not isinstance(document_ids, list): raise HTTPException( - status_code=400, - detail="document_ids_to_add_in_context must be a list" + status_code=400, detail="document_ids_to_add_in_context must be a list" ) - + validated_ids = [] for i, doc_id in enumerate(document_ids): if isinstance(doc_id, bool): @@ -111,119 +102,110 @@ def validate_document_ids(document_ids: Any) -> list[int]: if doc_id <= 0: raise HTTPException( status_code=400, - detail=f"document_ids_to_add_in_context[{i}] must be a positive integer" + detail=f"document_ids_to_add_in_context[{i}] must be a positive integer", ) validated_ids.append(doc_id) elif isinstance(doc_id, str): if not doc_id.strip(): raise HTTPException( status_code=400, - detail=f"document_ids_to_add_in_context[{i}] cannot be empty" + detail=f"document_ids_to_add_in_context[{i}] cannot be empty", ) - - if not re.match(r'^[1-9]\d*$', doc_id.strip()): + + if not re.match(r"^[1-9]\d*$", doc_id.strip()): raise HTTPException( status_code=400, - detail=f"document_ids_to_add_in_context[{i}] must be a valid positive integer" + detail=f"document_ids_to_add_in_context[{i}] must be a valid positive integer", ) - + value = int(doc_id.strip()) # Regex already guarantees value > 0 if value <= 0: raise HTTPException( status_code=400, - detail=f"document_ids_to_add_in_context[{i}] must be a positive integer" + detail=f"document_ids_to_add_in_context[{i}] must be a positive integer", ) validated_ids.append(value) else: raise HTTPException( status_code=400, - detail=f"document_ids_to_add_in_context[{i}] must be an integer or string representation of an integer" + detail=f"document_ids_to_add_in_context[{i}] must be an integer or string representation of an integer", ) - + return validated_ids def validate_connectors(connectors: Any) -> list[str]: """ Validate selected_connectors list. - + Args: connectors: The connectors to validate - + Returns: List[str]: Validated list of connector names - + Raises: HTTPException: If validation fails """ if connectors is None: return [] - + if not isinstance(connectors, list): raise HTTPException( - status_code=400, - detail="selected_connectors must be a list" + status_code=400, detail="selected_connectors must be a list" ) - + validated_connectors = [] for i, connector in enumerate(connectors): if not isinstance(connector, str): raise HTTPException( - status_code=400, - detail=f"selected_connectors[{i}] must be a string" + status_code=400, detail=f"selected_connectors[{i}] must be a string" ) - + if not connector.strip(): raise HTTPException( - status_code=400, - detail=f"selected_connectors[{i}] cannot be empty" + status_code=400, detail=f"selected_connectors[{i}] cannot be empty" ) - + trimmed = connector.strip() - if not re.fullmatch(r'[\w\-_]+', trimmed): + if not re.fullmatch(r"[\w\-_]+", trimmed): raise HTTPException( status_code=400, - detail=f"selected_connectors[{i}] contains invalid characters" + detail=f"selected_connectors[{i}] contains invalid characters", ) validated_connectors.append(trimmed) - + return validated_connectors def validate_research_mode(research_mode: Any) -> str: """ Validate research_mode parameter. - + Args: research_mode: The research mode to validate - + Returns: str: Validated research mode - + Raises: HTTPException: If validation fails """ if research_mode is None: return "QNA" # Default value - + if not isinstance(research_mode, str): - raise HTTPException( - status_code=400, - detail="research_mode must be a string" - ) + raise HTTPException(status_code=400, detail="research_mode must be a string") normalized_mode = research_mode.strip().upper() if not normalized_mode: - raise HTTPException( - status_code=400, - detail="research_mode cannot be empty" - ) + raise HTTPException(status_code=400, detail="research_mode cannot be empty") valid_modes = ["REPORT_GENERAL", "REPORT_DEEP", "REPORT_DEEPER", "QNA"] if normalized_mode not in valid_modes: raise HTTPException( status_code=400, - detail=f"research_mode must be one of: {', '.join(valid_modes)}" + detail=f"research_mode must be one of: {', '.join(valid_modes)}", ) return normalized_mode @@ -231,36 +213,30 @@ def validate_research_mode(research_mode: Any) -> str: def validate_search_mode(search_mode: Any) -> str: """ Validate search_mode parameter. - + Args: search_mode: The search mode to validate - + Returns: str: Validated search mode - + Raises: HTTPException: If validation fails """ if search_mode is None: return "CHUNKS" # Default value - + if not isinstance(search_mode, str): - raise HTTPException( - status_code=400, - detail="search_mode must be a string" - ) + raise HTTPException(status_code=400, detail="search_mode must be a string") normalized_mode = search_mode.strip().upper() if not normalized_mode: - raise HTTPException( - status_code=400, - detail="search_mode cannot be empty" - ) + raise HTTPException(status_code=400, detail="search_mode cannot be empty") valid_modes = ["CHUNKS", "DOCUMENTS"] if normalized_mode not in valid_modes: raise HTTPException( status_code=400, - detail=f"search_mode must be one of: {', '.join(valid_modes)}" + detail=f"search_mode must be one of: {', '.join(valid_modes)}", ) return normalized_mode @@ -268,185 +244,155 @@ def validate_search_mode(search_mode: Any) -> str: def validate_messages(messages: Any) -> list[dict]: """ Validate messages structure. - + Args: messages: The messages to validate - + Returns: List[dict]: Validated messages - + Raises: HTTPException: If validation fails """ if not isinstance(messages, list): - raise HTTPException( - status_code=400, - detail="messages must be a list" - ) - + raise HTTPException(status_code=400, detail="messages must be a list") + if not messages: - raise HTTPException( - status_code=400, - detail="messages cannot be empty" - ) - + raise HTTPException(status_code=400, detail="messages cannot be empty") + validated_messages = [] for i, message in enumerate(messages): if not isinstance(message, dict): raise HTTPException( - status_code=400, - detail=f"messages[{i}] must be a dictionary" + status_code=400, detail=f"messages[{i}] must be a dictionary" ) - + if "role" not in message: raise HTTPException( - status_code=400, - detail=f"messages[{i}] must have a 'role' field" + status_code=400, detail=f"messages[{i}] must have a 'role' field" ) - + if "content" not in message: raise HTTPException( - status_code=400, - detail=f"messages[{i}] must have a 'content' field" + status_code=400, detail=f"messages[{i}] must have a 'content' field" ) - + role = message["role"] if not isinstance(role, str) or role not in ["user", "assistant", "system"]: raise HTTPException( status_code=400, - detail=f"messages[{i}].role must be 'user', 'assistant', or 'system'" + detail=f"messages[{i}].role must be 'user', 'assistant', or 'system'", ) - + content = message["content"] if not isinstance(content, str): raise HTTPException( - status_code=400, - detail=f"messages[{i}].content must be a string" + status_code=400, detail=f"messages[{i}].content must be a string" ) - + if not content.strip(): raise HTTPException( - status_code=400, - detail=f"messages[{i}].content cannot be empty" + status_code=400, detail=f"messages[{i}].content cannot be empty" ) - + # Trim content and enforce max length (10,000 chars) sanitized_content = content.strip() if len(sanitized_content) > 10000: # Reasonable limit raise HTTPException( status_code=400, - detail=f"messages[{i}].content is too long (max 10000 characters)" + detail=f"messages[{i}].content is too long (max 10000 characters)", ) - - validated_messages.append({ - "role": role, - "content": sanitized_content - }) - + + validated_messages.append({"role": role, "content": sanitized_content}) + return validated_messages def validate_email(email: str) -> str: """ Validate email address using pyvalidators library. - + Args: email: The email address to validate - + Returns: str: Validated email address - + Raises: HTTPException: If validation fails """ if not email or not email.strip(): - raise HTTPException( - status_code=400, - detail="Email address is required" - ) - + raise HTTPException(status_code=400, detail="Email address is required") + email = email.strip() - + if not validators.email(email): - raise HTTPException( - status_code=400, - detail="Invalid email address format" - ) - + raise HTTPException(status_code=400, detail="Invalid email address format") + return email def validate_url(url: str) -> str: """ Validate URL using pyvalidators library. - + Args: url: The URL to validate - + Returns: str: Validated URL - + Raises: HTTPException: If validation fails """ if not url or not url.strip(): - raise HTTPException( - status_code=400, - detail="URL is required" - ) - + raise HTTPException(status_code=400, detail="URL is required") + url = url.strip() - + if not validators.url(url): - raise HTTPException( - status_code=400, - detail="Invalid URL format" - ) - + raise HTTPException(status_code=400, detail="Invalid URL format") + return url def validate_uuid(uuid_string: str) -> str: """ Validate UUID using pyvalidators library. - + Args: uuid_string: The UUID string to validate - + Returns: str: Validated UUID string - + Raises: HTTPException: If validation fails """ if not uuid_string or not uuid_string.strip(): - raise HTTPException( - status_code=400, - detail="UUID is required" - ) - + raise HTTPException(status_code=400, detail="UUID is required") + uuid_string = uuid_string.strip() - + if not validators.uuid(uuid_string): - raise HTTPException( - status_code=400, - detail="Invalid UUID format" - ) - + raise HTTPException(status_code=400, detail="Invalid UUID format") + return uuid_string -def validate_connector_config(connector_type: str | Any, config: dict[str, Any]) -> dict[str, Any]: +def validate_connector_config( + connector_type: str | Any, config: dict[str, Any] +) -> dict[str, Any]: """ Validate connector configuration based on connector type. - + Args: connector_type: The type of connector (string or enum) config: The configuration dictionary to validate - + Returns: dict: Validated configuration - + Raises: ValueError: If validation fails """ @@ -454,76 +400,69 @@ def validate_connector_config(connector_type: str | Any, config: dict[str, Any]) raise ValueError("config must be a dictionary of connector settings") # Convert enum to string if needed - connector_type_str = str(connector_type).split('.')[-1] if hasattr(connector_type, 'value') else str(connector_type) - + connector_type_str = ( + str(connector_type).split(".")[-1] + if hasattr(connector_type, "value") + else str(connector_type) + ) + # Validation function helpers def validate_email_field(key: str, connector_name: str) -> None: if not validators.email(config.get(key, "")): raise ValueError(f"Invalid email format for {connector_name} connector") - + def validate_url_field(key: str, connector_name: str) -> None: if not validators.url(config.get(key, "")): raise ValueError(f"Invalid base URL format for {connector_name} connector") - + def validate_list_field(key: str, field_name: str) -> None: value = config.get(key) if not isinstance(value, list) or not value: raise ValueError(f"{field_name} must be a non-empty list of strings") - + # Lookup table for connector validation rules connector_rules = { - "SERPER_API": { - "required": ["SERPER_API_KEY"], - "validators": {} - }, - "TAVILY_API": { - "required": ["TAVILY_API_KEY"], - "validators": {} - }, - "LINKUP_API": { - "required": ["LINKUP_API_KEY"], - "validators": {} - }, - "SLACK_CONNECTOR": { - "required": ["SLACK_BOT_TOKEN"], - "validators": {} - }, + "SERPER_API": {"required": ["SERPER_API_KEY"], "validators": {}}, + "TAVILY_API": {"required": ["TAVILY_API_KEY"], "validators": {}}, + "LINKUP_API": {"required": ["LINKUP_API_KEY"], "validators": {}}, + "SLACK_CONNECTOR": {"required": ["SLACK_BOT_TOKEN"], "validators": {}}, "NOTION_CONNECTOR": { "required": ["NOTION_INTEGRATION_TOKEN"], - "validators": {} + "validators": {}, }, "GITHUB_CONNECTOR": { "required": ["GITHUB_PAT", "repo_full_names"], "validators": { - "repo_full_names": lambda: validate_list_field("repo_full_names", "repo_full_names") - } - }, - "LINEAR_CONNECTOR": { - "required": ["LINEAR_API_KEY"], - "validators": {} - }, - "DISCORD_CONNECTOR": { - "required": ["DISCORD_BOT_TOKEN"], - "validators": {} + "repo_full_names": lambda: validate_list_field( + "repo_full_names", "repo_full_names" + ) + }, }, + "LINEAR_CONNECTOR": {"required": ["LINEAR_API_KEY"], "validators": {}}, + "DISCORD_CONNECTOR": {"required": ["DISCORD_BOT_TOKEN"], "validators": {}}, "JIRA_CONNECTOR": { "required": ["JIRA_EMAIL", "JIRA_API_TOKEN", "JIRA_BASE_URL"], "validators": { "JIRA_EMAIL": lambda: validate_email_field("JIRA_EMAIL", "JIRA"), - "JIRA_BASE_URL": lambda: validate_url_field("JIRA_BASE_URL", "JIRA") - } + "JIRA_BASE_URL": lambda: validate_url_field("JIRA_BASE_URL", "JIRA"), + }, }, "CONFLUENCE_CONNECTOR": { - "required": ["CONFLUENCE_BASE_URL", "CONFLUENCE_EMAIL", "CONFLUENCE_API_TOKEN"], + "required": [ + "CONFLUENCE_BASE_URL", + "CONFLUENCE_EMAIL", + "CONFLUENCE_API_TOKEN", + ], "validators": { - "CONFLUENCE_EMAIL": lambda: validate_email_field("CONFLUENCE_EMAIL", "Confluence"), - "CONFLUENCE_BASE_URL": lambda: validate_url_field("CONFLUENCE_BASE_URL", "Confluence") - } - }, - "CLICKUP_CONNECTOR": { - "required": ["CLICKUP_API_TOKEN"], - "validators": {} + "CONFLUENCE_EMAIL": lambda: validate_email_field( + "CONFLUENCE_EMAIL", "Confluence" + ), + "CONFLUENCE_BASE_URL": lambda: validate_url_field( + "CONFLUENCE_BASE_URL", "Confluence" + ), + }, }, + "CLICKUP_CONNECTOR": {"required": ["CLICKUP_API_TOKEN"], "validators": {}}, # "GOOGLE_CALENDAR_CONNECTOR": { # "required": ["token", "refresh_token", "token_uri", "client_id", "expiry", "scopes", "client_secret"], # "validators": {}, @@ -538,26 +477,23 @@ def validate_connector_config(connector_type: str | Any, config: dict[str, Any]) # "required": ["AIRTABLE_API_KEY", "AIRTABLE_BASE_ID"], # "validators": {} # }, - "LUMA_CONNECTOR": { - "required": ["LUMA_API_KEY"], - "validators": {} - } + "LUMA_CONNECTOR": {"required": ["LUMA_API_KEY"], "validators": {}}, } - + rules = connector_rules.get(connector_type_str) if not rules: return config # Unknown connector type, pass through - + # Validate required keys match exactly if set(config.keys()) != set(rules["required"]): raise ValueError( f"For {connector_type_str} connector type, config must only contain these keys: {rules['required']}" ) - + # Apply custom validators first (these check format before emptiness) for validator_func in rules["validators"].values(): validator_func() - + # Validate each field is not empty for key in rules["required"]: # Special handling for Google connectors that don't allow None or empty strings @@ -568,5 +504,5 @@ def validate_connector_config(connector_type: str | Any, config: dict[str, Any]) # Standard check: field must have a truthy value if not config.get(key): raise ValueError(f"{key} cannot be empty") - + return config diff --git a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx index 6610cb046..f3c1531a8 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/client-layout.tsx @@ -1,12 +1,16 @@ "use client"; +import { Loader2 } from "lucide-react"; +import { usePathname, useRouter } from "next/navigation"; import type React from "react"; -import { useState } from "react"; +import { useEffect, useState } from "react"; import { DashboardBreadcrumb } from "@/components/dashboard-breadcrumb"; import { AppSidebarProvider } from "@/components/sidebar/AppSidebarProvider"; import { ThemeTogglerComponent } from "@/components/theme/theme-toggle"; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; import { Separator } from "@/components/ui/separator"; import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar"; +import { useLLMPreferences } from "@/hooks/use-llm-configs"; export function DashboardClientLayout({ children, @@ -19,6 +23,16 @@ export function DashboardClientLayout({ navSecondary: any[]; navMain: any[]; }) { + const router = useRouter(); + const pathname = usePathname(); + const searchSpaceIdNum = Number(searchSpaceId); + + const { loading, error, isOnboardingComplete } = useLLMPreferences(searchSpaceIdNum); + const [hasCheckedOnboarding, setHasCheckedOnboarding] = useState(false); + + // Skip onboarding check if we're already on the onboarding page + const isOnboardingPage = pathname?.includes("/onboard"); + const [open, setOpen] = useState(() => { try { const match = document.cookie.match(/(?:^|; )sidebar_state=([^;]+)/); @@ -29,6 +43,68 @@ export function DashboardClientLayout({ return true; }); + useEffect(() => { + // Skip check if already on onboarding page + if (isOnboardingPage) { + setHasCheckedOnboarding(true); + return; + } + + // Only check once after preferences have loaded + if (!loading && !hasCheckedOnboarding) { + const onboardingComplete = isOnboardingComplete(); + + if (!onboardingComplete) { + router.push(`/dashboard/${searchSpaceId}/onboard`); + } + + setHasCheckedOnboarding(true); + } + }, [ + loading, + isOnboardingComplete, + isOnboardingPage, + router, + searchSpaceId, + hasCheckedOnboarding, + ]); + + // Show loading screen while checking onboarding status (only on first load) + if (!hasCheckedOnboarding && loading && !isOnboardingPage) { + return ( +
+ + + Loading Configuration + Checking your LLM preferences... + + + + + +
+ ); + } + + // Show error screen if there's an error loading preferences (but not on onboarding page) + if (error && !hasCheckedOnboarding && !isOnboardingPage) { + return ( +
+ + + + Configuration Error + + Failed to load your LLM configuration + + +

{error}

+
+
+
+ ); + } + return ( {/* Use AppSidebarProvider which fetches user, search space, and recent chats */} diff --git a/surfsense_web/app/dashboard/[search_space_id]/layout.tsx b/surfsense_web/app/dashboard/[search_space_id]/layout.tsx index d0e04fe68..b012484aa 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/layout.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/layout.tsx @@ -33,6 +33,12 @@ export default function DashboardLayout({ icon: "SquareTerminal", items: [], }, + { + title: "Manage LLMs", + url: `/dashboard/${search_space_id}/settings`, + icon: "Settings2", + items: [], + }, { title: "Documents", diff --git a/surfsense_web/app/onboard/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx similarity index 87% rename from surfsense_web/app/onboard/page.tsx rename to surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx index 387bf736a..0ff0fb205 100644 --- a/surfsense_web/app/onboard/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/onboard/page.tsx @@ -2,7 +2,7 @@ import { ArrowLeft, ArrowRight, Bot, CheckCircle, Sparkles } from "lucide-react"; import { AnimatePresence, motion } from "motion/react"; -import { useRouter } from "next/navigation"; +import { useParams, useRouter } from "next/navigation"; import { useEffect, useState } from "react"; import { Logo } from "@/components/Logo"; import { AddProviderStep } from "@/components/onboard/add-provider-step"; @@ -17,13 +17,16 @@ const TOTAL_STEPS = 3; const OnboardPage = () => { const router = useRouter(); - const { llmConfigs, loading: configsLoading, refreshConfigs } = useLLMConfigs(); + const params = useParams(); + const searchSpaceId = Number(params.search_space_id); + + const { llmConfigs, loading: configsLoading, refreshConfigs } = useLLMConfigs(searchSpaceId); const { preferences, loading: preferencesLoading, isOnboardingComplete, refreshPreferences, - } = useLLMPreferences(); + } = useLLMPreferences(searchSpaceId); const [currentStep, setCurrentStep] = useState(1); const [hasUserProgressed, setHasUserProgressed] = useState(false); @@ -44,11 +47,23 @@ const OnboardPage = () => { }, [currentStep]); // Redirect to dashboard if onboarding is already complete and user hasn't progressed (fresh page load) + // But only check once to avoid redirect loops useEffect(() => { - if (!preferencesLoading && isOnboardingComplete() && !hasUserProgressed) { - router.push("/dashboard"); + if (!preferencesLoading && !configsLoading && isOnboardingComplete() && !hasUserProgressed) { + // Small delay to ensure the check is stable + const timer = setTimeout(() => { + router.push(`/dashboard/${searchSpaceId}`); + }, 100); + return () => clearTimeout(timer); } - }, [preferencesLoading, isOnboardingComplete, hasUserProgressed, router]); + }, [ + preferencesLoading, + configsLoading, + isOnboardingComplete, + hasUserProgressed, + router, + searchSpaceId, + ]); const progress = (currentStep / TOTAL_STEPS) * 100; @@ -80,7 +95,7 @@ const OnboardPage = () => { }; const handleComplete = () => { - router.push("/dashboard"); + router.push(`/dashboard/${searchSpaceId}/documents`); }; if (configsLoading || preferencesLoading) { @@ -184,12 +199,18 @@ const OnboardPage = () => { > {currentStep === 1 && ( )} - {currentStep === 2 && } - {currentStep === 3 && } + {currentStep === 2 && ( + + )} + {currentStep === 3 && } diff --git a/surfsense_web/app/settings/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/settings/page.tsx similarity index 82% rename from surfsense_web/app/settings/page.tsx rename to surfsense_web/app/dashboard/[search_space_id]/settings/page.tsx index 18278be94..9eba74617 100644 --- a/surfsense_web/app/settings/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/settings/page.tsx @@ -1,14 +1,16 @@ "use client"; -import { ArrowLeft, Bot, Brain, Settings } from "lucide-react"; // Import ArrowLeft icon -import { useRouter } from "next/navigation"; // Add this import +import { ArrowLeft, Bot, Brain, Settings } from "lucide-react"; +import { useParams, useRouter } from "next/navigation"; import { LLMRoleManager } from "@/components/settings/llm-role-manager"; import { ModelConfigManager } from "@/components/settings/model-config-manager"; import { Separator } from "@/components/ui/separator"; import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; export default function SettingsPage() { - const router = useRouter(); // Initialize router + const router = useRouter(); + const params = useParams(); + const searchSpaceId = Number(params.search_space_id); return (
@@ -19,7 +21,7 @@ export default function SettingsPage() {
{/* Back Button */}
@@ -57,11 +59,11 @@ export default function SettingsPage() {
- + - + diff --git a/surfsense_web/app/dashboard/layout.tsx b/surfsense_web/app/dashboard/layout.tsx index 41a269d41..55482d299 100644 --- a/surfsense_web/app/dashboard/layout.tsx +++ b/surfsense_web/app/dashboard/layout.tsx @@ -4,7 +4,6 @@ import { Loader2 } from "lucide-react"; import { useRouter } from "next/navigation"; import { useEffect, useState } from "react"; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; -import { useLLMPreferences } from "@/hooks/use-llm-configs"; interface DashboardLayoutProps { children: React.ReactNode; @@ -12,7 +11,6 @@ interface DashboardLayoutProps { export default function DashboardLayout({ children }: DashboardLayoutProps) { const router = useRouter(); - const { loading, error, isOnboardingComplete } = useLLMPreferences(); const [isCheckingAuth, setIsCheckingAuth] = useState(true); useEffect(() => { @@ -25,23 +23,14 @@ export default function DashboardLayout({ children }: DashboardLayoutProps) { setIsCheckingAuth(false); }, [router]); - useEffect(() => { - // Wait for preferences to load, then check if onboarding is complete - if (!loading && !error && !isCheckingAuth) { - if (!isOnboardingComplete()) { - router.push("/onboard"); - } - } - }, [loading, error, isCheckingAuth, isOnboardingComplete, router]); - - // Show loading screen while checking authentication or loading preferences - if (isCheckingAuth || loading) { + // Show loading screen while checking authentication + if (isCheckingAuth) { return (
Loading Dashboard - Checking your configuration... + Checking authentication... @@ -51,42 +40,5 @@ export default function DashboardLayout({ children }: DashboardLayoutProps) { ); } - // Show error screen if there's an error loading preferences - if (error) { - return ( -
- - - - Configuration Error - - Failed to load your LLM configuration - - -

{error}

-
-
-
- ); - } - - // Only render children if onboarding is complete - if (isOnboardingComplete()) { - return <>{children}; - } - - // This should not be reached due to redirect, but just in case - return ( -
- - - Redirecting... - Taking you to complete your setup - - - - - -
- ); + return <>{children}; } diff --git a/surfsense_web/components/UserDropdown.tsx b/surfsense_web/components/UserDropdown.tsx index 5d15b152b..230bf0554 100644 --- a/surfsense_web/components/UserDropdown.tsx +++ b/surfsense_web/components/UserDropdown.tsx @@ -66,10 +66,6 @@ export function UserDropdown({ - router.push(`/settings`)}> - - Settings - Log out diff --git a/surfsense_web/components/chat/ChatInputGroup.tsx b/surfsense_web/components/chat/ChatInputGroup.tsx index f57691e4d..c3877c108 100644 --- a/surfsense_web/components/chat/ChatInputGroup.tsx +++ b/surfsense_web/components/chat/ChatInputGroup.tsx @@ -332,8 +332,11 @@ const ResearchModeSelector = React.memo( ResearchModeSelector.displayName = "ResearchModeSelector"; const LLMSelector = React.memo(() => { - const { llmConfigs, loading: llmLoading, error } = useLLMConfigs(); - const { preferences, updatePreferences, loading: preferencesLoading } = useLLMPreferences(); + const { search_space_id } = useParams(); + const searchSpaceId = Number(search_space_id); + + const { llmConfigs, loading: llmLoading, error } = useLLMConfigs(searchSpaceId); + const { preferences, updatePreferences, loading: preferencesLoading } = useLLMPreferences(searchSpaceId); const isLoading = llmLoading || preferencesLoading; diff --git a/surfsense_web/components/onboard/add-provider-step.tsx b/surfsense_web/components/onboard/add-provider-step.tsx index f582000a3..9b70c8d7f 100644 --- a/surfsense_web/components/onboard/add-provider-step.tsx +++ b/surfsense_web/components/onboard/add-provider-step.tsx @@ -23,12 +23,17 @@ import { type CreateLLMConfig, useLLMConfigs } from "@/hooks/use-llm-configs"; import InferenceParamsEditor from "../inference-params-editor"; interface AddProviderStepProps { + searchSpaceId: number; onConfigCreated?: () => void; onConfigDeleted?: () => void; } -export function AddProviderStep({ onConfigCreated, onConfigDeleted }: AddProviderStepProps) { - const { llmConfigs, createLLMConfig, deleteLLMConfig } = useLLMConfigs(); +export function AddProviderStep({ + searchSpaceId, + onConfigCreated, + onConfigDeleted, +}: AddProviderStepProps) { + const { llmConfigs, createLLMConfig, deleteLLMConfig } = useLLMConfigs(searchSpaceId); const [isAddingNew, setIsAddingNew] = useState(false); const [formData, setFormData] = useState({ name: "", @@ -38,6 +43,7 @@ export function AddProviderStep({ onConfigCreated, onConfigDeleted }: AddProvide api_key: "", api_base: "", litellm_params: {}, + search_space_id: searchSpaceId, }); const [isSubmitting, setIsSubmitting] = useState(false); @@ -65,6 +71,7 @@ export function AddProviderStep({ onConfigCreated, onConfigDeleted }: AddProvide api_key: "", api_base: "", litellm_params: {}, + search_space_id: searchSpaceId, }); setIsAddingNew(false); // Notify parent component that a config was created @@ -253,7 +260,6 @@ export function AddProviderStep({ onConfigCreated, onConfigDeleted }: AddProvide />
-
@@ -578,6 +580,7 @@ export function ModelConfigManager() { api_key: "", api_base: "", litellm_params: {}, + search_space_id: searchSpaceId, }); }} disabled={isSubmitting} diff --git a/surfsense_web/hooks/use-llm-configs.ts b/surfsense_web/hooks/use-llm-configs.ts index ccb3a5bc9..adf49d634 100644 --- a/surfsense_web/hooks/use-llm-configs.ts +++ b/surfsense_web/hooks/use-llm-configs.ts @@ -12,7 +12,7 @@ export interface LLMConfig { api_base?: string; litellm_params?: Record; created_at: string; - user_id: string; + search_space_id: number; } export interface LLMPreferences { @@ -32,6 +32,7 @@ export interface CreateLLMConfig { api_key: string; api_base?: string; litellm_params?: Record; + search_space_id: number; } export interface UpdateLLMConfig { @@ -44,16 +45,21 @@ export interface UpdateLLMConfig { litellm_params?: Record; } -export function useLLMConfigs() { +export function useLLMConfigs(searchSpaceId: number | null) { const [llmConfigs, setLlmConfigs] = useState([]); const [loading, setLoading] = useState(true); const [error, setError] = useState(null); const fetchLLMConfigs = async () => { + if (!searchSpaceId) { + setLoading(false); + return; + } + try { setLoading(true); const response = await fetch( - `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/llm-configs/`, + `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/llm-configs/?search_space_id=${searchSpaceId}`, { headers: { Authorization: `Bearer ${localStorage.getItem("surfsense_bearer_token")}`, @@ -79,7 +85,7 @@ export function useLLMConfigs() { useEffect(() => { fetchLLMConfigs(); - }, []); + }, [searchSpaceId]); const createLLMConfig = async (config: CreateLLMConfig): Promise => { try { @@ -181,16 +187,21 @@ export function useLLMConfigs() { }; } -export function useLLMPreferences() { +export function useLLMPreferences(searchSpaceId: number | null) { const [preferences, setPreferences] = useState({}); const [loading, setLoading] = useState(true); const [error, setError] = useState(null); const fetchPreferences = async () => { + if (!searchSpaceId) { + setLoading(false); + return; + } + try { setLoading(true); const response = await fetch( - `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/users/me/llm-preferences`, + `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/llm-preferences`, { headers: { Authorization: `Bearer ${localStorage.getItem("surfsense_bearer_token")}`, @@ -216,12 +227,17 @@ export function useLLMPreferences() { useEffect(() => { fetchPreferences(); - }, []); + }, [searchSpaceId]); const updatePreferences = async (newPreferences: Partial): Promise => { + if (!searchSpaceId) { + toast.error("Search space ID is required"); + return false; + } + try { const response = await fetch( - `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/users/me/llm-preferences`, + `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/llm-preferences`, { method: "PUT", headers: {