Merge pull request #384 from MODSetter/dev

feat: moved LLMConfigs from User to SearchSpaces
This commit is contained in:
Rohan Verma 2025-10-10 00:55:30 -07:00 committed by GitHub
commit 9ee9a683be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
44 changed files with 1075 additions and 518 deletions

View file

@ -0,0 +1,352 @@
"""Migrate LLM configs to search spaces and add user preferences
Revision ID: 25
Revises: 24
Create Date: 2025-01-10 14:00:00.000000
Changes:
1. Migrate llm_configs from user association to search_space association
2. Create user_search_space_preferences table for per-user LLM preferences
3. Migrate existing user LLM preferences to user_search_space_preferences
4. Remove LLM preference columns from user table
"""
from collections.abc import Sequence
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "25"
down_revision: str | None = "24"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""
Upgrade schema to support collaborative search spaces with per-user preferences.
Migration steps:
1. Add search_space_id to llm_configs
2. Migrate existing llm_configs to first search space of their user
3. Replace user_id with search_space_id in llm_configs
4. Create user_search_space_preferences table
5. Migrate user LLM preferences to user_search_space_preferences
6. Remove LLM preference columns from user table
"""
from sqlalchemy import inspect
conn = op.get_bind()
inspector = inspect(conn)
# Get existing columns
llm_config_columns = [col["name"] for col in inspector.get_columns("llm_configs")]
user_columns = [col["name"] for col in inspector.get_columns("user")]
# ===== STEP 1: Add search_space_id to llm_configs =====
if "search_space_id" not in llm_config_columns:
op.add_column(
"llm_configs",
sa.Column("search_space_id", sa.Integer(), nullable=True),
)
# ===== STEP 2: Populate search_space_id with user's first search space =====
# This ensures existing LLM configs are assigned to a valid search space
op.execute(
"""
UPDATE llm_configs lc
SET search_space_id = (
SELECT id
FROM searchspaces ss
WHERE ss.user_id = lc.user_id
ORDER BY ss.created_at ASC
LIMIT 1
)
WHERE search_space_id IS NULL AND user_id IS NOT NULL
"""
)
# ===== STEP 3: Make search_space_id NOT NULL and add FK constraint =====
op.alter_column(
"llm_configs",
"search_space_id",
nullable=False,
)
# Add foreign key constraint
foreign_keys = [fk["name"] for fk in inspector.get_foreign_keys("llm_configs")]
if "fk_llm_configs_search_space_id" not in foreign_keys:
op.create_foreign_key(
"fk_llm_configs_search_space_id",
"llm_configs",
"searchspaces",
["search_space_id"],
["id"],
ondelete="CASCADE",
)
# Drop old user_id foreign key if it exists
if "fk_llm_configs_user_id_user" in foreign_keys:
op.drop_constraint(
"fk_llm_configs_user_id_user",
"llm_configs",
type_="foreignkey",
)
# Remove user_id column
if "user_id" in llm_config_columns:
op.drop_column("llm_configs", "user_id")
# ===== STEP 4: Create user_search_space_preferences table =====
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'user_search_space_preferences'
) THEN
CREATE TABLE user_search_space_preferences (
id SERIAL PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE,
long_context_llm_id INTEGER REFERENCES llm_configs(id) ON DELETE SET NULL,
fast_llm_id INTEGER REFERENCES llm_configs(id) ON DELETE SET NULL,
strategic_llm_id INTEGER REFERENCES llm_configs(id) ON DELETE SET NULL,
CONSTRAINT uq_user_searchspace UNIQUE (user_id, search_space_id)
);
END IF;
END$$;
"""
)
# Create indexes
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_indexes
WHERE tablename = 'user_search_space_preferences'
AND indexname = 'ix_user_search_space_preferences_id'
) THEN
CREATE INDEX ix_user_search_space_preferences_id
ON user_search_space_preferences(id);
END IF;
IF NOT EXISTS (
SELECT 1 FROM pg_indexes
WHERE tablename = 'user_search_space_preferences'
AND indexname = 'ix_user_search_space_preferences_created_at'
) THEN
CREATE INDEX ix_user_search_space_preferences_created_at
ON user_search_space_preferences(created_at);
END IF;
END$$;
"""
)
# ===== STEP 5: Migrate user LLM preferences to user_search_space_preferences =====
# For each user, create preferences for each of their search spaces
if all(
col in user_columns
for col in ["long_context_llm_id", "fast_llm_id", "strategic_llm_id"]
):
op.execute(
"""
INSERT INTO user_search_space_preferences
(user_id, search_space_id, long_context_llm_id, fast_llm_id, strategic_llm_id, created_at)
SELECT
u.id as user_id,
ss.id as search_space_id,
u.long_context_llm_id,
u.fast_llm_id,
u.strategic_llm_id,
NOW() as created_at
FROM "user" u
CROSS JOIN searchspaces ss
WHERE ss.user_id = u.id
ON CONFLICT (user_id, search_space_id) DO NOTHING
"""
)
# ===== STEP 6: Remove LLM preference columns from user table =====
# Get fresh list of foreign keys after previous operations
user_foreign_keys = [fk["name"] for fk in inspector.get_foreign_keys("user")]
# Drop foreign key constraints if they exist
if "fk_user_long_context_llm_id_llm_configs" in user_foreign_keys:
op.drop_constraint(
"fk_user_long_context_llm_id_llm_configs",
"user",
type_="foreignkey",
)
if "fk_user_fast_llm_id_llm_configs" in user_foreign_keys:
op.drop_constraint(
"fk_user_fast_llm_id_llm_configs",
"user",
type_="foreignkey",
)
if "fk_user_strategic_llm_id_llm_configs" in user_foreign_keys:
op.drop_constraint(
"fk_user_strategic_llm_id_llm_configs",
"user",
type_="foreignkey",
)
# Drop columns from user table
if "long_context_llm_id" in user_columns:
op.drop_column("user", "long_context_llm_id")
if "fast_llm_id" in user_columns:
op.drop_column("user", "fast_llm_id")
if "strategic_llm_id" in user_columns:
op.drop_column("user", "strategic_llm_id")
def downgrade() -> None:
"""
Downgrade schema back to user-owned LLM configs.
WARNING: This downgrade will result in data loss:
- LLM configs will be moved back to user ownership (first occurrence kept)
- Per-search-space user preferences will be consolidated to user level
- Additional LLM configs in search spaces beyond the first will be deleted
"""
from sqlalchemy import inspect
conn = op.get_bind()
inspector = inspect(conn)
# Get existing columns and constraints
llm_config_columns = [col["name"] for col in inspector.get_columns("llm_configs")]
user_columns = [col["name"] for col in inspector.get_columns("user")]
# ===== STEP 1: Add LLM preference columns back to user table =====
if "long_context_llm_id" not in user_columns:
op.add_column(
"user",
sa.Column("long_context_llm_id", sa.Integer(), nullable=True),
)
if "fast_llm_id" not in user_columns:
op.add_column(
"user",
sa.Column("fast_llm_id", sa.Integer(), nullable=True),
)
if "strategic_llm_id" not in user_columns:
op.add_column(
"user",
sa.Column("strategic_llm_id", sa.Integer(), nullable=True),
)
# ===== STEP 2: Migrate preferences back to user table =====
# Take the first preference for each user
op.execute(
"""
UPDATE "user" u
SET
long_context_llm_id = ussp.long_context_llm_id,
fast_llm_id = ussp.fast_llm_id,
strategic_llm_id = ussp.strategic_llm_id
FROM (
SELECT DISTINCT ON (user_id)
user_id,
long_context_llm_id,
fast_llm_id,
strategic_llm_id
FROM user_search_space_preferences
ORDER BY user_id, created_at ASC
) ussp
WHERE u.id = ussp.user_id
"""
)
# ===== STEP 3: Add foreign key constraints back to user table =====
op.create_foreign_key(
"fk_user_long_context_llm_id_llm_configs",
"user",
"llm_configs",
["long_context_llm_id"],
["id"],
ondelete="SET NULL",
)
op.create_foreign_key(
"fk_user_fast_llm_id_llm_configs",
"user",
"llm_configs",
["fast_llm_id"],
["id"],
ondelete="SET NULL",
)
op.create_foreign_key(
"fk_user_strategic_llm_id_llm_configs",
"user",
"llm_configs",
["strategic_llm_id"],
["id"],
ondelete="SET NULL",
)
# ===== STEP 4: Drop user_search_space_preferences table =====
op.execute("DROP TABLE IF EXISTS user_search_space_preferences CASCADE")
# ===== STEP 5: Add user_id back to llm_configs =====
if "user_id" not in llm_config_columns:
op.add_column(
"llm_configs",
sa.Column("user_id", postgresql.UUID(), nullable=True),
)
# Populate user_id from search_space
op.execute(
"""
UPDATE llm_configs lc
SET user_id = ss.user_id
FROM searchspaces ss
WHERE lc.search_space_id = ss.id
"""
)
# Make user_id NOT NULL
op.alter_column(
"llm_configs",
"user_id",
nullable=False,
)
# Add foreign key constraint for user_id
op.create_foreign_key(
"fk_llm_configs_user_id_user",
"llm_configs",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
# ===== STEP 6: Remove search_space_id from llm_configs =====
# Drop foreign key constraint
foreign_keys = [fk["name"] for fk in inspector.get_foreign_keys("llm_configs")]
if "fk_llm_configs_search_space_id" in foreign_keys:
op.drop_constraint(
"fk_llm_configs_search_space_id",
"llm_configs",
type_="foreignkey",
)
# Drop search_space_id column
if "search_space_id" in llm_config_columns:
op.drop_column("llm_configs", "search_space_id")

View file

@ -17,6 +17,7 @@ class Configuration:
# and when you invoke the graph # and when you invoke the graph
podcast_title: str podcast_title: str
user_id: str user_id: str
search_space_id: int
@classmethod @classmethod
def from_runnable_config( def from_runnable_config(

View file

@ -28,11 +28,12 @@ async def create_podcast_transcript(
# Get configuration from runnable config # Get configuration from runnable config
configuration = Configuration.from_runnable_config(config) configuration = Configuration.from_runnable_config(config)
user_id = configuration.user_id user_id = configuration.user_id
search_space_id = configuration.search_space_id
# Get user's long context LLM # Get user's long context LLM
llm = await get_user_long_context_llm(state.db_session, user_id) llm = await get_user_long_context_llm(state.db_session, user_id, search_space_id)
if not llm: if not llm:
error_message = f"No long context LLM configured for user {user_id}" error_message = f"No long context LLM configured for user {user_id} in search space {search_space_id}"
print(error_message) print(error_message)
raise RuntimeError(error_message) raise RuntimeError(error_message)

View file

@ -577,6 +577,7 @@ async def write_answer_outline(
user_query = configuration.user_query user_query = configuration.user_query
num_sections = configuration.num_sections num_sections = configuration.num_sections
user_id = configuration.user_id user_id = configuration.user_id
search_space_id = configuration.search_space_id
writer( writer(
{ {
@ -587,9 +588,9 @@ async def write_answer_outline(
) )
# Get user's strategic LLM # Get user's strategic LLM
llm = await get_user_strategic_llm(state.db_session, user_id) llm = await get_user_strategic_llm(state.db_session, user_id, search_space_id)
if not llm: if not llm:
error_message = f"No strategic LLM configured for user {user_id}" error_message = f"No strategic LLM configured for user {user_id} in search space {search_space_id}"
writer({"yield_value": streaming_service.format_error(error_message)}) writer({"yield_value": streaming_service.format_error(error_message)})
raise RuntimeError(error_message) raise RuntimeError(error_message)
@ -1854,6 +1855,7 @@ async def reformulate_user_query(
user_query=user_query, user_query=user_query,
session=state.db_session, session=state.db_session,
user_id=configuration.user_id, user_id=configuration.user_id,
search_space_id=configuration.search_space_id,
chat_history_str=chat_history_str, chat_history_str=chat_history_str,
) )
@ -2093,6 +2095,7 @@ async def generate_further_questions(
configuration = Configuration.from_runnable_config(config) configuration = Configuration.from_runnable_config(config)
chat_history = state.chat_history chat_history = state.chat_history
user_id = configuration.user_id user_id = configuration.user_id
search_space_id = configuration.search_space_id
streaming_service = state.streaming_service streaming_service = state.streaming_service
# Get reranked documents from the state (will be populated by sub-agents) # Get reranked documents from the state (will be populated by sub-agents)
@ -2107,9 +2110,9 @@ async def generate_further_questions(
) )
# Get user's fast LLM # Get user's fast LLM
llm = await get_user_fast_llm(state.db_session, user_id) llm = await get_user_fast_llm(state.db_session, user_id, search_space_id)
if not llm: if not llm:
error_message = f"No fast LLM configured for user {user_id}" error_message = f"No fast LLM configured for user {user_id} in search space {search_space_id}"
print(error_message) print(error_message)
writer({"yield_value": streaming_service.format_error(error_message)}) writer({"yield_value": streaming_service.format_error(error_message)})

View file

@ -101,11 +101,12 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any
documents = state.reranked_documents documents = state.reranked_documents
user_query = configuration.user_query user_query = configuration.user_query
user_id = configuration.user_id user_id = configuration.user_id
search_space_id = configuration.search_space_id
# Get user's fast LLM # Get user's fast LLM
llm = await get_user_fast_llm(state.db_session, user_id) llm = await get_user_fast_llm(state.db_session, user_id, search_space_id)
if not llm: if not llm:
error_message = f"No fast LLM configured for user {user_id}" error_message = f"No fast LLM configured for user {user_id} in search space {search_space_id}"
print(error_message) print(error_message)
raise RuntimeError(error_message) raise RuntimeError(error_message)

View file

@ -107,11 +107,12 @@ async def write_sub_section(state: State, config: RunnableConfig) -> dict[str, A
configuration = Configuration.from_runnable_config(config) configuration = Configuration.from_runnable_config(config)
documents = state.reranked_documents documents = state.reranked_documents
user_id = configuration.user_id user_id = configuration.user_id
search_space_id = configuration.search_space_id
# Get user's fast LLM # Get user's fast LLM
llm = await get_user_fast_llm(state.db_session, user_id) llm = await get_user_fast_llm(state.db_session, user_id, search_space_id)
if not llm: if not llm:
error_message = f"No fast LLM configured for user {user_id}" error_message = f"No fast LLM configured for user {user_id} in search space {search_space_id}"
print(error_message) print(error_message)
raise RuntimeError(error_message) raise RuntimeError(error_message)

View file

@ -240,6 +240,17 @@ class SearchSpace(BaseModel, TimestampMixin):
order_by="SearchSourceConnector.id", order_by="SearchSourceConnector.id",
cascade="all, delete-orphan", cascade="all, delete-orphan",
) )
llm_configs = relationship(
"LLMConfig",
back_populates="search_space",
order_by="LLMConfig.id",
cascade="all, delete-orphan",
)
user_preferences = relationship(
"UserSearchSpacePreference",
back_populates="search_space",
cascade="all, delete-orphan",
)
class SearchSourceConnector(BaseModel, TimestampMixin): class SearchSourceConnector(BaseModel, TimestampMixin):
@ -288,10 +299,54 @@ class LLMConfig(BaseModel, TimestampMixin):
# For any other parameters that litellm supports # For any other parameters that litellm supports
litellm_params = Column(JSON, nullable=True, default={}) litellm_params = Column(JSON, nullable=True, default={})
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
search_space = relationship("SearchSpace", back_populates="llm_configs")
class UserSearchSpacePreference(BaseModel, TimestampMixin):
__tablename__ = "user_search_space_preferences"
__table_args__ = (
UniqueConstraint(
"user_id",
"search_space_id",
name="uq_user_searchspace",
),
)
user_id = Column( user_id = Column(
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
) )
user = relationship("User", back_populates="llm_configs", foreign_keys=[user_id]) search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
# User-specific LLM preferences for this search space
long_context_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
fast_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
strategic_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
# Future RBAC fields can be added here
# role = Column(String(50), nullable=True) # e.g., 'owner', 'editor', 'viewer'
# permissions = Column(JSON, nullable=True)
user = relationship("User", back_populates="search_space_preferences")
search_space = relationship("SearchSpace", back_populates="user_preferences")
long_context_llm = relationship(
"LLMConfig", foreign_keys=[long_context_llm_id], post_update=True
)
fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True)
strategic_llm = relationship(
"LLMConfig", foreign_keys=[strategic_llm_id], post_update=True
)
class Log(BaseModel, TimestampMixin): class Log(BaseModel, TimestampMixin):
@ -321,64 +376,22 @@ if config.AUTH_TYPE == "GOOGLE":
"OAuthAccount", lazy="joined" "OAuthAccount", lazy="joined"
) )
search_spaces = relationship("SearchSpace", back_populates="user") search_spaces = relationship("SearchSpace", back_populates="user")
llm_configs = relationship( search_space_preferences = relationship(
"LLMConfig", "UserSearchSpacePreference",
back_populates="user", back_populates="user",
foreign_keys="LLMConfig.user_id",
cascade="all, delete-orphan", cascade="all, delete-orphan",
) )
long_context_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
fast_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
strategic_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
long_context_llm = relationship(
"LLMConfig", foreign_keys=[long_context_llm_id], post_update=True
)
fast_llm = relationship(
"LLMConfig", foreign_keys=[fast_llm_id], post_update=True
)
strategic_llm = relationship(
"LLMConfig", foreign_keys=[strategic_llm_id], post_update=True
)
else: else:
class User(SQLAlchemyBaseUserTableUUID, Base): class User(SQLAlchemyBaseUserTableUUID, Base):
search_spaces = relationship("SearchSpace", back_populates="user") search_spaces = relationship("SearchSpace", back_populates="user")
llm_configs = relationship( search_space_preferences = relationship(
"LLMConfig", "UserSearchSpacePreference",
back_populates="user", back_populates="user",
foreign_keys="LLMConfig.user_id",
cascade="all, delete-orphan", cascade="all, delete-orphan",
) )
long_context_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
fast_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
strategic_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
long_context_llm = relationship(
"LLMConfig", foreign_keys=[long_context_llm_id], post_update=True
)
fast_llm = relationship(
"LLMConfig", foreign_keys=[fast_llm_id], post_update=True
)
strategic_llm = relationship(
"LLMConfig", foreign_keys=[strategic_llm_id], post_update=True
)
engine = create_async_engine(DATABASE_URL) engine = create_async_engine(DATABASE_URL)
async_session_maker = async_sessionmaker(engine, expire_on_commit=False) async_session_maker = async_sessionmaker(engine, expire_on_commit=False)

View file

@ -17,19 +17,17 @@ from app.tasks.stream_connector_search_results import stream_connector_search_re
from app.users import current_active_user from app.users import current_active_user
from app.utils.check_ownership import check_ownership from app.utils.check_ownership import check_ownership
from app.utils.validators import ( from app.utils.validators import (
validate_search_space_id,
validate_document_ids,
validate_connectors, validate_connectors,
validate_document_ids,
validate_messages,
validate_research_mode, validate_research_mode,
validate_search_mode, validate_search_mode,
validate_messages, validate_search_space_id,
) )
router = APIRouter() router = APIRouter()
@router.post("/chat") @router.post("/chat")
async def handle_chat_data( async def handle_chat_data(
request: AISDKChatRequest, request: AISDKChatRequest,
@ -51,7 +49,9 @@ async def handle_chat_data(
search_space_id = validate_search_space_id(request_data.get("search_space_id")) search_space_id = validate_search_space_id(request_data.get("search_space_id"))
research_mode = validate_research_mode(request_data.get("research_mode")) research_mode = validate_research_mode(request_data.get("research_mode"))
selected_connectors = validate_connectors(request_data.get("selected_connectors")) selected_connectors = validate_connectors(request_data.get("selected_connectors"))
document_ids_to_add_in_context = validate_document_ids(request_data.get("document_ids_to_add_in_context")) document_ids_to_add_in_context = validate_document_ids(
request_data.get("document_ids_to_add_in_context")
)
search_mode_str = validate_search_mode(request_data.get("search_mode")) search_mode_str = validate_search_mode(request_data.get("search_mode"))
# Check if the search space belongs to the current user # Check if the search space belongs to the current user
@ -132,21 +132,16 @@ async def read_chats(
# Validate pagination parameters # Validate pagination parameters
if skip < 0: if skip < 0:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail="skip must be a non-negative integer"
detail="skip must be a non-negative integer"
) )
if limit <= 0 or limit > 1000: # Reasonable upper limit if limit <= 0 or limit > 1000: # Reasonable upper limit
raise HTTPException( raise HTTPException(status_code=400, detail="limit must be between 1 and 1000")
status_code=400,
detail="limit must be between 1 and 1000"
)
# Validate search_space_id if provided # Validate search_space_id if provided
if search_space_id is not None and search_space_id <= 0: if search_space_id is not None and search_space_id <= 0:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail="search_space_id must be a positive integer"
detail="search_space_id must be a positive integer"
) )
try: try:
# Select specific fields excluding messages # Select specific fields excluding messages

View file

@ -2,15 +2,72 @@ from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.db import LLMConfig, User, get_async_session from app.db import (
LLMConfig,
SearchSpace,
User,
UserSearchSpacePreference,
get_async_session,
)
from app.schemas import LLMConfigCreate, LLMConfigRead, LLMConfigUpdate from app.schemas import LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
from app.users import current_active_user from app.users import current_active_user
from app.utils.check_ownership import check_ownership
router = APIRouter() router = APIRouter()
# Helper function to check search space access
async def check_search_space_access(
session: AsyncSession, search_space_id: int, user: User
) -> SearchSpace:
"""Verify that the user has access to the search space"""
result = await session.execute(
select(SearchSpace).filter(
SearchSpace.id == search_space_id, SearchSpace.user_id == user.id
)
)
search_space = result.scalars().first()
if not search_space:
raise HTTPException(
status_code=404,
detail="Search space not found or you don't have permission to access it",
)
return search_space
# Helper function to get or create user search space preference
async def get_or_create_user_preference(
session: AsyncSession, user_id, search_space_id: int
) -> UserSearchSpacePreference:
"""Get or create user preference for a search space"""
result = await session.execute(
select(UserSearchSpacePreference)
.filter(
UserSearchSpacePreference.user_id == user_id,
UserSearchSpacePreference.search_space_id == search_space_id,
)
.options(
selectinload(UserSearchSpacePreference.long_context_llm),
selectinload(UserSearchSpacePreference.fast_llm),
selectinload(UserSearchSpacePreference.strategic_llm),
)
)
preference = result.scalars().first()
if not preference:
# Create new preference entry
preference = UserSearchSpacePreference(
user_id=user_id,
search_space_id=search_space_id,
)
session.add(preference)
await session.commit()
await session.refresh(preference)
return preference
class LLMPreferencesUpdate(BaseModel): class LLMPreferencesUpdate(BaseModel):
"""Schema for updating user LLM preferences""" """Schema for updating user LLM preferences"""
@ -36,9 +93,12 @@ async def create_llm_config(
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), user: User = Depends(current_active_user),
): ):
"""Create a new LLM configuration for the authenticated user""" """Create a new LLM configuration for a search space"""
try: try:
db_llm_config = LLMConfig(**llm_config.model_dump(), user_id=user.id) # Verify user has access to the search space
await check_search_space_access(session, llm_config.search_space_id, user)
db_llm_config = LLMConfig(**llm_config.model_dump())
session.add(db_llm_config) session.add(db_llm_config)
await session.commit() await session.commit()
await session.refresh(db_llm_config) await session.refresh(db_llm_config)
@ -54,20 +114,26 @@ async def create_llm_config(
@router.get("/llm-configs/", response_model=list[LLMConfigRead]) @router.get("/llm-configs/", response_model=list[LLMConfigRead])
async def read_llm_configs( async def read_llm_configs(
search_space_id: int,
skip: int = 0, skip: int = 0,
limit: int = 200, limit: int = 200,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), user: User = Depends(current_active_user),
): ):
"""Get all LLM configurations for the authenticated user""" """Get all LLM configurations for a search space"""
try: try:
# Verify user has access to the search space
await check_search_space_access(session, search_space_id, user)
result = await session.execute( result = await session.execute(
select(LLMConfig) select(LLMConfig)
.filter(LLMConfig.user_id == user.id) .filter(LLMConfig.search_space_id == search_space_id)
.offset(skip) .offset(skip)
.limit(limit) .limit(limit)
) )
return result.scalars().all() return result.scalars().all()
except HTTPException:
raise
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Failed to fetch LLM configurations: {e!s}" status_code=500, detail=f"Failed to fetch LLM configurations: {e!s}"
@ -82,7 +148,18 @@ async def read_llm_config(
): ):
"""Get a specific LLM configuration by ID""" """Get a specific LLM configuration by ID"""
try: try:
llm_config = await check_ownership(session, LLMConfig, llm_config_id, user) # Get the LLM config
result = await session.execute(
select(LLMConfig).filter(LLMConfig.id == llm_config_id)
)
llm_config = result.scalars().first()
if not llm_config:
raise HTTPException(status_code=404, detail="LLM configuration not found")
# Verify user has access to the search space
await check_search_space_access(session, llm_config.search_space_id, user)
return llm_config return llm_config
except HTTPException: except HTTPException:
raise raise
@ -101,7 +178,18 @@ async def update_llm_config(
): ):
"""Update an existing LLM configuration""" """Update an existing LLM configuration"""
try: try:
db_llm_config = await check_ownership(session, LLMConfig, llm_config_id, user) # Get the LLM config
result = await session.execute(
select(LLMConfig).filter(LLMConfig.id == llm_config_id)
)
db_llm_config = result.scalars().first()
if not db_llm_config:
raise HTTPException(status_code=404, detail="LLM configuration not found")
# Verify user has access to the search space
await check_search_space_access(session, db_llm_config.search_space_id, user)
update_data = llm_config_update.model_dump(exclude_unset=True) update_data = llm_config_update.model_dump(exclude_unset=True)
for key, value in update_data.items(): for key, value in update_data.items():
@ -127,7 +215,18 @@ async def delete_llm_config(
): ):
"""Delete an LLM configuration""" """Delete an LLM configuration"""
try: try:
db_llm_config = await check_ownership(session, LLMConfig, llm_config_id, user) # Get the LLM config
result = await session.execute(
select(LLMConfig).filter(LLMConfig.id == llm_config_id)
)
db_llm_config = result.scalars().first()
if not db_llm_config:
raise HTTPException(status_code=404, detail="LLM configuration not found")
# Verify user has access to the search space
await check_search_space_access(session, db_llm_config.search_space_id, user)
await session.delete(db_llm_config) await session.delete(db_llm_config)
await session.commit() await session.commit()
return {"message": "LLM configuration deleted successfully"} return {"message": "LLM configuration deleted successfully"}
@ -143,99 +242,101 @@ async def delete_llm_config(
# User LLM Preferences endpoints # User LLM Preferences endpoints
@router.get("/users/me/llm-preferences", response_model=LLMPreferencesRead) @router.get(
"/search-spaces/{search_space_id}/llm-preferences",
response_model=LLMPreferencesRead,
)
async def get_user_llm_preferences( async def get_user_llm_preferences(
search_space_id: int,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), user: User = Depends(current_active_user),
): ):
"""Get the current user's LLM preferences""" """Get the current user's LLM preferences for a specific search space"""
try: try:
# Refresh user to get latest relationships # Verify user has access to the search space
await session.refresh(user) await check_search_space_access(session, search_space_id, user)
result = { # Get or create user preference for this search space
"long_context_llm_id": user.long_context_llm_id, preference = await get_or_create_user_preference(
"fast_llm_id": user.fast_llm_id, session, user.id, search_space_id
"strategic_llm_id": user.strategic_llm_id, )
"long_context_llm": None,
"fast_llm": None, return {
"strategic_llm": None, "long_context_llm_id": preference.long_context_llm_id,
"fast_llm_id": preference.fast_llm_id,
"strategic_llm_id": preference.strategic_llm_id,
"long_context_llm": preference.long_context_llm,
"fast_llm": preference.fast_llm,
"strategic_llm": preference.strategic_llm,
} }
except HTTPException:
# Fetch the actual LLM configs if they exist raise
if user.long_context_llm_id:
long_context_llm = await session.execute(
select(LLMConfig).filter(
LLMConfig.id == user.long_context_llm_id,
LLMConfig.user_id == user.id,
)
)
llm_config = long_context_llm.scalars().first()
if llm_config:
result["long_context_llm"] = llm_config
if user.fast_llm_id:
fast_llm = await session.execute(
select(LLMConfig).filter(
LLMConfig.id == user.fast_llm_id, LLMConfig.user_id == user.id
)
)
llm_config = fast_llm.scalars().first()
if llm_config:
result["fast_llm"] = llm_config
if user.strategic_llm_id:
strategic_llm = await session.execute(
select(LLMConfig).filter(
LLMConfig.id == user.strategic_llm_id, LLMConfig.user_id == user.id
)
)
llm_config = strategic_llm.scalars().first()
if llm_config:
result["strategic_llm"] = llm_config
return result
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Failed to fetch LLM preferences: {e!s}" status_code=500, detail=f"Failed to fetch LLM preferences: {e!s}"
) from e ) from e
@router.put("/users/me/llm-preferences", response_model=LLMPreferencesRead) @router.put(
"/search-spaces/{search_space_id}/llm-preferences",
response_model=LLMPreferencesRead,
)
async def update_user_llm_preferences( async def update_user_llm_preferences(
search_space_id: int,
preferences: LLMPreferencesUpdate, preferences: LLMPreferencesUpdate,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), user: User = Depends(current_active_user),
): ):
"""Update the current user's LLM preferences""" """Update the current user's LLM preferences for a specific search space"""
try: try:
# Validate that all provided LLM config IDs belong to the user # Verify user has access to the search space
await check_search_space_access(session, search_space_id, user)
# Get or create user preference for this search space
preference = await get_or_create_user_preference(
session, user.id, search_space_id
)
# Validate that all provided LLM config IDs belong to the search space
update_data = preferences.model_dump(exclude_unset=True) update_data = preferences.model_dump(exclude_unset=True)
for _key, llm_config_id in update_data.items(): for _key, llm_config_id in update_data.items():
if llm_config_id is not None: if llm_config_id is not None:
# Verify ownership of the LLM config # Verify the LLM config belongs to the search space
result = await session.execute( result = await session.execute(
select(LLMConfig).filter( select(LLMConfig).filter(
LLMConfig.id == llm_config_id, LLMConfig.user_id == user.id LLMConfig.id == llm_config_id,
LLMConfig.search_space_id == search_space_id,
) )
) )
llm_config = result.scalars().first() llm_config = result.scalars().first()
if not llm_config: if not llm_config:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"LLM configuration {llm_config_id} not found or you don't have permission to access it", detail=f"LLM configuration {llm_config_id} not found in this search space",
) )
# Update user preferences # Update user preferences
for key, value in update_data.items(): for key, value in update_data.items():
setattr(user, key, value) setattr(preference, key, value)
await session.commit() await session.commit()
await session.refresh(user) await session.refresh(preference)
# Reload relationships
await session.refresh(
preference, ["long_context_llm", "fast_llm", "strategic_llm"]
)
# Return updated preferences # Return updated preferences
return await get_user_llm_preferences(session, user) return {
"long_context_llm_id": preference.long_context_llm_id,
"fast_llm_id": preference.fast_llm_id,
"strategic_llm_id": preference.strategic_llm_id,
"long_context_llm": preference.long_context_llm,
"fast_llm": preference.fast_llm,
"strategic_llm": preference.strategic_llm,
}
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:

View file

@ -1,4 +1,3 @@
import uuid
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
@ -30,7 +29,9 @@ class LLMConfigBase(BaseModel):
class LLMConfigCreate(LLMConfigBase): class LLMConfigCreate(LLMConfigBase):
pass search_space_id: int = Field(
..., description="Search space ID to associate the LLM config with"
)
class LLMConfigUpdate(BaseModel): class LLMConfigUpdate(BaseModel):
@ -56,6 +57,6 @@ class LLMConfigUpdate(BaseModel):
class LLMConfigRead(LLMConfigBase, IDModel, TimestampModel): class LLMConfigRead(LLMConfigBase, IDModel, TimestampModel):
id: int id: int
created_at: datetime created_at: datetime
user_id: uuid.UUID search_space_id: int
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)

View file

@ -1,10 +1,14 @@
import logging import logging
import litellm
from langchain_litellm import ChatLiteLLM from langchain_litellm import ChatLiteLLM
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from app.db import LLMConfig, User from app.db import LLMConfig, UserSearchSpacePreference
# Configure litellm to automatically drop unsupported parameters
litellm.drop_params = True
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -16,54 +20,67 @@ class LLMRole:
async def get_user_llm_instance( async def get_user_llm_instance(
session: AsyncSession, user_id: str, role: str session: AsyncSession, user_id: str, search_space_id: int, role: str
) -> ChatLiteLLM | None: ) -> ChatLiteLLM | None:
""" """
Get a ChatLiteLLM instance for a specific user and role. Get a ChatLiteLLM instance for a specific user, search space, and role.
Args: Args:
session: Database session session: Database session
user_id: User ID user_id: User ID
search_space_id: Search Space ID
role: LLM role ('long_context', 'fast', or 'strategic') role: LLM role ('long_context', 'fast', or 'strategic')
Returns: Returns:
ChatLiteLLM instance or None if not found ChatLiteLLM instance or None if not found
""" """
try: try:
# Get user with their LLM preferences # Get user's LLM preferences for this search space
result = await session.execute(select(User).where(User.id == user_id)) result = await session.execute(
user = result.scalars().first() select(UserSearchSpacePreference).where(
UserSearchSpacePreference.user_id == user_id,
UserSearchSpacePreference.search_space_id == search_space_id,
)
)
preference = result.scalars().first()
if not user: if not preference:
logger.error(f"User {user_id} not found") logger.error(
f"No LLM preferences found for user {user_id} in search space {search_space_id}"
)
return None return None
# Get the appropriate LLM config ID based on role # Get the appropriate LLM config ID based on role
llm_config_id = None llm_config_id = None
if role == LLMRole.LONG_CONTEXT: if role == LLMRole.LONG_CONTEXT:
llm_config_id = user.long_context_llm_id llm_config_id = preference.long_context_llm_id
elif role == LLMRole.FAST: elif role == LLMRole.FAST:
llm_config_id = user.fast_llm_id llm_config_id = preference.fast_llm_id
elif role == LLMRole.STRATEGIC: elif role == LLMRole.STRATEGIC:
llm_config_id = user.strategic_llm_id llm_config_id = preference.strategic_llm_id
else: else:
logger.error(f"Invalid LLM role: {role}") logger.error(f"Invalid LLM role: {role}")
return None return None
if not llm_config_id: if not llm_config_id:
logger.error(f"No {role} LLM configured for user {user_id}") logger.error(
f"No {role} LLM configured for user {user_id} in search space {search_space_id}"
)
return None return None
# Get the LLM configuration # Get the LLM configuration
result = await session.execute( result = await session.execute(
select(LLMConfig).where( select(LLMConfig).where(
LLMConfig.id == llm_config_id, LLMConfig.user_id == user_id LLMConfig.id == llm_config_id,
LLMConfig.search_space_id == search_space_id,
) )
) )
llm_config = result.scalars().first() llm_config = result.scalars().first()
if not llm_config: if not llm_config:
logger.error(f"LLM config {llm_config_id} not found for user {user_id}") logger.error(
f"LLM config {llm_config_id} not found in search space {search_space_id}"
)
return None return None
# Build the model string for litellm # Build the model string for litellm
@ -113,19 +130,25 @@ async def get_user_llm_instance(
async def get_user_long_context_llm( async def get_user_long_context_llm(
session: AsyncSession, user_id: str session: AsyncSession, user_id: str, search_space_id: int
) -> ChatLiteLLM | None: ) -> ChatLiteLLM | None:
"""Get user's long context LLM instance.""" """Get user's long context LLM instance for a specific search space."""
return await get_user_llm_instance(session, user_id, LLMRole.LONG_CONTEXT) return await get_user_llm_instance(
session, user_id, search_space_id, LLMRole.LONG_CONTEXT
)
async def get_user_fast_llm(session: AsyncSession, user_id: str) -> ChatLiteLLM | None: async def get_user_fast_llm(
"""Get user's fast LLM instance.""" session: AsyncSession, user_id: str, search_space_id: int
return await get_user_llm_instance(session, user_id, LLMRole.FAST) ) -> ChatLiteLLM | None:
"""Get user's fast LLM instance for a specific search space."""
return await get_user_llm_instance(session, user_id, search_space_id, LLMRole.FAST)
async def get_user_strategic_llm( async def get_user_strategic_llm(
session: AsyncSession, user_id: str session: AsyncSession, user_id: str, search_space_id: int
) -> ChatLiteLLM | None: ) -> ChatLiteLLM | None:
"""Get user's strategic LLM instance.""" """Get user's strategic LLM instance for a specific search space."""
return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC) return await get_user_llm_instance(
session, user_id, search_space_id, LLMRole.STRATEGIC
)

View file

@ -17,6 +17,7 @@ class QueryService:
user_query: str, user_query: str,
session: AsyncSession, session: AsyncSession,
user_id: str, user_id: str,
search_space_id: int,
chat_history_str: str | None = None, chat_history_str: str | None = None,
) -> str: ) -> str:
""" """
@ -27,6 +28,7 @@ class QueryService:
user_query: The original user query user_query: The original user query
session: Database session for accessing user LLM configs session: Database session for accessing user LLM configs
user_id: User ID to get their specific LLM configuration user_id: User ID to get their specific LLM configuration
search_space_id: Search Space ID to get user's LLM preferences
chat_history_str: Optional chat history string chat_history_str: Optional chat history string
Returns: Returns:
@ -37,10 +39,10 @@ class QueryService:
try: try:
# Get the user's strategic LLM instance # Get the user's strategic LLM instance
llm = await get_user_strategic_llm(session, user_id) llm = await get_user_strategic_llm(session, user_id, search_space_id)
if not llm: if not llm:
print( print(
f"Warning: No strategic LLM configured for user {user_id}. Using original query." f"Warning: No strategic LLM configured for user {user_id} in search space {search_space_id}. Using original query."
) )
return user_query return user_query

View file

@ -260,7 +260,9 @@ async def index_airtable_records(
continue continue
# Generate document summary # Generate document summary
user_llm = await get_user_long_context_llm(session, user_id) user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm: if user_llm:
document_metadata = { document_metadata = {

View file

@ -222,7 +222,9 @@ async def index_clickup_tasks(
continue continue
# Generate summary with metadata # Generate summary with metadata
user_llm = await get_user_long_context_llm(session, user_id) user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm: if user_llm:
document_metadata = { document_metadata = {

View file

@ -233,7 +233,9 @@ async def index_confluence_pages(
continue continue
# Generate summary with metadata # Generate summary with metadata
user_llm = await get_user_long_context_llm(session, user_id) user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
comment_count = len(comments) comment_count = len(comments)
if user_llm: if user_llm:

View file

@ -325,7 +325,9 @@ async def index_discord_messages(
continue continue
# Get user's long context LLM # Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id) user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if not user_llm: if not user_llm:
logger.error( logger.error(
f"No long context LLM configured for user {user_id}" f"No long context LLM configured for user {user_id}"

View file

@ -213,7 +213,9 @@ async def index_github_repos(
continue continue
# Generate summary with metadata # Generate summary with metadata
user_llm = await get_user_long_context_llm(session, user_id) user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm: if user_llm:
# Extract file extension from file path # Extract file extension from file path
file_extension = ( file_extension = (

View file

@ -266,7 +266,9 @@ async def index_google_calendar_events(
continue continue
# Generate summary with metadata # Generate summary with metadata
user_llm = await get_user_long_context_llm(session, user_id) user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm: if user_llm:
document_metadata = { document_metadata = {

View file

@ -210,7 +210,9 @@ async def index_google_gmail_messages(
continue continue
# Generate summary with metadata # Generate summary with metadata
user_llm = await get_user_long_context_llm(session, user_id) user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm: if user_llm:
document_metadata = { document_metadata = {

View file

@ -216,7 +216,9 @@ async def index_jira_issues(
continue continue
# Generate summary with metadata # Generate summary with metadata
user_llm = await get_user_long_context_llm(session, user_id) user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
comment_count = len(formatted_issue.get("comments", [])) comment_count = len(formatted_issue.get("comments", []))
if user_llm: if user_llm:

View file

@ -228,7 +228,9 @@ async def index_linear_issues(
continue continue
# Generate summary with metadata # Generate summary with metadata
user_llm = await get_user_long_context_llm(session, user_id) user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
state = formatted_issue.get("state", "Unknown") state = formatted_issue.get("state", "Unknown")
description = formatted_issue.get("description", "") description = formatted_issue.get("description", "")
comment_count = len(formatted_issue.get("comments", [])) comment_count = len(formatted_issue.get("comments", []))

View file

@ -270,7 +270,9 @@ async def index_luma_events(
continue continue
# Generate summary with metadata # Generate summary with metadata
user_llm = await get_user_long_context_llm(session, user_id) user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm: if user_llm:
document_metadata = { document_metadata = {

View file

@ -299,7 +299,9 @@ async def index_notion_pages(
continue continue
# Get user's long context LLM # Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id) user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if not user_llm: if not user_llm:
logger.error(f"No long context LLM configured for user {user_id}") logger.error(f"No long context LLM configured for user {user_id}")
skipped_pages.append(f"{page_title} (no LLM configured)") skipped_pages.append(f"{page_title} (no LLM configured)")

View file

@ -104,9 +104,11 @@ async def add_extension_received_document(
return existing_document return existing_document
# Get user's long context LLM # Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id) user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
if not user_llm: if not user_llm:
raise RuntimeError(f"No long context LLM configured for user {user_id}") raise RuntimeError(
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
)
# Generate summary with metadata # Generate summary with metadata
document_metadata = { document_metadata = {

View file

@ -60,9 +60,11 @@ async def add_received_file_document_using_unstructured(
# TODO: Check if file_markdown exceeds token limit of embedding model # TODO: Check if file_markdown exceeds token limit of embedding model
# Get user's long context LLM # Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id) user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
if not user_llm: if not user_llm:
raise RuntimeError(f"No long context LLM configured for user {user_id}") raise RuntimeError(
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
)
# Generate summary with metadata # Generate summary with metadata
document_metadata = { document_metadata = {
@ -140,9 +142,11 @@ async def add_received_file_document_using_llamacloud(
return existing_document return existing_document
# Get user's long context LLM # Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id) user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
if not user_llm: if not user_llm:
raise RuntimeError(f"No long context LLM configured for user {user_id}") raise RuntimeError(
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
)
# Generate summary with metadata # Generate summary with metadata
document_metadata = { document_metadata = {
@ -221,9 +225,11 @@ async def add_received_file_document_using_docling(
return existing_document return existing_document
# Get user's long context LLM # Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id) user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
if not user_llm: if not user_llm:
raise RuntimeError(f"No long context LLM configured for user {user_id}") raise RuntimeError(
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
)
# Generate summary using chunked processing for large documents # Generate summary using chunked processing for large documents
from app.services.docling_service import create_docling_service from app.services.docling_service import create_docling_service

View file

@ -75,9 +75,11 @@ async def add_received_markdown_file_document(
return existing_document return existing_document
# Get user's long context LLM # Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id) user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
if not user_llm: if not user_llm:
raise RuntimeError(f"No long context LLM configured for user {user_id}") raise RuntimeError(
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
)
# Generate summary with metadata # Generate summary with metadata
document_metadata = { document_metadata = {

View file

@ -161,9 +161,11 @@ async def add_crawled_url_document(
) )
# Get user's long context LLM # Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id) user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
if not user_llm: if not user_llm:
raise RuntimeError(f"No long context LLM configured for user {user_id}") raise RuntimeError(
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
)
# Generate summary # Generate summary
await task_logger.log_task_progress( await task_logger.log_task_progress(

View file

@ -234,9 +234,11 @@ async def add_youtube_video_document(
) )
# Get user's long context LLM # Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id) user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
if not user_llm: if not user_llm:
raise RuntimeError(f"No long context LLM configured for user {user_id}") raise RuntimeError(
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
)
# Generate summary # Generate summary
await task_logger.log_task_progress( await task_logger.log_task_progress(

View file

@ -98,6 +98,7 @@ async def generate_chat_podcast(
"configurable": { "configurable": {
"podcast_title": "SurfSense", "podcast_title": "SurfSense",
"user_id": str(user_id), "user_id": str(user_id),
"search_space_id": search_space_id,
} }
} }
# Initialize state with database session and streaming service # Initialize state with database session and streaming service

View file

@ -27,23 +27,17 @@ def validate_search_space_id(search_space_id: Any) -> int:
HTTPException: If validation fails HTTPException: If validation fails
""" """
if search_space_id is None: if search_space_id is None:
raise HTTPException( raise HTTPException(status_code=400, detail="search_space_id is required")
status_code=400,
detail="search_space_id is required"
)
if isinstance(search_space_id, bool): if isinstance(search_space_id, bool):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail="search_space_id must be an integer, not a boolean"
detail="search_space_id must be an integer, not a boolean"
) )
if isinstance(search_space_id, int): if isinstance(search_space_id, int):
if search_space_id <= 0: if search_space_id <= 0:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail="search_space_id must be a positive integer"
detail="search_space_id must be a positive integer"
) )
return search_space_id return search_space_id
@ -51,29 +45,27 @@ def validate_search_space_id(search_space_id: Any) -> int:
# Check if it's a valid integer string # Check if it's a valid integer string
if not search_space_id.strip(): if not search_space_id.strip():
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail="search_space_id cannot be empty"
detail="search_space_id cannot be empty"
) )
# Check for valid integer format (no leading zeros, no decimal points) # Check for valid integer format (no leading zeros, no decimal points)
if not re.match(r'^[1-9]\d*$', search_space_id.strip()): if not re.match(r"^[1-9]\d*$", search_space_id.strip()):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="search_space_id must be a valid positive integer" detail="search_space_id must be a valid positive integer",
) )
value = int(search_space_id.strip()) value = int(search_space_id.strip())
# Regex already guarantees value > 0, but check retained for clarity # Regex already guarantees value > 0, but check retained for clarity
if value <= 0: if value <= 0:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail="search_space_id must be a positive integer"
detail="search_space_id must be a positive integer"
) )
return value return value
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="search_space_id must be an integer or string representation of an integer" detail="search_space_id must be an integer or string representation of an integer",
) )
@ -95,8 +87,7 @@ def validate_document_ids(document_ids: Any) -> list[int]:
if not isinstance(document_ids, list): if not isinstance(document_ids, list):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail="document_ids_to_add_in_context must be a list"
detail="document_ids_to_add_in_context must be a list"
) )
validated_ids = [] validated_ids = []
@ -111,20 +102,20 @@ def validate_document_ids(document_ids: Any) -> list[int]:
if doc_id <= 0: if doc_id <= 0:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"document_ids_to_add_in_context[{i}] must be a positive integer" detail=f"document_ids_to_add_in_context[{i}] must be a positive integer",
) )
validated_ids.append(doc_id) validated_ids.append(doc_id)
elif isinstance(doc_id, str): elif isinstance(doc_id, str):
if not doc_id.strip(): if not doc_id.strip():
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"document_ids_to_add_in_context[{i}] cannot be empty" detail=f"document_ids_to_add_in_context[{i}] cannot be empty",
) )
if not re.match(r'^[1-9]\d*$', doc_id.strip()): if not re.match(r"^[1-9]\d*$", doc_id.strip()):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"document_ids_to_add_in_context[{i}] must be a valid positive integer" detail=f"document_ids_to_add_in_context[{i}] must be a valid positive integer",
) )
value = int(doc_id.strip()) value = int(doc_id.strip())
@ -132,13 +123,13 @@ def validate_document_ids(document_ids: Any) -> list[int]:
if value <= 0: if value <= 0:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"document_ids_to_add_in_context[{i}] must be a positive integer" detail=f"document_ids_to_add_in_context[{i}] must be a positive integer",
) )
validated_ids.append(value) validated_ids.append(value)
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"document_ids_to_add_in_context[{i}] must be an integer or string representation of an integer" detail=f"document_ids_to_add_in_context[{i}] must be an integer or string representation of an integer",
) )
return validated_ids return validated_ids
@ -162,29 +153,26 @@ def validate_connectors(connectors: Any) -> list[str]:
if not isinstance(connectors, list): if not isinstance(connectors, list):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail="selected_connectors must be a list"
detail="selected_connectors must be a list"
) )
validated_connectors = [] validated_connectors = []
for i, connector in enumerate(connectors): for i, connector in enumerate(connectors):
if not isinstance(connector, str): if not isinstance(connector, str):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail=f"selected_connectors[{i}] must be a string"
detail=f"selected_connectors[{i}] must be a string"
) )
if not connector.strip(): if not connector.strip():
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail=f"selected_connectors[{i}] cannot be empty"
detail=f"selected_connectors[{i}] cannot be empty"
) )
trimmed = connector.strip() trimmed = connector.strip()
if not re.fullmatch(r'[\w\-_]+', trimmed): if not re.fullmatch(r"[\w\-_]+", trimmed):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"selected_connectors[{i}] contains invalid characters" detail=f"selected_connectors[{i}] contains invalid characters",
) )
validated_connectors.append(trimmed) validated_connectors.append(trimmed)
@ -208,22 +196,16 @@ def validate_research_mode(research_mode: Any) -> str:
return "QNA" # Default value return "QNA" # Default value
if not isinstance(research_mode, str): if not isinstance(research_mode, str):
raise HTTPException( raise HTTPException(status_code=400, detail="research_mode must be a string")
status_code=400,
detail="research_mode must be a string"
)
normalized_mode = research_mode.strip().upper() normalized_mode = research_mode.strip().upper()
if not normalized_mode: if not normalized_mode:
raise HTTPException( raise HTTPException(status_code=400, detail="research_mode cannot be empty")
status_code=400,
detail="research_mode cannot be empty"
)
valid_modes = ["REPORT_GENERAL", "REPORT_DEEP", "REPORT_DEEPER", "QNA"] valid_modes = ["REPORT_GENERAL", "REPORT_DEEP", "REPORT_DEEPER", "QNA"]
if normalized_mode not in valid_modes: if normalized_mode not in valid_modes:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"research_mode must be one of: {', '.join(valid_modes)}" detail=f"research_mode must be one of: {', '.join(valid_modes)}",
) )
return normalized_mode return normalized_mode
@ -245,22 +227,16 @@ def validate_search_mode(search_mode: Any) -> str:
return "CHUNKS" # Default value return "CHUNKS" # Default value
if not isinstance(search_mode, str): if not isinstance(search_mode, str):
raise HTTPException( raise HTTPException(status_code=400, detail="search_mode must be a string")
status_code=400,
detail="search_mode must be a string"
)
normalized_mode = search_mode.strip().upper() normalized_mode = search_mode.strip().upper()
if not normalized_mode: if not normalized_mode:
raise HTTPException( raise HTTPException(status_code=400, detail="search_mode cannot be empty")
status_code=400,
detail="search_mode cannot be empty"
)
valid_modes = ["CHUNKS", "DOCUMENTS"] valid_modes = ["CHUNKS", "DOCUMENTS"]
if normalized_mode not in valid_modes: if normalized_mode not in valid_modes:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"search_mode must be one of: {', '.join(valid_modes)}" detail=f"search_mode must be one of: {', '.join(valid_modes)}",
) )
return normalized_mode return normalized_mode
@ -279,55 +255,44 @@ def validate_messages(messages: Any) -> list[dict]:
HTTPException: If validation fails HTTPException: If validation fails
""" """
if not isinstance(messages, list): if not isinstance(messages, list):
raise HTTPException( raise HTTPException(status_code=400, detail="messages must be a list")
status_code=400,
detail="messages must be a list"
)
if not messages: if not messages:
raise HTTPException( raise HTTPException(status_code=400, detail="messages cannot be empty")
status_code=400,
detail="messages cannot be empty"
)
validated_messages = [] validated_messages = []
for i, message in enumerate(messages): for i, message in enumerate(messages):
if not isinstance(message, dict): if not isinstance(message, dict):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail=f"messages[{i}] must be a dictionary"
detail=f"messages[{i}] must be a dictionary"
) )
if "role" not in message: if "role" not in message:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail=f"messages[{i}] must have a 'role' field"
detail=f"messages[{i}] must have a 'role' field"
) )
if "content" not in message: if "content" not in message:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail=f"messages[{i}] must have a 'content' field"
detail=f"messages[{i}] must have a 'content' field"
) )
role = message["role"] role = message["role"]
if not isinstance(role, str) or role not in ["user", "assistant", "system"]: if not isinstance(role, str) or role not in ["user", "assistant", "system"]:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"messages[{i}].role must be 'user', 'assistant', or 'system'" detail=f"messages[{i}].role must be 'user', 'assistant', or 'system'",
) )
content = message["content"] content = message["content"]
if not isinstance(content, str): if not isinstance(content, str):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail=f"messages[{i}].content must be a string"
detail=f"messages[{i}].content must be a string"
) )
if not content.strip(): if not content.strip():
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail=f"messages[{i}].content cannot be empty"
detail=f"messages[{i}].content cannot be empty"
) )
# Trim content and enforce max length (10,000 chars) # Trim content and enforce max length (10,000 chars)
@ -335,13 +300,10 @@ def validate_messages(messages: Any) -> list[dict]:
if len(sanitized_content) > 10000: # Reasonable limit if len(sanitized_content) > 10000: # Reasonable limit
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"messages[{i}].content is too long (max 10000 characters)" detail=f"messages[{i}].content is too long (max 10000 characters)",
) )
validated_messages.append({ validated_messages.append({"role": role, "content": sanitized_content})
"role": role,
"content": sanitized_content
})
return validated_messages return validated_messages
@ -360,18 +322,12 @@ def validate_email(email: str) -> str:
HTTPException: If validation fails HTTPException: If validation fails
""" """
if not email or not email.strip(): if not email or not email.strip():
raise HTTPException( raise HTTPException(status_code=400, detail="Email address is required")
status_code=400,
detail="Email address is required"
)
email = email.strip() email = email.strip()
if not validators.email(email): if not validators.email(email):
raise HTTPException( raise HTTPException(status_code=400, detail="Invalid email address format")
status_code=400,
detail="Invalid email address format"
)
return email return email
@ -390,18 +346,12 @@ def validate_url(url: str) -> str:
HTTPException: If validation fails HTTPException: If validation fails
""" """
if not url or not url.strip(): if not url or not url.strip():
raise HTTPException( raise HTTPException(status_code=400, detail="URL is required")
status_code=400,
detail="URL is required"
)
url = url.strip() url = url.strip()
if not validators.url(url): if not validators.url(url):
raise HTTPException( raise HTTPException(status_code=400, detail="Invalid URL format")
status_code=400,
detail="Invalid URL format"
)
return url return url
@ -420,23 +370,19 @@ def validate_uuid(uuid_string: str) -> str:
HTTPException: If validation fails HTTPException: If validation fails
""" """
if not uuid_string or not uuid_string.strip(): if not uuid_string or not uuid_string.strip():
raise HTTPException( raise HTTPException(status_code=400, detail="UUID is required")
status_code=400,
detail="UUID is required"
)
uuid_string = uuid_string.strip() uuid_string = uuid_string.strip()
if not validators.uuid(uuid_string): if not validators.uuid(uuid_string):
raise HTTPException( raise HTTPException(status_code=400, detail="Invalid UUID format")
status_code=400,
detail="Invalid UUID format"
)
return uuid_string return uuid_string
def validate_connector_config(connector_type: str | Any, config: dict[str, Any]) -> dict[str, Any]: def validate_connector_config(
connector_type: str | Any, config: dict[str, Any]
) -> dict[str, Any]:
""" """
Validate connector configuration based on connector type. Validate connector configuration based on connector type.
@ -454,7 +400,11 @@ def validate_connector_config(connector_type: str | Any, config: dict[str, Any])
raise ValueError("config must be a dictionary of connector settings") raise ValueError("config must be a dictionary of connector settings")
# Convert enum to string if needed # Convert enum to string if needed
connector_type_str = str(connector_type).split('.')[-1] if hasattr(connector_type, 'value') else str(connector_type) connector_type_str = (
str(connector_type).split(".")[-1]
if hasattr(connector_type, "value")
else str(connector_type)
)
# Validation function helpers # Validation function helpers
def validate_email_field(key: str, connector_name: str) -> None: def validate_email_field(key: str, connector_name: str) -> None:
@ -472,58 +422,47 @@ def validate_connector_config(connector_type: str | Any, config: dict[str, Any])
# Lookup table for connector validation rules # Lookup table for connector validation rules
connector_rules = { connector_rules = {
"SERPER_API": { "SERPER_API": {"required": ["SERPER_API_KEY"], "validators": {}},
"required": ["SERPER_API_KEY"], "TAVILY_API": {"required": ["TAVILY_API_KEY"], "validators": {}},
"validators": {} "LINKUP_API": {"required": ["LINKUP_API_KEY"], "validators": {}},
}, "SLACK_CONNECTOR": {"required": ["SLACK_BOT_TOKEN"], "validators": {}},
"TAVILY_API": {
"required": ["TAVILY_API_KEY"],
"validators": {}
},
"LINKUP_API": {
"required": ["LINKUP_API_KEY"],
"validators": {}
},
"SLACK_CONNECTOR": {
"required": ["SLACK_BOT_TOKEN"],
"validators": {}
},
"NOTION_CONNECTOR": { "NOTION_CONNECTOR": {
"required": ["NOTION_INTEGRATION_TOKEN"], "required": ["NOTION_INTEGRATION_TOKEN"],
"validators": {} "validators": {},
}, },
"GITHUB_CONNECTOR": { "GITHUB_CONNECTOR": {
"required": ["GITHUB_PAT", "repo_full_names"], "required": ["GITHUB_PAT", "repo_full_names"],
"validators": { "validators": {
"repo_full_names": lambda: validate_list_field("repo_full_names", "repo_full_names") "repo_full_names": lambda: validate_list_field(
} "repo_full_names", "repo_full_names"
)
}, },
"LINEAR_CONNECTOR": {
"required": ["LINEAR_API_KEY"],
"validators": {}
},
"DISCORD_CONNECTOR": {
"required": ["DISCORD_BOT_TOKEN"],
"validators": {}
}, },
"LINEAR_CONNECTOR": {"required": ["LINEAR_API_KEY"], "validators": {}},
"DISCORD_CONNECTOR": {"required": ["DISCORD_BOT_TOKEN"], "validators": {}},
"JIRA_CONNECTOR": { "JIRA_CONNECTOR": {
"required": ["JIRA_EMAIL", "JIRA_API_TOKEN", "JIRA_BASE_URL"], "required": ["JIRA_EMAIL", "JIRA_API_TOKEN", "JIRA_BASE_URL"],
"validators": { "validators": {
"JIRA_EMAIL": lambda: validate_email_field("JIRA_EMAIL", "JIRA"), "JIRA_EMAIL": lambda: validate_email_field("JIRA_EMAIL", "JIRA"),
"JIRA_BASE_URL": lambda: validate_url_field("JIRA_BASE_URL", "JIRA") "JIRA_BASE_URL": lambda: validate_url_field("JIRA_BASE_URL", "JIRA"),
} },
}, },
"CONFLUENCE_CONNECTOR": { "CONFLUENCE_CONNECTOR": {
"required": ["CONFLUENCE_BASE_URL", "CONFLUENCE_EMAIL", "CONFLUENCE_API_TOKEN"], "required": [
"CONFLUENCE_BASE_URL",
"CONFLUENCE_EMAIL",
"CONFLUENCE_API_TOKEN",
],
"validators": { "validators": {
"CONFLUENCE_EMAIL": lambda: validate_email_field("CONFLUENCE_EMAIL", "Confluence"), "CONFLUENCE_EMAIL": lambda: validate_email_field(
"CONFLUENCE_BASE_URL": lambda: validate_url_field("CONFLUENCE_BASE_URL", "Confluence") "CONFLUENCE_EMAIL", "Confluence"
} ),
"CONFLUENCE_BASE_URL": lambda: validate_url_field(
"CONFLUENCE_BASE_URL", "Confluence"
),
}, },
"CLICKUP_CONNECTOR": {
"required": ["CLICKUP_API_TOKEN"],
"validators": {}
}, },
"CLICKUP_CONNECTOR": {"required": ["CLICKUP_API_TOKEN"], "validators": {}},
# "GOOGLE_CALENDAR_CONNECTOR": { # "GOOGLE_CALENDAR_CONNECTOR": {
# "required": ["token", "refresh_token", "token_uri", "client_id", "expiry", "scopes", "client_secret"], # "required": ["token", "refresh_token", "token_uri", "client_id", "expiry", "scopes", "client_secret"],
# "validators": {}, # "validators": {},
@ -538,10 +477,7 @@ def validate_connector_config(connector_type: str | Any, config: dict[str, Any])
# "required": ["AIRTABLE_API_KEY", "AIRTABLE_BASE_ID"], # "required": ["AIRTABLE_API_KEY", "AIRTABLE_BASE_ID"],
# "validators": {} # "validators": {}
# }, # },
"LUMA_CONNECTOR": { "LUMA_CONNECTOR": {"required": ["LUMA_API_KEY"], "validators": {}},
"required": ["LUMA_API_KEY"],
"validators": {}
}
} }
rules = connector_rules.get(connector_type_str) rules = connector_rules.get(connector_type_str)

View file

@ -1,12 +1,16 @@
"use client"; "use client";
import { Loader2 } from "lucide-react";
import { usePathname, useRouter } from "next/navigation";
import type React from "react"; import type React from "react";
import { useState } from "react"; import { useEffect, useState } from "react";
import { DashboardBreadcrumb } from "@/components/dashboard-breadcrumb"; import { DashboardBreadcrumb } from "@/components/dashboard-breadcrumb";
import { AppSidebarProvider } from "@/components/sidebar/AppSidebarProvider"; import { AppSidebarProvider } from "@/components/sidebar/AppSidebarProvider";
import { ThemeTogglerComponent } from "@/components/theme/theme-toggle"; import { ThemeTogglerComponent } from "@/components/theme/theme-toggle";
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
import { Separator } from "@/components/ui/separator"; import { Separator } from "@/components/ui/separator";
import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar"; import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
import { useLLMPreferences } from "@/hooks/use-llm-configs";
export function DashboardClientLayout({ export function DashboardClientLayout({
children, children,
@ -19,6 +23,16 @@ export function DashboardClientLayout({
navSecondary: any[]; navSecondary: any[];
navMain: any[]; navMain: any[];
}) { }) {
const router = useRouter();
const pathname = usePathname();
const searchSpaceIdNum = Number(searchSpaceId);
const { loading, error, isOnboardingComplete } = useLLMPreferences(searchSpaceIdNum);
const [hasCheckedOnboarding, setHasCheckedOnboarding] = useState(false);
// Skip onboarding check if we're already on the onboarding page
const isOnboardingPage = pathname?.includes("/onboard");
const [open, setOpen] = useState<boolean>(() => { const [open, setOpen] = useState<boolean>(() => {
try { try {
const match = document.cookie.match(/(?:^|; )sidebar_state=([^;]+)/); const match = document.cookie.match(/(?:^|; )sidebar_state=([^;]+)/);
@ -29,6 +43,68 @@ export function DashboardClientLayout({
return true; return true;
}); });
useEffect(() => {
// Skip check if already on onboarding page
if (isOnboardingPage) {
setHasCheckedOnboarding(true);
return;
}
// Only check once after preferences have loaded
if (!loading && !hasCheckedOnboarding) {
const onboardingComplete = isOnboardingComplete();
if (!onboardingComplete) {
router.push(`/dashboard/${searchSpaceId}/onboard`);
}
setHasCheckedOnboarding(true);
}
}, [
loading,
isOnboardingComplete,
isOnboardingPage,
router,
searchSpaceId,
hasCheckedOnboarding,
]);
// Show loading screen while checking onboarding status (only on first load)
if (!hasCheckedOnboarding && loading && !isOnboardingPage) {
return (
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
<Card className="w-[350px] bg-background/60 backdrop-blur-sm">
<CardHeader className="pb-2">
<CardTitle className="text-xl font-medium">Loading Configuration</CardTitle>
<CardDescription>Checking your LLM preferences...</CardDescription>
</CardHeader>
<CardContent className="flex justify-center py-6">
<Loader2 className="h-12 w-12 text-primary animate-spin" />
</CardContent>
</Card>
</div>
);
}
// Show error screen if there's an error loading preferences (but not on onboarding page)
if (error && !hasCheckedOnboarding && !isOnboardingPage) {
return (
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
<Card className="w-[400px] bg-background/60 backdrop-blur-sm border-destructive/20">
<CardHeader className="pb-2">
<CardTitle className="text-xl font-medium text-destructive">
Configuration Error
</CardTitle>
<CardDescription>Failed to load your LLM configuration</CardDescription>
</CardHeader>
<CardContent>
<p className="text-sm text-muted-foreground">{error}</p>
</CardContent>
</Card>
</div>
);
}
return ( return (
<SidebarProvider open={open} onOpenChange={setOpen}> <SidebarProvider open={open} onOpenChange={setOpen}>
{/* Use AppSidebarProvider which fetches user, search space, and recent chats */} {/* Use AppSidebarProvider which fetches user, search space, and recent chats */}

View file

@ -33,6 +33,12 @@ export default function DashboardLayout({
icon: "SquareTerminal", icon: "SquareTerminal",
items: [], items: [],
}, },
{
title: "Manage LLMs",
url: `/dashboard/${search_space_id}/settings`,
icon: "Settings2",
items: [],
},
{ {
title: "Documents", title: "Documents",

View file

@ -2,7 +2,7 @@
import { ArrowLeft, ArrowRight, Bot, CheckCircle, Sparkles } from "lucide-react"; import { ArrowLeft, ArrowRight, Bot, CheckCircle, Sparkles } from "lucide-react";
import { AnimatePresence, motion } from "motion/react"; import { AnimatePresence, motion } from "motion/react";
import { useRouter } from "next/navigation"; import { useParams, useRouter } from "next/navigation";
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { Logo } from "@/components/Logo"; import { Logo } from "@/components/Logo";
import { AddProviderStep } from "@/components/onboard/add-provider-step"; import { AddProviderStep } from "@/components/onboard/add-provider-step";
@ -17,13 +17,16 @@ const TOTAL_STEPS = 3;
const OnboardPage = () => { const OnboardPage = () => {
const router = useRouter(); const router = useRouter();
const { llmConfigs, loading: configsLoading, refreshConfigs } = useLLMConfigs(); const params = useParams();
const searchSpaceId = Number(params.search_space_id);
const { llmConfigs, loading: configsLoading, refreshConfigs } = useLLMConfigs(searchSpaceId);
const { const {
preferences, preferences,
loading: preferencesLoading, loading: preferencesLoading,
isOnboardingComplete, isOnboardingComplete,
refreshPreferences, refreshPreferences,
} = useLLMPreferences(); } = useLLMPreferences(searchSpaceId);
const [currentStep, setCurrentStep] = useState(1); const [currentStep, setCurrentStep] = useState(1);
const [hasUserProgressed, setHasUserProgressed] = useState(false); const [hasUserProgressed, setHasUserProgressed] = useState(false);
@ -44,11 +47,23 @@ const OnboardPage = () => {
}, [currentStep]); }, [currentStep]);
// Redirect to dashboard if onboarding is already complete and user hasn't progressed (fresh page load) // Redirect to dashboard if onboarding is already complete and user hasn't progressed (fresh page load)
// But only check once to avoid redirect loops
useEffect(() => { useEffect(() => {
if (!preferencesLoading && isOnboardingComplete() && !hasUserProgressed) { if (!preferencesLoading && !configsLoading && isOnboardingComplete() && !hasUserProgressed) {
router.push("/dashboard"); // Small delay to ensure the check is stable
const timer = setTimeout(() => {
router.push(`/dashboard/${searchSpaceId}`);
}, 100);
return () => clearTimeout(timer);
} }
}, [preferencesLoading, isOnboardingComplete, hasUserProgressed, router]); }, [
preferencesLoading,
configsLoading,
isOnboardingComplete,
hasUserProgressed,
router,
searchSpaceId,
]);
const progress = (currentStep / TOTAL_STEPS) * 100; const progress = (currentStep / TOTAL_STEPS) * 100;
@ -80,7 +95,7 @@ const OnboardPage = () => {
}; };
const handleComplete = () => { const handleComplete = () => {
router.push("/dashboard"); router.push(`/dashboard/${searchSpaceId}/documents`);
}; };
if (configsLoading || preferencesLoading) { if (configsLoading || preferencesLoading) {
@ -184,12 +199,18 @@ const OnboardPage = () => {
> >
{currentStep === 1 && ( {currentStep === 1 && (
<AddProviderStep <AddProviderStep
searchSpaceId={searchSpaceId}
onConfigCreated={refreshConfigs} onConfigCreated={refreshConfigs}
onConfigDeleted={refreshConfigs} onConfigDeleted={refreshConfigs}
/> />
)} )}
{currentStep === 2 && <AssignRolesStep onPreferencesUpdated={refreshPreferences} />} {currentStep === 2 && (
{currentStep === 3 && <CompletionStep />} <AssignRolesStep
searchSpaceId={searchSpaceId}
onPreferencesUpdated={refreshPreferences}
/>
)}
{currentStep === 3 && <CompletionStep searchSpaceId={searchSpaceId} />}
</motion.div> </motion.div>
</AnimatePresence> </AnimatePresence>
</CardContent> </CardContent>

View file

@ -1,14 +1,16 @@
"use client"; "use client";
import { ArrowLeft, Bot, Brain, Settings } from "lucide-react"; // Import ArrowLeft icon import { ArrowLeft, Bot, Brain, Settings } from "lucide-react";
import { useRouter } from "next/navigation"; // Add this import import { useParams, useRouter } from "next/navigation";
import { LLMRoleManager } from "@/components/settings/llm-role-manager"; import { LLMRoleManager } from "@/components/settings/llm-role-manager";
import { ModelConfigManager } from "@/components/settings/model-config-manager"; import { ModelConfigManager } from "@/components/settings/model-config-manager";
import { Separator } from "@/components/ui/separator"; import { Separator } from "@/components/ui/separator";
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
export default function SettingsPage() { export default function SettingsPage() {
const router = useRouter(); // Initialize router const router = useRouter();
const params = useParams();
const searchSpaceId = Number(params.search_space_id);
return ( return (
<div className="min-h-screen bg-background"> <div className="min-h-screen bg-background">
@ -19,7 +21,7 @@ export default function SettingsPage() {
<div className="flex items-center space-x-4"> <div className="flex items-center space-x-4">
{/* Back Button */} {/* Back Button */}
<button <button
onClick={() => router.push("/dashboard")} onClick={() => router.push(`/dashboard/${searchSpaceId}`)}
className="flex items-center justify-center h-10 w-10 rounded-lg bg-primary/10 hover:bg-primary/20 transition-colors" className="flex items-center justify-center h-10 w-10 rounded-lg bg-primary/10 hover:bg-primary/20 transition-colors"
aria-label="Back to Dashboard" aria-label="Back to Dashboard"
type="button" type="button"
@ -32,7 +34,7 @@ export default function SettingsPage() {
<div className="space-y-1"> <div className="space-y-1">
<h1 className="text-3xl font-bold tracking-tight">Settings</h1> <h1 className="text-3xl font-bold tracking-tight">Settings</h1>
<p className="text-lg text-muted-foreground"> <p className="text-lg text-muted-foreground">
Manage your LLM configurations and role assignments. Manage your LLM configurations and role assignments for this search space.
</p> </p>
</div> </div>
</div> </div>
@ -57,11 +59,11 @@ export default function SettingsPage() {
</div> </div>
<TabsContent value="models" className="space-y-6"> <TabsContent value="models" className="space-y-6">
<ModelConfigManager /> <ModelConfigManager searchSpaceId={searchSpaceId} />
</TabsContent> </TabsContent>
<TabsContent value="roles" className="space-y-6"> <TabsContent value="roles" className="space-y-6">
<LLMRoleManager /> <LLMRoleManager searchSpaceId={searchSpaceId} />
</TabsContent> </TabsContent>
</Tabs> </Tabs>
</div> </div>

View file

@ -4,7 +4,6 @@ import { Loader2 } from "lucide-react";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
import { useLLMPreferences } from "@/hooks/use-llm-configs";
interface DashboardLayoutProps { interface DashboardLayoutProps {
children: React.ReactNode; children: React.ReactNode;
@ -12,7 +11,6 @@ interface DashboardLayoutProps {
export default function DashboardLayout({ children }: DashboardLayoutProps) { export default function DashboardLayout({ children }: DashboardLayoutProps) {
const router = useRouter(); const router = useRouter();
const { loading, error, isOnboardingComplete } = useLLMPreferences();
const [isCheckingAuth, setIsCheckingAuth] = useState(true); const [isCheckingAuth, setIsCheckingAuth] = useState(true);
useEffect(() => { useEffect(() => {
@ -25,23 +23,14 @@ export default function DashboardLayout({ children }: DashboardLayoutProps) {
setIsCheckingAuth(false); setIsCheckingAuth(false);
}, [router]); }, [router]);
useEffect(() => { // Show loading screen while checking authentication
// Wait for preferences to load, then check if onboarding is complete if (isCheckingAuth) {
if (!loading && !error && !isCheckingAuth) {
if (!isOnboardingComplete()) {
router.push("/onboard");
}
}
}, [loading, error, isCheckingAuth, isOnboardingComplete, router]);
// Show loading screen while checking authentication or loading preferences
if (isCheckingAuth || loading) {
return ( return (
<div className="flex flex-col items-center justify-center min-h-screen space-y-4"> <div className="flex flex-col items-center justify-center min-h-screen space-y-4">
<Card className="w-[350px] bg-background/60 backdrop-blur-sm"> <Card className="w-[350px] bg-background/60 backdrop-blur-sm">
<CardHeader className="pb-2"> <CardHeader className="pb-2">
<CardTitle className="text-xl font-medium">Loading Dashboard</CardTitle> <CardTitle className="text-xl font-medium">Loading Dashboard</CardTitle>
<CardDescription>Checking your configuration...</CardDescription> <CardDescription>Checking authentication...</CardDescription>
</CardHeader> </CardHeader>
<CardContent className="flex justify-center py-6"> <CardContent className="flex justify-center py-6">
<Loader2 className="h-12 w-12 text-primary animate-spin" /> <Loader2 className="h-12 w-12 text-primary animate-spin" />
@ -51,42 +40,5 @@ export default function DashboardLayout({ children }: DashboardLayoutProps) {
); );
} }
// Show error screen if there's an error loading preferences
if (error) {
return (
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
<Card className="w-[400px] bg-background/60 backdrop-blur-sm border-destructive/20">
<CardHeader className="pb-2">
<CardTitle className="text-xl font-medium text-destructive">
Configuration Error
</CardTitle>
<CardDescription>Failed to load your LLM configuration</CardDescription>
</CardHeader>
<CardContent>
<p className="text-sm text-muted-foreground">{error}</p>
</CardContent>
</Card>
</div>
);
}
// Only render children if onboarding is complete
if (isOnboardingComplete()) {
return <>{children}</>; return <>{children}</>;
}
// This should not be reached due to redirect, but just in case
return (
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
<Card className="w-[350px] bg-background/60 backdrop-blur-sm">
<CardHeader className="pb-2">
<CardTitle className="text-xl font-medium">Redirecting...</CardTitle>
<CardDescription>Taking you to complete your setup</CardDescription>
</CardHeader>
<CardContent className="flex justify-center py-6">
<Loader2 className="h-12 w-12 text-primary animate-spin" />
</CardContent>
</Card>
</div>
);
} }

View file

@ -66,10 +66,6 @@ export function UserDropdown({
</DropdownMenuItem> </DropdownMenuItem>
</DropdownMenuGroup> </DropdownMenuGroup>
<DropdownMenuSeparator /> <DropdownMenuSeparator />
<DropdownMenuItem onClick={() => router.push(`/settings`)}>
<Settings className="mr-2 h-4 w-4" />
Settings
</DropdownMenuItem>
<DropdownMenuItem onClick={handleLogout}> <DropdownMenuItem onClick={handleLogout}>
<LogOut className="mr-2 h-4 w-4" /> <LogOut className="mr-2 h-4 w-4" />
Log out Log out

View file

@ -332,8 +332,11 @@ const ResearchModeSelector = React.memo(
ResearchModeSelector.displayName = "ResearchModeSelector"; ResearchModeSelector.displayName = "ResearchModeSelector";
const LLMSelector = React.memo(() => { const LLMSelector = React.memo(() => {
const { llmConfigs, loading: llmLoading, error } = useLLMConfigs(); const { search_space_id } = useParams();
const { preferences, updatePreferences, loading: preferencesLoading } = useLLMPreferences(); const searchSpaceId = Number(search_space_id);
const { llmConfigs, loading: llmLoading, error } = useLLMConfigs(searchSpaceId);
const { preferences, updatePreferences, loading: preferencesLoading } = useLLMPreferences(searchSpaceId);
const isLoading = llmLoading || preferencesLoading; const isLoading = llmLoading || preferencesLoading;

View file

@ -23,12 +23,17 @@ import { type CreateLLMConfig, useLLMConfigs } from "@/hooks/use-llm-configs";
import InferenceParamsEditor from "../inference-params-editor"; import InferenceParamsEditor from "../inference-params-editor";
interface AddProviderStepProps { interface AddProviderStepProps {
searchSpaceId: number;
onConfigCreated?: () => void; onConfigCreated?: () => void;
onConfigDeleted?: () => void; onConfigDeleted?: () => void;
} }
export function AddProviderStep({ onConfigCreated, onConfigDeleted }: AddProviderStepProps) { export function AddProviderStep({
const { llmConfigs, createLLMConfig, deleteLLMConfig } = useLLMConfigs(); searchSpaceId,
onConfigCreated,
onConfigDeleted,
}: AddProviderStepProps) {
const { llmConfigs, createLLMConfig, deleteLLMConfig } = useLLMConfigs(searchSpaceId);
const [isAddingNew, setIsAddingNew] = useState(false); const [isAddingNew, setIsAddingNew] = useState(false);
const [formData, setFormData] = useState<CreateLLMConfig>({ const [formData, setFormData] = useState<CreateLLMConfig>({
name: "", name: "",
@ -38,6 +43,7 @@ export function AddProviderStep({ onConfigCreated, onConfigDeleted }: AddProvide
api_key: "", api_key: "",
api_base: "", api_base: "",
litellm_params: {}, litellm_params: {},
search_space_id: searchSpaceId,
}); });
const [isSubmitting, setIsSubmitting] = useState(false); const [isSubmitting, setIsSubmitting] = useState(false);
@ -65,6 +71,7 @@ export function AddProviderStep({ onConfigCreated, onConfigDeleted }: AddProvide
api_key: "", api_key: "",
api_base: "", api_base: "",
litellm_params: {}, litellm_params: {},
search_space_id: searchSpaceId,
}); });
setIsAddingNew(false); setIsAddingNew(false);
// Notify parent component that a config was created // Notify parent component that a config was created
@ -253,7 +260,6 @@ export function AddProviderStep({ onConfigCreated, onConfigDeleted }: AddProvide
/> />
</div> </div>
<div className="flex gap-2 pt-4"> <div className="flex gap-2 pt-4">
<Button type="submit" disabled={isSubmitting}> <Button type="submit" disabled={isSubmitting}>
{isSubmitting ? "Adding..." : "Add Provider"} {isSubmitting ? "Adding..." : "Add Provider"}

View file

@ -41,12 +41,13 @@ const ROLE_DESCRIPTIONS = {
}; };
interface AssignRolesStepProps { interface AssignRolesStepProps {
searchSpaceId: number;
onPreferencesUpdated?: () => Promise<void>; onPreferencesUpdated?: () => Promise<void>;
} }
export function AssignRolesStep({ onPreferencesUpdated }: AssignRolesStepProps) { export function AssignRolesStep({ searchSpaceId, onPreferencesUpdated }: AssignRolesStepProps) {
const { llmConfigs } = useLLMConfigs(); const { llmConfigs } = useLLMConfigs(searchSpaceId);
const { preferences, updatePreferences } = useLLMPreferences(); const { preferences, updatePreferences } = useLLMPreferences(searchSpaceId);
const [assignments, setAssignments] = useState({ const [assignments, setAssignments] = useState({
long_context_llm_id: preferences.long_context_llm_id || "", long_context_llm_id: preferences.long_context_llm_id || "",

View file

@ -12,9 +12,13 @@ const ROLE_ICONS = {
strategic: Bot, strategic: Bot,
}; };
export function CompletionStep() { interface CompletionStepProps {
const { llmConfigs } = useLLMConfigs(); searchSpaceId: number;
const { preferences } = useLLMPreferences(); }
export function CompletionStep({ searchSpaceId }: CompletionStepProps) {
const { llmConfigs } = useLLMConfigs(searchSpaceId);
const { preferences } = useLLMPreferences(searchSpaceId);
const assignedConfigs = { const assignedConfigs = {
long_context: llmConfigs.find((c) => c.id === preferences.long_context_llm_id), long_context: llmConfigs.find((c) => c.id === preferences.long_context_llm_id),

View file

@ -56,20 +56,24 @@ const ROLE_DESCRIPTIONS = {
}, },
}; };
export function LLMRoleManager() { interface LLMRoleManagerProps {
searchSpaceId: number;
}
export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
const { const {
llmConfigs, llmConfigs,
loading: configsLoading, loading: configsLoading,
error: configsError, error: configsError,
refreshConfigs, refreshConfigs,
} = useLLMConfigs(); } = useLLMConfigs(searchSpaceId);
const { const {
preferences, preferences,
loading: preferencesLoading, loading: preferencesLoading,
error: preferencesError, error: preferencesError,
updatePreferences, updatePreferences,
refreshPreferences, refreshPreferences,
} = useLLMPreferences(); } = useLLMPreferences(searchSpaceId);
const [assignments, setAssignments] = useState({ const [assignments, setAssignments] = useState({
long_context_llm_id: preferences.long_context_llm_id || "", long_context_llm_id: preferences.long_context_llm_id || "",

View file

@ -41,7 +41,11 @@ import { LLM_PROVIDERS } from "@/contracts/enums/llm-providers";
import { type CreateLLMConfig, type LLMConfig, useLLMConfigs } from "@/hooks/use-llm-configs"; import { type CreateLLMConfig, type LLMConfig, useLLMConfigs } from "@/hooks/use-llm-configs";
import InferenceParamsEditor from "../inference-params-editor"; import InferenceParamsEditor from "../inference-params-editor";
export function ModelConfigManager() { interface ModelConfigManagerProps {
searchSpaceId: number;
}
export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) {
const { const {
llmConfigs, llmConfigs,
loading, loading,
@ -50,7 +54,7 @@ export function ModelConfigManager() {
updateLLMConfig, updateLLMConfig,
deleteLLMConfig, deleteLLMConfig,
refreshConfigs, refreshConfigs,
} = useLLMConfigs(); } = useLLMConfigs(searchSpaceId);
const [isAddingNew, setIsAddingNew] = useState(false); const [isAddingNew, setIsAddingNew] = useState(false);
const [editingConfig, setEditingConfig] = useState<LLMConfig | null>(null); const [editingConfig, setEditingConfig] = useState<LLMConfig | null>(null);
const [showApiKey, setShowApiKey] = useState<Record<number, boolean>>({}); const [showApiKey, setShowApiKey] = useState<Record<number, boolean>>({});
@ -62,6 +66,7 @@ export function ModelConfigManager() {
api_key: "", api_key: "",
api_base: "", api_base: "",
litellm_params: {}, litellm_params: {},
search_space_id: searchSpaceId,
}); });
const [isSubmitting, setIsSubmitting] = useState(false); const [isSubmitting, setIsSubmitting] = useState(false);
@ -76,9 +81,10 @@ export function ModelConfigManager() {
api_key: editingConfig.api_key, api_key: editingConfig.api_key,
api_base: editingConfig.api_base || "", api_base: editingConfig.api_base || "",
litellm_params: editingConfig.litellm_params || {}, litellm_params: editingConfig.litellm_params || {},
search_space_id: searchSpaceId,
}); });
} }
}, [editingConfig]); }, [editingConfig, searchSpaceId]);
const handleInputChange = (field: keyof CreateLLMConfig, value: string) => { const handleInputChange = (field: keyof CreateLLMConfig, value: string) => {
setFormData((prev) => ({ ...prev, [field]: value })); setFormData((prev) => ({ ...prev, [field]: value }));
@ -113,6 +119,7 @@ export function ModelConfigManager() {
api_key: "", api_key: "",
api_base: "", api_base: "",
litellm_params: {}, litellm_params: {},
search_space_id: searchSpaceId,
}); });
setIsAddingNew(false); setIsAddingNew(false);
setEditingConfig(null); setEditingConfig(null);
@ -426,6 +433,7 @@ export function ModelConfigManager() {
api_key: "", api_key: "",
api_base: "", api_base: "",
litellm_params: {}, litellm_params: {},
search_space_id: searchSpaceId,
}); });
} }
}} }}
@ -462,18 +470,12 @@ export function ModelConfigManager() {
value={formData.provider} value={formData.provider}
onValueChange={(value) => handleInputChange("provider", value)} onValueChange={(value) => handleInputChange("provider", value)}
> >
<SelectTrigger className="h-auto min-h-[2.5rem] py-2"> <SelectTrigger>
<SelectValue placeholder="Select a provider"> <SelectValue placeholder="Select a provider">
{formData.provider && ( {formData.provider && (
<div className="flex items-center space-x-2 py-1"> <span className="font-medium">
<div className="font-medium">
{LLM_PROVIDERS.find((p) => p.value === formData.provider)?.label} {LLM_PROVIDERS.find((p) => p.value === formData.provider)?.label}
</div> </span>
<div className="text-xs text-muted-foreground"></div>
<div className="text-xs text-muted-foreground">
{LLM_PROVIDERS.find((p) => p.value === formData.provider)?.description}
</div>
</div>
)} )}
</SelectValue> </SelectValue>
</SelectTrigger> </SelectTrigger>
@ -578,6 +580,7 @@ export function ModelConfigManager() {
api_key: "", api_key: "",
api_base: "", api_base: "",
litellm_params: {}, litellm_params: {},
search_space_id: searchSpaceId,
}); });
}} }}
disabled={isSubmitting} disabled={isSubmitting}

View file

@ -12,7 +12,7 @@ export interface LLMConfig {
api_base?: string; api_base?: string;
litellm_params?: Record<string, any>; litellm_params?: Record<string, any>;
created_at: string; created_at: string;
user_id: string; search_space_id: number;
} }
export interface LLMPreferences { export interface LLMPreferences {
@ -32,6 +32,7 @@ export interface CreateLLMConfig {
api_key: string; api_key: string;
api_base?: string; api_base?: string;
litellm_params?: Record<string, any>; litellm_params?: Record<string, any>;
search_space_id: number;
} }
export interface UpdateLLMConfig { export interface UpdateLLMConfig {
@ -44,16 +45,21 @@ export interface UpdateLLMConfig {
litellm_params?: Record<string, any>; litellm_params?: Record<string, any>;
} }
export function useLLMConfigs() { export function useLLMConfigs(searchSpaceId: number | null) {
const [llmConfigs, setLlmConfigs] = useState<LLMConfig[]>([]); const [llmConfigs, setLlmConfigs] = useState<LLMConfig[]>([]);
const [loading, setLoading] = useState(true); const [loading, setLoading] = useState(true);
const [error, setError] = useState<string | null>(null); const [error, setError] = useState<string | null>(null);
const fetchLLMConfigs = async () => { const fetchLLMConfigs = async () => {
if (!searchSpaceId) {
setLoading(false);
return;
}
try { try {
setLoading(true); setLoading(true);
const response = await fetch( const response = await fetch(
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/llm-configs/`, `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/llm-configs/?search_space_id=${searchSpaceId}`,
{ {
headers: { headers: {
Authorization: `Bearer ${localStorage.getItem("surfsense_bearer_token")}`, Authorization: `Bearer ${localStorage.getItem("surfsense_bearer_token")}`,
@ -79,7 +85,7 @@ export function useLLMConfigs() {
useEffect(() => { useEffect(() => {
fetchLLMConfigs(); fetchLLMConfigs();
}, []); }, [searchSpaceId]);
const createLLMConfig = async (config: CreateLLMConfig): Promise<LLMConfig | null> => { const createLLMConfig = async (config: CreateLLMConfig): Promise<LLMConfig | null> => {
try { try {
@ -181,16 +187,21 @@ export function useLLMConfigs() {
}; };
} }
export function useLLMPreferences() { export function useLLMPreferences(searchSpaceId: number | null) {
const [preferences, setPreferences] = useState<LLMPreferences>({}); const [preferences, setPreferences] = useState<LLMPreferences>({});
const [loading, setLoading] = useState(true); const [loading, setLoading] = useState(true);
const [error, setError] = useState<string | null>(null); const [error, setError] = useState<string | null>(null);
const fetchPreferences = async () => { const fetchPreferences = async () => {
if (!searchSpaceId) {
setLoading(false);
return;
}
try { try {
setLoading(true); setLoading(true);
const response = await fetch( const response = await fetch(
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/users/me/llm-preferences`, `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/llm-preferences`,
{ {
headers: { headers: {
Authorization: `Bearer ${localStorage.getItem("surfsense_bearer_token")}`, Authorization: `Bearer ${localStorage.getItem("surfsense_bearer_token")}`,
@ -216,12 +227,17 @@ export function useLLMPreferences() {
useEffect(() => { useEffect(() => {
fetchPreferences(); fetchPreferences();
}, []); }, [searchSpaceId]);
const updatePreferences = async (newPreferences: Partial<LLMPreferences>): Promise<boolean> => { const updatePreferences = async (newPreferences: Partial<LLMPreferences>): Promise<boolean> => {
if (!searchSpaceId) {
toast.error("Search space ID is required");
return false;
}
try { try {
const response = await fetch( const response = await fetch(
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/users/me/llm-preferences`, `${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/llm-preferences`,
{ {
method: "PUT", method: "PUT",
headers: { headers: {