feat: moved LLMConfigs from User to SearchSpaces

- RBAC soon??
- Updated various services and routes to handle search space-specific LLM preferences.
- Modified frontend components to pass search space ID for LLM configuration management.
- Removed onboarding page and settings page as part of the refactor.
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2025-10-10 00:50:29 -07:00
parent a1b1db3895
commit 633ea3ac0f
44 changed files with 1075 additions and 518 deletions

View file

@ -0,0 +1,352 @@
"""Migrate LLM configs to search spaces and add user preferences
Revision ID: 25
Revises: 24
Create Date: 2025-01-10 14:00:00.000000
Changes:
1. Migrate llm_configs from user association to search_space association
2. Create user_search_space_preferences table for per-user LLM preferences
3. Migrate existing user LLM preferences to user_search_space_preferences
4. Remove LLM preference columns from user table
"""
from collections.abc import Sequence
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "25"
down_revision: str | None = "24"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""
Upgrade schema to support collaborative search spaces with per-user preferences.
Migration steps:
1. Add search_space_id to llm_configs
2. Migrate existing llm_configs to first search space of their user
3. Replace user_id with search_space_id in llm_configs
4. Create user_search_space_preferences table
5. Migrate user LLM preferences to user_search_space_preferences
6. Remove LLM preference columns from user table
"""
from sqlalchemy import inspect
conn = op.get_bind()
inspector = inspect(conn)
# Get existing columns
llm_config_columns = [col["name"] for col in inspector.get_columns("llm_configs")]
user_columns = [col["name"] for col in inspector.get_columns("user")]
# ===== STEP 1: Add search_space_id to llm_configs =====
if "search_space_id" not in llm_config_columns:
op.add_column(
"llm_configs",
sa.Column("search_space_id", sa.Integer(), nullable=True),
)
# ===== STEP 2: Populate search_space_id with user's first search space =====
# This ensures existing LLM configs are assigned to a valid search space
op.execute(
"""
UPDATE llm_configs lc
SET search_space_id = (
SELECT id
FROM searchspaces ss
WHERE ss.user_id = lc.user_id
ORDER BY ss.created_at ASC
LIMIT 1
)
WHERE search_space_id IS NULL AND user_id IS NOT NULL
"""
)
# ===== STEP 3: Make search_space_id NOT NULL and add FK constraint =====
op.alter_column(
"llm_configs",
"search_space_id",
nullable=False,
)
# Add foreign key constraint
foreign_keys = [fk["name"] for fk in inspector.get_foreign_keys("llm_configs")]
if "fk_llm_configs_search_space_id" not in foreign_keys:
op.create_foreign_key(
"fk_llm_configs_search_space_id",
"llm_configs",
"searchspaces",
["search_space_id"],
["id"],
ondelete="CASCADE",
)
# Drop old user_id foreign key if it exists
if "fk_llm_configs_user_id_user" in foreign_keys:
op.drop_constraint(
"fk_llm_configs_user_id_user",
"llm_configs",
type_="foreignkey",
)
# Remove user_id column
if "user_id" in llm_config_columns:
op.drop_column("llm_configs", "user_id")
# ===== STEP 4: Create user_search_space_preferences table =====
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'user_search_space_preferences'
) THEN
CREATE TABLE user_search_space_preferences (
id SERIAL PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE,
long_context_llm_id INTEGER REFERENCES llm_configs(id) ON DELETE SET NULL,
fast_llm_id INTEGER REFERENCES llm_configs(id) ON DELETE SET NULL,
strategic_llm_id INTEGER REFERENCES llm_configs(id) ON DELETE SET NULL,
CONSTRAINT uq_user_searchspace UNIQUE (user_id, search_space_id)
);
END IF;
END$$;
"""
)
# Create indexes
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_indexes
WHERE tablename = 'user_search_space_preferences'
AND indexname = 'ix_user_search_space_preferences_id'
) THEN
CREATE INDEX ix_user_search_space_preferences_id
ON user_search_space_preferences(id);
END IF;
IF NOT EXISTS (
SELECT 1 FROM pg_indexes
WHERE tablename = 'user_search_space_preferences'
AND indexname = 'ix_user_search_space_preferences_created_at'
) THEN
CREATE INDEX ix_user_search_space_preferences_created_at
ON user_search_space_preferences(created_at);
END IF;
END$$;
"""
)
# ===== STEP 5: Migrate user LLM preferences to user_search_space_preferences =====
# For each user, create preferences for each of their search spaces
if all(
col in user_columns
for col in ["long_context_llm_id", "fast_llm_id", "strategic_llm_id"]
):
op.execute(
"""
INSERT INTO user_search_space_preferences
(user_id, search_space_id, long_context_llm_id, fast_llm_id, strategic_llm_id, created_at)
SELECT
u.id as user_id,
ss.id as search_space_id,
u.long_context_llm_id,
u.fast_llm_id,
u.strategic_llm_id,
NOW() as created_at
FROM "user" u
CROSS JOIN searchspaces ss
WHERE ss.user_id = u.id
ON CONFLICT (user_id, search_space_id) DO NOTHING
"""
)
# ===== STEP 6: Remove LLM preference columns from user table =====
# Get fresh list of foreign keys after previous operations
user_foreign_keys = [fk["name"] for fk in inspector.get_foreign_keys("user")]
# Drop foreign key constraints if they exist
if "fk_user_long_context_llm_id_llm_configs" in user_foreign_keys:
op.drop_constraint(
"fk_user_long_context_llm_id_llm_configs",
"user",
type_="foreignkey",
)
if "fk_user_fast_llm_id_llm_configs" in user_foreign_keys:
op.drop_constraint(
"fk_user_fast_llm_id_llm_configs",
"user",
type_="foreignkey",
)
if "fk_user_strategic_llm_id_llm_configs" in user_foreign_keys:
op.drop_constraint(
"fk_user_strategic_llm_id_llm_configs",
"user",
type_="foreignkey",
)
# Drop columns from user table
if "long_context_llm_id" in user_columns:
op.drop_column("user", "long_context_llm_id")
if "fast_llm_id" in user_columns:
op.drop_column("user", "fast_llm_id")
if "strategic_llm_id" in user_columns:
op.drop_column("user", "strategic_llm_id")
def downgrade() -> None:
"""
Downgrade schema back to user-owned LLM configs.
WARNING: This downgrade will result in data loss:
- LLM configs will be moved back to user ownership (first occurrence kept)
- Per-search-space user preferences will be consolidated to user level
- Additional LLM configs in search spaces beyond the first will be deleted
"""
from sqlalchemy import inspect
conn = op.get_bind()
inspector = inspect(conn)
# Get existing columns and constraints
llm_config_columns = [col["name"] for col in inspector.get_columns("llm_configs")]
user_columns = [col["name"] for col in inspector.get_columns("user")]
# ===== STEP 1: Add LLM preference columns back to user table =====
if "long_context_llm_id" not in user_columns:
op.add_column(
"user",
sa.Column("long_context_llm_id", sa.Integer(), nullable=True),
)
if "fast_llm_id" not in user_columns:
op.add_column(
"user",
sa.Column("fast_llm_id", sa.Integer(), nullable=True),
)
if "strategic_llm_id" not in user_columns:
op.add_column(
"user",
sa.Column("strategic_llm_id", sa.Integer(), nullable=True),
)
# ===== STEP 2: Migrate preferences back to user table =====
# Take the first preference for each user
op.execute(
"""
UPDATE "user" u
SET
long_context_llm_id = ussp.long_context_llm_id,
fast_llm_id = ussp.fast_llm_id,
strategic_llm_id = ussp.strategic_llm_id
FROM (
SELECT DISTINCT ON (user_id)
user_id,
long_context_llm_id,
fast_llm_id,
strategic_llm_id
FROM user_search_space_preferences
ORDER BY user_id, created_at ASC
) ussp
WHERE u.id = ussp.user_id
"""
)
# ===== STEP 3: Add foreign key constraints back to user table =====
op.create_foreign_key(
"fk_user_long_context_llm_id_llm_configs",
"user",
"llm_configs",
["long_context_llm_id"],
["id"],
ondelete="SET NULL",
)
op.create_foreign_key(
"fk_user_fast_llm_id_llm_configs",
"user",
"llm_configs",
["fast_llm_id"],
["id"],
ondelete="SET NULL",
)
op.create_foreign_key(
"fk_user_strategic_llm_id_llm_configs",
"user",
"llm_configs",
["strategic_llm_id"],
["id"],
ondelete="SET NULL",
)
# ===== STEP 4: Drop user_search_space_preferences table =====
op.execute("DROP TABLE IF EXISTS user_search_space_preferences CASCADE")
# ===== STEP 5: Add user_id back to llm_configs =====
if "user_id" not in llm_config_columns:
op.add_column(
"llm_configs",
sa.Column("user_id", postgresql.UUID(), nullable=True),
)
# Populate user_id from search_space
op.execute(
"""
UPDATE llm_configs lc
SET user_id = ss.user_id
FROM searchspaces ss
WHERE lc.search_space_id = ss.id
"""
)
# Make user_id NOT NULL
op.alter_column(
"llm_configs",
"user_id",
nullable=False,
)
# Add foreign key constraint for user_id
op.create_foreign_key(
"fk_llm_configs_user_id_user",
"llm_configs",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
# ===== STEP 6: Remove search_space_id from llm_configs =====
# Drop foreign key constraint
foreign_keys = [fk["name"] for fk in inspector.get_foreign_keys("llm_configs")]
if "fk_llm_configs_search_space_id" in foreign_keys:
op.drop_constraint(
"fk_llm_configs_search_space_id",
"llm_configs",
type_="foreignkey",
)
# Drop search_space_id column
if "search_space_id" in llm_config_columns:
op.drop_column("llm_configs", "search_space_id")

View file

@ -17,6 +17,7 @@ class Configuration:
# and when you invoke the graph
podcast_title: str
user_id: str
search_space_id: int
@classmethod
def from_runnable_config(

View file

@ -28,11 +28,12 @@ async def create_podcast_transcript(
# Get configuration from runnable config
configuration = Configuration.from_runnable_config(config)
user_id = configuration.user_id
search_space_id = configuration.search_space_id
# Get user's long context LLM
llm = await get_user_long_context_llm(state.db_session, user_id)
llm = await get_user_long_context_llm(state.db_session, user_id, search_space_id)
if not llm:
error_message = f"No long context LLM configured for user {user_id}"
error_message = f"No long context LLM configured for user {user_id} in search space {search_space_id}"
print(error_message)
raise RuntimeError(error_message)

View file

@ -577,6 +577,7 @@ async def write_answer_outline(
user_query = configuration.user_query
num_sections = configuration.num_sections
user_id = configuration.user_id
search_space_id = configuration.search_space_id
writer(
{
@ -587,9 +588,9 @@ async def write_answer_outline(
)
# Get user's strategic LLM
llm = await get_user_strategic_llm(state.db_session, user_id)
llm = await get_user_strategic_llm(state.db_session, user_id, search_space_id)
if not llm:
error_message = f"No strategic LLM configured for user {user_id}"
error_message = f"No strategic LLM configured for user {user_id} in search space {search_space_id}"
writer({"yield_value": streaming_service.format_error(error_message)})
raise RuntimeError(error_message)
@ -1854,6 +1855,7 @@ async def reformulate_user_query(
user_query=user_query,
session=state.db_session,
user_id=configuration.user_id,
search_space_id=configuration.search_space_id,
chat_history_str=chat_history_str,
)
@ -2093,6 +2095,7 @@ async def generate_further_questions(
configuration = Configuration.from_runnable_config(config)
chat_history = state.chat_history
user_id = configuration.user_id
search_space_id = configuration.search_space_id
streaming_service = state.streaming_service
# Get reranked documents from the state (will be populated by sub-agents)
@ -2107,9 +2110,9 @@ async def generate_further_questions(
)
# Get user's fast LLM
llm = await get_user_fast_llm(state.db_session, user_id)
llm = await get_user_fast_llm(state.db_session, user_id, search_space_id)
if not llm:
error_message = f"No fast LLM configured for user {user_id}"
error_message = f"No fast LLM configured for user {user_id} in search space {search_space_id}"
print(error_message)
writer({"yield_value": streaming_service.format_error(error_message)})

View file

@ -101,11 +101,12 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any
documents = state.reranked_documents
user_query = configuration.user_query
user_id = configuration.user_id
search_space_id = configuration.search_space_id
# Get user's fast LLM
llm = await get_user_fast_llm(state.db_session, user_id)
llm = await get_user_fast_llm(state.db_session, user_id, search_space_id)
if not llm:
error_message = f"No fast LLM configured for user {user_id}"
error_message = f"No fast LLM configured for user {user_id} in search space {search_space_id}"
print(error_message)
raise RuntimeError(error_message)

View file

@ -107,11 +107,12 @@ async def write_sub_section(state: State, config: RunnableConfig) -> dict[str, A
configuration = Configuration.from_runnable_config(config)
documents = state.reranked_documents
user_id = configuration.user_id
search_space_id = configuration.search_space_id
# Get user's fast LLM
llm = await get_user_fast_llm(state.db_session, user_id)
llm = await get_user_fast_llm(state.db_session, user_id, search_space_id)
if not llm:
error_message = f"No fast LLM configured for user {user_id}"
error_message = f"No fast LLM configured for user {user_id} in search space {search_space_id}"
print(error_message)
raise RuntimeError(error_message)

View file

@ -240,6 +240,17 @@ class SearchSpace(BaseModel, TimestampMixin):
order_by="SearchSourceConnector.id",
cascade="all, delete-orphan",
)
llm_configs = relationship(
"LLMConfig",
back_populates="search_space",
order_by="LLMConfig.id",
cascade="all, delete-orphan",
)
user_preferences = relationship(
"UserSearchSpacePreference",
back_populates="search_space",
cascade="all, delete-orphan",
)
class SearchSourceConnector(BaseModel, TimestampMixin):
@ -288,10 +299,54 @@ class LLMConfig(BaseModel, TimestampMixin):
# For any other parameters that litellm supports
litellm_params = Column(JSON, nullable=True, default={})
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
search_space = relationship("SearchSpace", back_populates="llm_configs")
class UserSearchSpacePreference(BaseModel, TimestampMixin):
__tablename__ = "user_search_space_preferences"
__table_args__ = (
UniqueConstraint(
"user_id",
"search_space_id",
name="uq_user_searchspace",
),
)
user_id = Column(
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
)
user = relationship("User", back_populates="llm_configs", foreign_keys=[user_id])
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
# User-specific LLM preferences for this search space
long_context_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
fast_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
strategic_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
# Future RBAC fields can be added here
# role = Column(String(50), nullable=True) # e.g., 'owner', 'editor', 'viewer'
# permissions = Column(JSON, nullable=True)
user = relationship("User", back_populates="search_space_preferences")
search_space = relationship("SearchSpace", back_populates="user_preferences")
long_context_llm = relationship(
"LLMConfig", foreign_keys=[long_context_llm_id], post_update=True
)
fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True)
strategic_llm = relationship(
"LLMConfig", foreign_keys=[strategic_llm_id], post_update=True
)
class Log(BaseModel, TimestampMixin):
@ -321,64 +376,22 @@ if config.AUTH_TYPE == "GOOGLE":
"OAuthAccount", lazy="joined"
)
search_spaces = relationship("SearchSpace", back_populates="user")
llm_configs = relationship(
"LLMConfig",
search_space_preferences = relationship(
"UserSearchSpacePreference",
back_populates="user",
foreign_keys="LLMConfig.user_id",
cascade="all, delete-orphan",
)
long_context_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
fast_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
strategic_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
long_context_llm = relationship(
"LLMConfig", foreign_keys=[long_context_llm_id], post_update=True
)
fast_llm = relationship(
"LLMConfig", foreign_keys=[fast_llm_id], post_update=True
)
strategic_llm = relationship(
"LLMConfig", foreign_keys=[strategic_llm_id], post_update=True
)
else:
class User(SQLAlchemyBaseUserTableUUID, Base):
search_spaces = relationship("SearchSpace", back_populates="user")
llm_configs = relationship(
"LLMConfig",
search_space_preferences = relationship(
"UserSearchSpacePreference",
back_populates="user",
foreign_keys="LLMConfig.user_id",
cascade="all, delete-orphan",
)
long_context_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
fast_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
strategic_llm_id = Column(
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
)
long_context_llm = relationship(
"LLMConfig", foreign_keys=[long_context_llm_id], post_update=True
)
fast_llm = relationship(
"LLMConfig", foreign_keys=[fast_llm_id], post_update=True
)
strategic_llm = relationship(
"LLMConfig", foreign_keys=[strategic_llm_id], post_update=True
)
engine = create_async_engine(DATABASE_URL)
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)

View file

@ -17,19 +17,17 @@ from app.tasks.stream_connector_search_results import stream_connector_search_re
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
from app.utils.validators import (
validate_search_space_id,
validate_document_ids,
validate_connectors,
validate_document_ids,
validate_messages,
validate_research_mode,
validate_search_mode,
validate_messages,
validate_search_space_id,
)
router = APIRouter()
@router.post("/chat")
async def handle_chat_data(
request: AISDKChatRequest,
@ -38,20 +36,22 @@ async def handle_chat_data(
):
# Validate and sanitize all input data
messages = validate_messages(request.messages)
if messages[-1]["role"] != "user":
raise HTTPException(
status_code=400, detail="Last message must be a user message"
)
user_query = messages[-1]["content"]
# Extract and validate data from request
request_data = request.data or {}
search_space_id = validate_search_space_id(request_data.get("search_space_id"))
research_mode = validate_research_mode(request_data.get("research_mode"))
selected_connectors = validate_connectors(request_data.get("selected_connectors"))
document_ids_to_add_in_context = validate_document_ids(request_data.get("document_ids_to_add_in_context"))
document_ids_to_add_in_context = validate_document_ids(
request_data.get("document_ids_to_add_in_context")
)
search_mode_str = validate_search_mode(request_data.get("search_mode"))
# Check if the search space belongs to the current user
@ -132,21 +132,16 @@ async def read_chats(
# Validate pagination parameters
if skip < 0:
raise HTTPException(
status_code=400,
detail="skip must be a non-negative integer"
status_code=400, detail="skip must be a non-negative integer"
)
if limit <= 0 or limit > 1000: # Reasonable upper limit
raise HTTPException(
status_code=400,
detail="limit must be between 1 and 1000"
)
raise HTTPException(status_code=400, detail="limit must be between 1 and 1000")
# Validate search_space_id if provided
if search_space_id is not None and search_space_id <= 0:
raise HTTPException(
status_code=400,
detail="search_space_id must be a positive integer"
status_code=400, detail="search_space_id must be a positive integer"
)
try:
# Select specific fields excluding messages

View file

@ -2,15 +2,72 @@ from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.db import LLMConfig, User, get_async_session
from app.db import (
LLMConfig,
SearchSpace,
User,
UserSearchSpacePreference,
get_async_session,
)
from app.schemas import LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
router = APIRouter()
# Helper function to check search space access
async def check_search_space_access(
session: AsyncSession, search_space_id: int, user: User
) -> SearchSpace:
"""Verify that the user has access to the search space"""
result = await session.execute(
select(SearchSpace).filter(
SearchSpace.id == search_space_id, SearchSpace.user_id == user.id
)
)
search_space = result.scalars().first()
if not search_space:
raise HTTPException(
status_code=404,
detail="Search space not found or you don't have permission to access it",
)
return search_space
# Helper function to get or create user search space preference
async def get_or_create_user_preference(
session: AsyncSession, user_id, search_space_id: int
) -> UserSearchSpacePreference:
"""Get or create user preference for a search space"""
result = await session.execute(
select(UserSearchSpacePreference)
.filter(
UserSearchSpacePreference.user_id == user_id,
UserSearchSpacePreference.search_space_id == search_space_id,
)
.options(
selectinload(UserSearchSpacePreference.long_context_llm),
selectinload(UserSearchSpacePreference.fast_llm),
selectinload(UserSearchSpacePreference.strategic_llm),
)
)
preference = result.scalars().first()
if not preference:
# Create new preference entry
preference = UserSearchSpacePreference(
user_id=user_id,
search_space_id=search_space_id,
)
session.add(preference)
await session.commit()
await session.refresh(preference)
return preference
class LLMPreferencesUpdate(BaseModel):
"""Schema for updating user LLM preferences"""
@ -36,9 +93,12 @@ async def create_llm_config(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Create a new LLM configuration for the authenticated user"""
"""Create a new LLM configuration for a search space"""
try:
db_llm_config = LLMConfig(**llm_config.model_dump(), user_id=user.id)
# Verify user has access to the search space
await check_search_space_access(session, llm_config.search_space_id, user)
db_llm_config = LLMConfig(**llm_config.model_dump())
session.add(db_llm_config)
await session.commit()
await session.refresh(db_llm_config)
@ -54,20 +114,26 @@ async def create_llm_config(
@router.get("/llm-configs/", response_model=list[LLMConfigRead])
async def read_llm_configs(
search_space_id: int,
skip: int = 0,
limit: int = 200,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Get all LLM configurations for the authenticated user"""
"""Get all LLM configurations for a search space"""
try:
# Verify user has access to the search space
await check_search_space_access(session, search_space_id, user)
result = await session.execute(
select(LLMConfig)
.filter(LLMConfig.user_id == user.id)
.filter(LLMConfig.search_space_id == search_space_id)
.offset(skip)
.limit(limit)
)
return result.scalars().all()
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to fetch LLM configurations: {e!s}"
@ -82,7 +148,18 @@ async def read_llm_config(
):
"""Get a specific LLM configuration by ID"""
try:
llm_config = await check_ownership(session, LLMConfig, llm_config_id, user)
# Get the LLM config
result = await session.execute(
select(LLMConfig).filter(LLMConfig.id == llm_config_id)
)
llm_config = result.scalars().first()
if not llm_config:
raise HTTPException(status_code=404, detail="LLM configuration not found")
# Verify user has access to the search space
await check_search_space_access(session, llm_config.search_space_id, user)
return llm_config
except HTTPException:
raise
@ -101,7 +178,18 @@ async def update_llm_config(
):
"""Update an existing LLM configuration"""
try:
db_llm_config = await check_ownership(session, LLMConfig, llm_config_id, user)
# Get the LLM config
result = await session.execute(
select(LLMConfig).filter(LLMConfig.id == llm_config_id)
)
db_llm_config = result.scalars().first()
if not db_llm_config:
raise HTTPException(status_code=404, detail="LLM configuration not found")
# Verify user has access to the search space
await check_search_space_access(session, db_llm_config.search_space_id, user)
update_data = llm_config_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
@ -127,7 +215,18 @@ async def delete_llm_config(
):
"""Delete an LLM configuration"""
try:
db_llm_config = await check_ownership(session, LLMConfig, llm_config_id, user)
# Get the LLM config
result = await session.execute(
select(LLMConfig).filter(LLMConfig.id == llm_config_id)
)
db_llm_config = result.scalars().first()
if not db_llm_config:
raise HTTPException(status_code=404, detail="LLM configuration not found")
# Verify user has access to the search space
await check_search_space_access(session, db_llm_config.search_space_id, user)
await session.delete(db_llm_config)
await session.commit()
return {"message": "LLM configuration deleted successfully"}
@ -143,99 +242,101 @@ async def delete_llm_config(
# User LLM Preferences endpoints
@router.get("/users/me/llm-preferences", response_model=LLMPreferencesRead)
@router.get(
"/search-spaces/{search_space_id}/llm-preferences",
response_model=LLMPreferencesRead,
)
async def get_user_llm_preferences(
search_space_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Get the current user's LLM preferences"""
"""Get the current user's LLM preferences for a specific search space"""
try:
# Refresh user to get latest relationships
await session.refresh(user)
# Verify user has access to the search space
await check_search_space_access(session, search_space_id, user)
result = {
"long_context_llm_id": user.long_context_llm_id,
"fast_llm_id": user.fast_llm_id,
"strategic_llm_id": user.strategic_llm_id,
"long_context_llm": None,
"fast_llm": None,
"strategic_llm": None,
# Get or create user preference for this search space
preference = await get_or_create_user_preference(
session, user.id, search_space_id
)
return {
"long_context_llm_id": preference.long_context_llm_id,
"fast_llm_id": preference.fast_llm_id,
"strategic_llm_id": preference.strategic_llm_id,
"long_context_llm": preference.long_context_llm,
"fast_llm": preference.fast_llm,
"strategic_llm": preference.strategic_llm,
}
# Fetch the actual LLM configs if they exist
if user.long_context_llm_id:
long_context_llm = await session.execute(
select(LLMConfig).filter(
LLMConfig.id == user.long_context_llm_id,
LLMConfig.user_id == user.id,
)
)
llm_config = long_context_llm.scalars().first()
if llm_config:
result["long_context_llm"] = llm_config
if user.fast_llm_id:
fast_llm = await session.execute(
select(LLMConfig).filter(
LLMConfig.id == user.fast_llm_id, LLMConfig.user_id == user.id
)
)
llm_config = fast_llm.scalars().first()
if llm_config:
result["fast_llm"] = llm_config
if user.strategic_llm_id:
strategic_llm = await session.execute(
select(LLMConfig).filter(
LLMConfig.id == user.strategic_llm_id, LLMConfig.user_id == user.id
)
)
llm_config = strategic_llm.scalars().first()
if llm_config:
result["strategic_llm"] = llm_config
return result
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to fetch LLM preferences: {e!s}"
) from e
@router.put("/users/me/llm-preferences", response_model=LLMPreferencesRead)
@router.put(
"/search-spaces/{search_space_id}/llm-preferences",
response_model=LLMPreferencesRead,
)
async def update_user_llm_preferences(
search_space_id: int,
preferences: LLMPreferencesUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Update the current user's LLM preferences"""
"""Update the current user's LLM preferences for a specific search space"""
try:
# Validate that all provided LLM config IDs belong to the user
# Verify user has access to the search space
await check_search_space_access(session, search_space_id, user)
# Get or create user preference for this search space
preference = await get_or_create_user_preference(
session, user.id, search_space_id
)
# Validate that all provided LLM config IDs belong to the search space
update_data = preferences.model_dump(exclude_unset=True)
for _key, llm_config_id in update_data.items():
if llm_config_id is not None:
# Verify ownership of the LLM config
# Verify the LLM config belongs to the search space
result = await session.execute(
select(LLMConfig).filter(
LLMConfig.id == llm_config_id, LLMConfig.user_id == user.id
LLMConfig.id == llm_config_id,
LLMConfig.search_space_id == search_space_id,
)
)
llm_config = result.scalars().first()
if not llm_config:
raise HTTPException(
status_code=404,
detail=f"LLM configuration {llm_config_id} not found or you don't have permission to access it",
detail=f"LLM configuration {llm_config_id} not found in this search space",
)
# Update user preferences
for key, value in update_data.items():
setattr(user, key, value)
setattr(preference, key, value)
await session.commit()
await session.refresh(user)
await session.refresh(preference)
# Reload relationships
await session.refresh(
preference, ["long_context_llm", "fast_llm", "strategic_llm"]
)
# Return updated preferences
return await get_user_llm_preferences(session, user)
return {
"long_context_llm_id": preference.long_context_llm_id,
"fast_llm_id": preference.fast_llm_id,
"strategic_llm_id": preference.strategic_llm_id,
"long_context_llm": preference.long_context_llm,
"fast_llm": preference.fast_llm,
"strategic_llm": preference.strategic_llm,
}
except HTTPException:
raise
except Exception as e:

View file

@ -1,4 +1,3 @@
import uuid
from datetime import datetime
from typing import Any
@ -30,7 +29,9 @@ class LLMConfigBase(BaseModel):
class LLMConfigCreate(LLMConfigBase):
pass
search_space_id: int = Field(
..., description="Search space ID to associate the LLM config with"
)
class LLMConfigUpdate(BaseModel):
@ -56,6 +57,6 @@ class LLMConfigUpdate(BaseModel):
class LLMConfigRead(LLMConfigBase, IDModel, TimestampModel):
id: int
created_at: datetime
user_id: uuid.UUID
search_space_id: int
model_config = ConfigDict(from_attributes=True)

View file

@ -24,4 +24,4 @@ class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
created_at: datetime
user_id: uuid.UUID
model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True)

View file

@ -1,10 +1,14 @@
import logging
import litellm
from langchain_litellm import ChatLiteLLM
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import LLMConfig, User
from app.db import LLMConfig, UserSearchSpacePreference
# Configure litellm to automatically drop unsupported parameters
litellm.drop_params = True
logger = logging.getLogger(__name__)
@ -16,54 +20,67 @@ class LLMRole:
async def get_user_llm_instance(
session: AsyncSession, user_id: str, role: str
session: AsyncSession, user_id: str, search_space_id: int, role: str
) -> ChatLiteLLM | None:
"""
Get a ChatLiteLLM instance for a specific user and role.
Get a ChatLiteLLM instance for a specific user, search space, and role.
Args:
session: Database session
user_id: User ID
search_space_id: Search Space ID
role: LLM role ('long_context', 'fast', or 'strategic')
Returns:
ChatLiteLLM instance or None if not found
"""
try:
# Get user with their LLM preferences
result = await session.execute(select(User).where(User.id == user_id))
user = result.scalars().first()
# Get user's LLM preferences for this search space
result = await session.execute(
select(UserSearchSpacePreference).where(
UserSearchSpacePreference.user_id == user_id,
UserSearchSpacePreference.search_space_id == search_space_id,
)
)
preference = result.scalars().first()
if not user:
logger.error(f"User {user_id} not found")
if not preference:
logger.error(
f"No LLM preferences found for user {user_id} in search space {search_space_id}"
)
return None
# Get the appropriate LLM config ID based on role
llm_config_id = None
if role == LLMRole.LONG_CONTEXT:
llm_config_id = user.long_context_llm_id
llm_config_id = preference.long_context_llm_id
elif role == LLMRole.FAST:
llm_config_id = user.fast_llm_id
llm_config_id = preference.fast_llm_id
elif role == LLMRole.STRATEGIC:
llm_config_id = user.strategic_llm_id
llm_config_id = preference.strategic_llm_id
else:
logger.error(f"Invalid LLM role: {role}")
return None
if not llm_config_id:
logger.error(f"No {role} LLM configured for user {user_id}")
logger.error(
f"No {role} LLM configured for user {user_id} in search space {search_space_id}"
)
return None
# Get the LLM configuration
result = await session.execute(
select(LLMConfig).where(
LLMConfig.id == llm_config_id, LLMConfig.user_id == user_id
LLMConfig.id == llm_config_id,
LLMConfig.search_space_id == search_space_id,
)
)
llm_config = result.scalars().first()
if not llm_config:
logger.error(f"LLM config {llm_config_id} not found for user {user_id}")
logger.error(
f"LLM config {llm_config_id} not found in search space {search_space_id}"
)
return None
# Build the model string for litellm
@ -113,19 +130,25 @@ async def get_user_llm_instance(
async def get_user_long_context_llm(
session: AsyncSession, user_id: str
session: AsyncSession, user_id: str, search_space_id: int
) -> ChatLiteLLM | None:
"""Get user's long context LLM instance."""
return await get_user_llm_instance(session, user_id, LLMRole.LONG_CONTEXT)
"""Get user's long context LLM instance for a specific search space."""
return await get_user_llm_instance(
session, user_id, search_space_id, LLMRole.LONG_CONTEXT
)
async def get_user_fast_llm(session: AsyncSession, user_id: str) -> ChatLiteLLM | None:
"""Get user's fast LLM instance."""
return await get_user_llm_instance(session, user_id, LLMRole.FAST)
async def get_user_fast_llm(
session: AsyncSession, user_id: str, search_space_id: int
) -> ChatLiteLLM | None:
"""Get user's fast LLM instance for a specific search space."""
return await get_user_llm_instance(session, user_id, search_space_id, LLMRole.FAST)
async def get_user_strategic_llm(
session: AsyncSession, user_id: str
session: AsyncSession, user_id: str, search_space_id: int
) -> ChatLiteLLM | None:
"""Get user's strategic LLM instance."""
return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC)
"""Get user's strategic LLM instance for a specific search space."""
return await get_user_llm_instance(
session, user_id, search_space_id, LLMRole.STRATEGIC
)

View file

@ -17,6 +17,7 @@ class QueryService:
user_query: str,
session: AsyncSession,
user_id: str,
search_space_id: int,
chat_history_str: str | None = None,
) -> str:
"""
@ -27,6 +28,7 @@ class QueryService:
user_query: The original user query
session: Database session for accessing user LLM configs
user_id: User ID to get their specific LLM configuration
search_space_id: Search Space ID to get user's LLM preferences
chat_history_str: Optional chat history string
Returns:
@ -37,10 +39,10 @@ class QueryService:
try:
# Get the user's strategic LLM instance
llm = await get_user_strategic_llm(session, user_id)
llm = await get_user_strategic_llm(session, user_id, search_space_id)
if not llm:
print(
f"Warning: No strategic LLM configured for user {user_id}. Using original query."
f"Warning: No strategic LLM configured for user {user_id} in search space {search_space_id}. Using original query."
)
return user_query

View file

@ -260,7 +260,9 @@ async def index_airtable_records(
continue
# Generate document summary
user_llm = await get_user_long_context_llm(session, user_id)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm:
document_metadata = {

View file

@ -222,7 +222,9 @@ async def index_clickup_tasks(
continue
# Generate summary with metadata
user_llm = await get_user_long_context_llm(session, user_id)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm:
document_metadata = {

View file

@ -233,7 +233,9 @@ async def index_confluence_pages(
continue
# Generate summary with metadata
user_llm = await get_user_long_context_llm(session, user_id)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
comment_count = len(comments)
if user_llm:

View file

@ -325,7 +325,9 @@ async def index_discord_messages(
continue
# Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if not user_llm:
logger.error(
f"No long context LLM configured for user {user_id}"

View file

@ -213,7 +213,9 @@ async def index_github_repos(
continue
# Generate summary with metadata
user_llm = await get_user_long_context_llm(session, user_id)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm:
# Extract file extension from file path
file_extension = (

View file

@ -266,7 +266,9 @@ async def index_google_calendar_events(
continue
# Generate summary with metadata
user_llm = await get_user_long_context_llm(session, user_id)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm:
document_metadata = {

View file

@ -210,7 +210,9 @@ async def index_google_gmail_messages(
continue
# Generate summary with metadata
user_llm = await get_user_long_context_llm(session, user_id)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm:
document_metadata = {

View file

@ -216,7 +216,9 @@ async def index_jira_issues(
continue
# Generate summary with metadata
user_llm = await get_user_long_context_llm(session, user_id)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
comment_count = len(formatted_issue.get("comments", []))
if user_llm:

View file

@ -228,7 +228,9 @@ async def index_linear_issues(
continue
# Generate summary with metadata
user_llm = await get_user_long_context_llm(session, user_id)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
state = formatted_issue.get("state", "Unknown")
description = formatted_issue.get("description", "")
comment_count = len(formatted_issue.get("comments", []))

View file

@ -270,7 +270,9 @@ async def index_luma_events(
continue
# Generate summary with metadata
user_llm = await get_user_long_context_llm(session, user_id)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if user_llm:
document_metadata = {

View file

@ -299,7 +299,9 @@ async def index_notion_pages(
continue
# Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id)
user_llm = await get_user_long_context_llm(
session, user_id, search_space_id
)
if not user_llm:
logger.error(f"No long context LLM configured for user {user_id}")
skipped_pages.append(f"{page_title} (no LLM configured)")

View file

@ -104,9 +104,11 @@ async def add_extension_received_document(
return existing_document
# Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id)
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
if not user_llm:
raise RuntimeError(f"No long context LLM configured for user {user_id}")
raise RuntimeError(
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
)
# Generate summary with metadata
document_metadata = {

View file

@ -60,9 +60,11 @@ async def add_received_file_document_using_unstructured(
# TODO: Check if file_markdown exceeds token limit of embedding model
# Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id)
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
if not user_llm:
raise RuntimeError(f"No long context LLM configured for user {user_id}")
raise RuntimeError(
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
)
# Generate summary with metadata
document_metadata = {
@ -140,9 +142,11 @@ async def add_received_file_document_using_llamacloud(
return existing_document
# Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id)
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
if not user_llm:
raise RuntimeError(f"No long context LLM configured for user {user_id}")
raise RuntimeError(
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
)
# Generate summary with metadata
document_metadata = {
@ -221,9 +225,11 @@ async def add_received_file_document_using_docling(
return existing_document
# Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id)
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
if not user_llm:
raise RuntimeError(f"No long context LLM configured for user {user_id}")
raise RuntimeError(
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
)
# Generate summary using chunked processing for large documents
from app.services.docling_service import create_docling_service

View file

@ -75,9 +75,11 @@ async def add_received_markdown_file_document(
return existing_document
# Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id)
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
if not user_llm:
raise RuntimeError(f"No long context LLM configured for user {user_id}")
raise RuntimeError(
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
)
# Generate summary with metadata
document_metadata = {

View file

@ -161,9 +161,11 @@ async def add_crawled_url_document(
)
# Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id)
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
if not user_llm:
raise RuntimeError(f"No long context LLM configured for user {user_id}")
raise RuntimeError(
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
)
# Generate summary
await task_logger.log_task_progress(

View file

@ -234,9 +234,11 @@ async def add_youtube_video_document(
)
# Get user's long context LLM
user_llm = await get_user_long_context_llm(session, user_id)
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
if not user_llm:
raise RuntimeError(f"No long context LLM configured for user {user_id}")
raise RuntimeError(
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
)
# Generate summary
await task_logger.log_task_progress(

View file

@ -98,6 +98,7 @@ async def generate_chat_podcast(
"configurable": {
"podcast_title": "SurfSense",
"user_id": str(user_id),
"search_space_id": search_space_id,
}
}
# Initialize state with database session and streaming service

View file

@ -16,89 +16,80 @@ from fastapi import HTTPException
def validate_search_space_id(search_space_id: Any) -> int:
"""
Validate and convert search_space_id to integer.
Args:
search_space_id: The search space ID to validate
Returns:
int: Validated search space ID
Raises:
HTTPException: If validation fails
"""
if search_space_id is None:
raise HTTPException(
status_code=400,
detail="search_space_id is required"
)
raise HTTPException(status_code=400, detail="search_space_id is required")
if isinstance(search_space_id, bool):
raise HTTPException(
status_code=400,
detail="search_space_id must be an integer, not a boolean"
status_code=400, detail="search_space_id must be an integer, not a boolean"
)
if isinstance(search_space_id, int):
if search_space_id <= 0:
raise HTTPException(
status_code=400,
detail="search_space_id must be a positive integer"
status_code=400, detail="search_space_id must be a positive integer"
)
return search_space_id
if isinstance(search_space_id, str):
# Check if it's a valid integer string
if not search_space_id.strip():
raise HTTPException(
status_code=400,
detail="search_space_id cannot be empty"
status_code=400, detail="search_space_id cannot be empty"
)
# Check for valid integer format (no leading zeros, no decimal points)
if not re.match(r'^[1-9]\d*$', search_space_id.strip()):
if not re.match(r"^[1-9]\d*$", search_space_id.strip()):
raise HTTPException(
status_code=400,
detail="search_space_id must be a valid positive integer"
detail="search_space_id must be a valid positive integer",
)
value = int(search_space_id.strip())
# Regex already guarantees value > 0, but check retained for clarity
if value <= 0:
raise HTTPException(
status_code=400,
detail="search_space_id must be a positive integer"
status_code=400, detail="search_space_id must be a positive integer"
)
return value
raise HTTPException(
status_code=400,
detail="search_space_id must be an integer or string representation of an integer"
detail="search_space_id must be an integer or string representation of an integer",
)
def validate_document_ids(document_ids: Any) -> list[int]:
"""
Validate and convert document_ids to list of integers.
Args:
document_ids: The document IDs to validate
Returns:
List[int]: Validated list of document IDs
Raises:
HTTPException: If validation fails
"""
if document_ids is None:
return []
if not isinstance(document_ids, list):
raise HTTPException(
status_code=400,
detail="document_ids_to_add_in_context must be a list"
status_code=400, detail="document_ids_to_add_in_context must be a list"
)
validated_ids = []
for i, doc_id in enumerate(document_ids):
if isinstance(doc_id, bool):
@ -111,119 +102,110 @@ def validate_document_ids(document_ids: Any) -> list[int]:
if doc_id <= 0:
raise HTTPException(
status_code=400,
detail=f"document_ids_to_add_in_context[{i}] must be a positive integer"
detail=f"document_ids_to_add_in_context[{i}] must be a positive integer",
)
validated_ids.append(doc_id)
elif isinstance(doc_id, str):
if not doc_id.strip():
raise HTTPException(
status_code=400,
detail=f"document_ids_to_add_in_context[{i}] cannot be empty"
detail=f"document_ids_to_add_in_context[{i}] cannot be empty",
)
if not re.match(r'^[1-9]\d*$', doc_id.strip()):
if not re.match(r"^[1-9]\d*$", doc_id.strip()):
raise HTTPException(
status_code=400,
detail=f"document_ids_to_add_in_context[{i}] must be a valid positive integer"
detail=f"document_ids_to_add_in_context[{i}] must be a valid positive integer",
)
value = int(doc_id.strip())
# Regex already guarantees value > 0
if value <= 0:
raise HTTPException(
status_code=400,
detail=f"document_ids_to_add_in_context[{i}] must be a positive integer"
detail=f"document_ids_to_add_in_context[{i}] must be a positive integer",
)
validated_ids.append(value)
else:
raise HTTPException(
status_code=400,
detail=f"document_ids_to_add_in_context[{i}] must be an integer or string representation of an integer"
detail=f"document_ids_to_add_in_context[{i}] must be an integer or string representation of an integer",
)
return validated_ids
def validate_connectors(connectors: Any) -> list[str]:
"""
Validate selected_connectors list.
Args:
connectors: The connectors to validate
Returns:
List[str]: Validated list of connector names
Raises:
HTTPException: If validation fails
"""
if connectors is None:
return []
if not isinstance(connectors, list):
raise HTTPException(
status_code=400,
detail="selected_connectors must be a list"
status_code=400, detail="selected_connectors must be a list"
)
validated_connectors = []
for i, connector in enumerate(connectors):
if not isinstance(connector, str):
raise HTTPException(
status_code=400,
detail=f"selected_connectors[{i}] must be a string"
status_code=400, detail=f"selected_connectors[{i}] must be a string"
)
if not connector.strip():
raise HTTPException(
status_code=400,
detail=f"selected_connectors[{i}] cannot be empty"
status_code=400, detail=f"selected_connectors[{i}] cannot be empty"
)
trimmed = connector.strip()
if not re.fullmatch(r'[\w\-_]+', trimmed):
if not re.fullmatch(r"[\w\-_]+", trimmed):
raise HTTPException(
status_code=400,
detail=f"selected_connectors[{i}] contains invalid characters"
detail=f"selected_connectors[{i}] contains invalid characters",
)
validated_connectors.append(trimmed)
return validated_connectors
def validate_research_mode(research_mode: Any) -> str:
"""
Validate research_mode parameter.
Args:
research_mode: The research mode to validate
Returns:
str: Validated research mode
Raises:
HTTPException: If validation fails
"""
if research_mode is None:
return "QNA" # Default value
if not isinstance(research_mode, str):
raise HTTPException(
status_code=400,
detail="research_mode must be a string"
)
raise HTTPException(status_code=400, detail="research_mode must be a string")
normalized_mode = research_mode.strip().upper()
if not normalized_mode:
raise HTTPException(
status_code=400,
detail="research_mode cannot be empty"
)
raise HTTPException(status_code=400, detail="research_mode cannot be empty")
valid_modes = ["REPORT_GENERAL", "REPORT_DEEP", "REPORT_DEEPER", "QNA"]
if normalized_mode not in valid_modes:
raise HTTPException(
status_code=400,
detail=f"research_mode must be one of: {', '.join(valid_modes)}"
detail=f"research_mode must be one of: {', '.join(valid_modes)}",
)
return normalized_mode
@ -231,36 +213,30 @@ def validate_research_mode(research_mode: Any) -> str:
def validate_search_mode(search_mode: Any) -> str:
"""
Validate search_mode parameter.
Args:
search_mode: The search mode to validate
Returns:
str: Validated search mode
Raises:
HTTPException: If validation fails
"""
if search_mode is None:
return "CHUNKS" # Default value
if not isinstance(search_mode, str):
raise HTTPException(
status_code=400,
detail="search_mode must be a string"
)
raise HTTPException(status_code=400, detail="search_mode must be a string")
normalized_mode = search_mode.strip().upper()
if not normalized_mode:
raise HTTPException(
status_code=400,
detail="search_mode cannot be empty"
)
raise HTTPException(status_code=400, detail="search_mode cannot be empty")
valid_modes = ["CHUNKS", "DOCUMENTS"]
if normalized_mode not in valid_modes:
raise HTTPException(
status_code=400,
detail=f"search_mode must be one of: {', '.join(valid_modes)}"
detail=f"search_mode must be one of: {', '.join(valid_modes)}",
)
return normalized_mode
@ -268,185 +244,155 @@ def validate_search_mode(search_mode: Any) -> str:
def validate_messages(messages: Any) -> list[dict]:
"""
Validate messages structure.
Args:
messages: The messages to validate
Returns:
List[dict]: Validated messages
Raises:
HTTPException: If validation fails
"""
if not isinstance(messages, list):
raise HTTPException(
status_code=400,
detail="messages must be a list"
)
raise HTTPException(status_code=400, detail="messages must be a list")
if not messages:
raise HTTPException(
status_code=400,
detail="messages cannot be empty"
)
raise HTTPException(status_code=400, detail="messages cannot be empty")
validated_messages = []
for i, message in enumerate(messages):
if not isinstance(message, dict):
raise HTTPException(
status_code=400,
detail=f"messages[{i}] must be a dictionary"
status_code=400, detail=f"messages[{i}] must be a dictionary"
)
if "role" not in message:
raise HTTPException(
status_code=400,
detail=f"messages[{i}] must have a 'role' field"
status_code=400, detail=f"messages[{i}] must have a 'role' field"
)
if "content" not in message:
raise HTTPException(
status_code=400,
detail=f"messages[{i}] must have a 'content' field"
status_code=400, detail=f"messages[{i}] must have a 'content' field"
)
role = message["role"]
if not isinstance(role, str) or role not in ["user", "assistant", "system"]:
raise HTTPException(
status_code=400,
detail=f"messages[{i}].role must be 'user', 'assistant', or 'system'"
detail=f"messages[{i}].role must be 'user', 'assistant', or 'system'",
)
content = message["content"]
if not isinstance(content, str):
raise HTTPException(
status_code=400,
detail=f"messages[{i}].content must be a string"
status_code=400, detail=f"messages[{i}].content must be a string"
)
if not content.strip():
raise HTTPException(
status_code=400,
detail=f"messages[{i}].content cannot be empty"
status_code=400, detail=f"messages[{i}].content cannot be empty"
)
# Trim content and enforce max length (10,000 chars)
sanitized_content = content.strip()
if len(sanitized_content) > 10000: # Reasonable limit
raise HTTPException(
status_code=400,
detail=f"messages[{i}].content is too long (max 10000 characters)"
detail=f"messages[{i}].content is too long (max 10000 characters)",
)
validated_messages.append({
"role": role,
"content": sanitized_content
})
validated_messages.append({"role": role, "content": sanitized_content})
return validated_messages
def validate_email(email: str) -> str:
"""
Validate email address using pyvalidators library.
Args:
email: The email address to validate
Returns:
str: Validated email address
Raises:
HTTPException: If validation fails
"""
if not email or not email.strip():
raise HTTPException(
status_code=400,
detail="Email address is required"
)
raise HTTPException(status_code=400, detail="Email address is required")
email = email.strip()
if not validators.email(email):
raise HTTPException(
status_code=400,
detail="Invalid email address format"
)
raise HTTPException(status_code=400, detail="Invalid email address format")
return email
def validate_url(url: str) -> str:
"""
Validate URL using pyvalidators library.
Args:
url: The URL to validate
Returns:
str: Validated URL
Raises:
HTTPException: If validation fails
"""
if not url or not url.strip():
raise HTTPException(
status_code=400,
detail="URL is required"
)
raise HTTPException(status_code=400, detail="URL is required")
url = url.strip()
if not validators.url(url):
raise HTTPException(
status_code=400,
detail="Invalid URL format"
)
raise HTTPException(status_code=400, detail="Invalid URL format")
return url
def validate_uuid(uuid_string: str) -> str:
"""
Validate UUID using pyvalidators library.
Args:
uuid_string: The UUID string to validate
Returns:
str: Validated UUID string
Raises:
HTTPException: If validation fails
"""
if not uuid_string or not uuid_string.strip():
raise HTTPException(
status_code=400,
detail="UUID is required"
)
raise HTTPException(status_code=400, detail="UUID is required")
uuid_string = uuid_string.strip()
if not validators.uuid(uuid_string):
raise HTTPException(
status_code=400,
detail="Invalid UUID format"
)
raise HTTPException(status_code=400, detail="Invalid UUID format")
return uuid_string
def validate_connector_config(connector_type: str | Any, config: dict[str, Any]) -> dict[str, Any]:
def validate_connector_config(
connector_type: str | Any, config: dict[str, Any]
) -> dict[str, Any]:
"""
Validate connector configuration based on connector type.
Args:
connector_type: The type of connector (string or enum)
config: The configuration dictionary to validate
Returns:
dict: Validated configuration
Raises:
ValueError: If validation fails
"""
@ -454,76 +400,69 @@ def validate_connector_config(connector_type: str | Any, config: dict[str, Any])
raise ValueError("config must be a dictionary of connector settings")
# Convert enum to string if needed
connector_type_str = str(connector_type).split('.')[-1] if hasattr(connector_type, 'value') else str(connector_type)
connector_type_str = (
str(connector_type).split(".")[-1]
if hasattr(connector_type, "value")
else str(connector_type)
)
# Validation function helpers
def validate_email_field(key: str, connector_name: str) -> None:
if not validators.email(config.get(key, "")):
raise ValueError(f"Invalid email format for {connector_name} connector")
def validate_url_field(key: str, connector_name: str) -> None:
if not validators.url(config.get(key, "")):
raise ValueError(f"Invalid base URL format for {connector_name} connector")
def validate_list_field(key: str, field_name: str) -> None:
value = config.get(key)
if not isinstance(value, list) or not value:
raise ValueError(f"{field_name} must be a non-empty list of strings")
# Lookup table for connector validation rules
connector_rules = {
"SERPER_API": {
"required": ["SERPER_API_KEY"],
"validators": {}
},
"TAVILY_API": {
"required": ["TAVILY_API_KEY"],
"validators": {}
},
"LINKUP_API": {
"required": ["LINKUP_API_KEY"],
"validators": {}
},
"SLACK_CONNECTOR": {
"required": ["SLACK_BOT_TOKEN"],
"validators": {}
},
"SERPER_API": {"required": ["SERPER_API_KEY"], "validators": {}},
"TAVILY_API": {"required": ["TAVILY_API_KEY"], "validators": {}},
"LINKUP_API": {"required": ["LINKUP_API_KEY"], "validators": {}},
"SLACK_CONNECTOR": {"required": ["SLACK_BOT_TOKEN"], "validators": {}},
"NOTION_CONNECTOR": {
"required": ["NOTION_INTEGRATION_TOKEN"],
"validators": {}
"validators": {},
},
"GITHUB_CONNECTOR": {
"required": ["GITHUB_PAT", "repo_full_names"],
"validators": {
"repo_full_names": lambda: validate_list_field("repo_full_names", "repo_full_names")
}
},
"LINEAR_CONNECTOR": {
"required": ["LINEAR_API_KEY"],
"validators": {}
},
"DISCORD_CONNECTOR": {
"required": ["DISCORD_BOT_TOKEN"],
"validators": {}
"repo_full_names": lambda: validate_list_field(
"repo_full_names", "repo_full_names"
)
},
},
"LINEAR_CONNECTOR": {"required": ["LINEAR_API_KEY"], "validators": {}},
"DISCORD_CONNECTOR": {"required": ["DISCORD_BOT_TOKEN"], "validators": {}},
"JIRA_CONNECTOR": {
"required": ["JIRA_EMAIL", "JIRA_API_TOKEN", "JIRA_BASE_URL"],
"validators": {
"JIRA_EMAIL": lambda: validate_email_field("JIRA_EMAIL", "JIRA"),
"JIRA_BASE_URL": lambda: validate_url_field("JIRA_BASE_URL", "JIRA")
}
"JIRA_BASE_URL": lambda: validate_url_field("JIRA_BASE_URL", "JIRA"),
},
},
"CONFLUENCE_CONNECTOR": {
"required": ["CONFLUENCE_BASE_URL", "CONFLUENCE_EMAIL", "CONFLUENCE_API_TOKEN"],
"required": [
"CONFLUENCE_BASE_URL",
"CONFLUENCE_EMAIL",
"CONFLUENCE_API_TOKEN",
],
"validators": {
"CONFLUENCE_EMAIL": lambda: validate_email_field("CONFLUENCE_EMAIL", "Confluence"),
"CONFLUENCE_BASE_URL": lambda: validate_url_field("CONFLUENCE_BASE_URL", "Confluence")
}
},
"CLICKUP_CONNECTOR": {
"required": ["CLICKUP_API_TOKEN"],
"validators": {}
"CONFLUENCE_EMAIL": lambda: validate_email_field(
"CONFLUENCE_EMAIL", "Confluence"
),
"CONFLUENCE_BASE_URL": lambda: validate_url_field(
"CONFLUENCE_BASE_URL", "Confluence"
),
},
},
"CLICKUP_CONNECTOR": {"required": ["CLICKUP_API_TOKEN"], "validators": {}},
# "GOOGLE_CALENDAR_CONNECTOR": {
# "required": ["token", "refresh_token", "token_uri", "client_id", "expiry", "scopes", "client_secret"],
# "validators": {},
@ -538,26 +477,23 @@ def validate_connector_config(connector_type: str | Any, config: dict[str, Any])
# "required": ["AIRTABLE_API_KEY", "AIRTABLE_BASE_ID"],
# "validators": {}
# },
"LUMA_CONNECTOR": {
"required": ["LUMA_API_KEY"],
"validators": {}
}
"LUMA_CONNECTOR": {"required": ["LUMA_API_KEY"], "validators": {}},
}
rules = connector_rules.get(connector_type_str)
if not rules:
return config # Unknown connector type, pass through
# Validate required keys match exactly
if set(config.keys()) != set(rules["required"]):
raise ValueError(
f"For {connector_type_str} connector type, config must only contain these keys: {rules['required']}"
)
# Apply custom validators first (these check format before emptiness)
for validator_func in rules["validators"].values():
validator_func()
# Validate each field is not empty
for key in rules["required"]:
# Special handling for Google connectors that don't allow None or empty strings
@ -568,5 +504,5 @@ def validate_connector_config(connector_type: str | Any, config: dict[str, Any])
# Standard check: field must have a truthy value
if not config.get(key):
raise ValueError(f"{key} cannot be empty")
return config

View file

@ -1,12 +1,16 @@
"use client";
import { Loader2 } from "lucide-react";
import { usePathname, useRouter } from "next/navigation";
import type React from "react";
import { useState } from "react";
import { useEffect, useState } from "react";
import { DashboardBreadcrumb } from "@/components/dashboard-breadcrumb";
import { AppSidebarProvider } from "@/components/sidebar/AppSidebarProvider";
import { ThemeTogglerComponent } from "@/components/theme/theme-toggle";
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
import { Separator } from "@/components/ui/separator";
import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
import { useLLMPreferences } from "@/hooks/use-llm-configs";
export function DashboardClientLayout({
children,
@ -19,6 +23,16 @@ export function DashboardClientLayout({
navSecondary: any[];
navMain: any[];
}) {
const router = useRouter();
const pathname = usePathname();
const searchSpaceIdNum = Number(searchSpaceId);
const { loading, error, isOnboardingComplete } = useLLMPreferences(searchSpaceIdNum);
const [hasCheckedOnboarding, setHasCheckedOnboarding] = useState(false);
// Skip onboarding check if we're already on the onboarding page
const isOnboardingPage = pathname?.includes("/onboard");
const [open, setOpen] = useState<boolean>(() => {
try {
const match = document.cookie.match(/(?:^|; )sidebar_state=([^;]+)/);
@ -29,6 +43,68 @@ export function DashboardClientLayout({
return true;
});
useEffect(() => {
// Skip check if already on onboarding page
if (isOnboardingPage) {
setHasCheckedOnboarding(true);
return;
}
// Only check once after preferences have loaded
if (!loading && !hasCheckedOnboarding) {
const onboardingComplete = isOnboardingComplete();
if (!onboardingComplete) {
router.push(`/dashboard/${searchSpaceId}/onboard`);
}
setHasCheckedOnboarding(true);
}
}, [
loading,
isOnboardingComplete,
isOnboardingPage,
router,
searchSpaceId,
hasCheckedOnboarding,
]);
// Show loading screen while checking onboarding status (only on first load)
if (!hasCheckedOnboarding && loading && !isOnboardingPage) {
return (
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
<Card className="w-[350px] bg-background/60 backdrop-blur-sm">
<CardHeader className="pb-2">
<CardTitle className="text-xl font-medium">Loading Configuration</CardTitle>
<CardDescription>Checking your LLM preferences...</CardDescription>
</CardHeader>
<CardContent className="flex justify-center py-6">
<Loader2 className="h-12 w-12 text-primary animate-spin" />
</CardContent>
</Card>
</div>
);
}
// Show error screen if there's an error loading preferences (but not on onboarding page)
if (error && !hasCheckedOnboarding && !isOnboardingPage) {
return (
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
<Card className="w-[400px] bg-background/60 backdrop-blur-sm border-destructive/20">
<CardHeader className="pb-2">
<CardTitle className="text-xl font-medium text-destructive">
Configuration Error
</CardTitle>
<CardDescription>Failed to load your LLM configuration</CardDescription>
</CardHeader>
<CardContent>
<p className="text-sm text-muted-foreground">{error}</p>
</CardContent>
</Card>
</div>
);
}
return (
<SidebarProvider open={open} onOpenChange={setOpen}>
{/* Use AppSidebarProvider which fetches user, search space, and recent chats */}

View file

@ -33,6 +33,12 @@ export default function DashboardLayout({
icon: "SquareTerminal",
items: [],
},
{
title: "Manage LLMs",
url: `/dashboard/${search_space_id}/settings`,
icon: "Settings2",
items: [],
},
{
title: "Documents",

View file

@ -2,7 +2,7 @@
import { ArrowLeft, ArrowRight, Bot, CheckCircle, Sparkles } from "lucide-react";
import { AnimatePresence, motion } from "motion/react";
import { useRouter } from "next/navigation";
import { useParams, useRouter } from "next/navigation";
import { useEffect, useState } from "react";
import { Logo } from "@/components/Logo";
import { AddProviderStep } from "@/components/onboard/add-provider-step";
@ -17,13 +17,16 @@ const TOTAL_STEPS = 3;
const OnboardPage = () => {
const router = useRouter();
const { llmConfigs, loading: configsLoading, refreshConfigs } = useLLMConfigs();
const params = useParams();
const searchSpaceId = Number(params.search_space_id);
const { llmConfigs, loading: configsLoading, refreshConfigs } = useLLMConfigs(searchSpaceId);
const {
preferences,
loading: preferencesLoading,
isOnboardingComplete,
refreshPreferences,
} = useLLMPreferences();
} = useLLMPreferences(searchSpaceId);
const [currentStep, setCurrentStep] = useState(1);
const [hasUserProgressed, setHasUserProgressed] = useState(false);
@ -44,11 +47,23 @@ const OnboardPage = () => {
}, [currentStep]);
// Redirect to dashboard if onboarding is already complete and user hasn't progressed (fresh page load)
// But only check once to avoid redirect loops
useEffect(() => {
if (!preferencesLoading && isOnboardingComplete() && !hasUserProgressed) {
router.push("/dashboard");
if (!preferencesLoading && !configsLoading && isOnboardingComplete() && !hasUserProgressed) {
// Small delay to ensure the check is stable
const timer = setTimeout(() => {
router.push(`/dashboard/${searchSpaceId}`);
}, 100);
return () => clearTimeout(timer);
}
}, [preferencesLoading, isOnboardingComplete, hasUserProgressed, router]);
}, [
preferencesLoading,
configsLoading,
isOnboardingComplete,
hasUserProgressed,
router,
searchSpaceId,
]);
const progress = (currentStep / TOTAL_STEPS) * 100;
@ -80,7 +95,7 @@ const OnboardPage = () => {
};
const handleComplete = () => {
router.push("/dashboard");
router.push(`/dashboard/${searchSpaceId}/documents`);
};
if (configsLoading || preferencesLoading) {
@ -184,12 +199,18 @@ const OnboardPage = () => {
>
{currentStep === 1 && (
<AddProviderStep
searchSpaceId={searchSpaceId}
onConfigCreated={refreshConfigs}
onConfigDeleted={refreshConfigs}
/>
)}
{currentStep === 2 && <AssignRolesStep onPreferencesUpdated={refreshPreferences} />}
{currentStep === 3 && <CompletionStep />}
{currentStep === 2 && (
<AssignRolesStep
searchSpaceId={searchSpaceId}
onPreferencesUpdated={refreshPreferences}
/>
)}
{currentStep === 3 && <CompletionStep searchSpaceId={searchSpaceId} />}
</motion.div>
</AnimatePresence>
</CardContent>

View file

@ -1,14 +1,16 @@
"use client";
import { ArrowLeft, Bot, Brain, Settings } from "lucide-react"; // Import ArrowLeft icon
import { useRouter } from "next/navigation"; // Add this import
import { ArrowLeft, Bot, Brain, Settings } from "lucide-react";
import { useParams, useRouter } from "next/navigation";
import { LLMRoleManager } from "@/components/settings/llm-role-manager";
import { ModelConfigManager } from "@/components/settings/model-config-manager";
import { Separator } from "@/components/ui/separator";
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
export default function SettingsPage() {
const router = useRouter(); // Initialize router
const router = useRouter();
const params = useParams();
const searchSpaceId = Number(params.search_space_id);
return (
<div className="min-h-screen bg-background">
@ -19,7 +21,7 @@ export default function SettingsPage() {
<div className="flex items-center space-x-4">
{/* Back Button */}
<button
onClick={() => router.push("/dashboard")}
onClick={() => router.push(`/dashboard/${searchSpaceId}`)}
className="flex items-center justify-center h-10 w-10 rounded-lg bg-primary/10 hover:bg-primary/20 transition-colors"
aria-label="Back to Dashboard"
type="button"
@ -32,7 +34,7 @@ export default function SettingsPage() {
<div className="space-y-1">
<h1 className="text-3xl font-bold tracking-tight">Settings</h1>
<p className="text-lg text-muted-foreground">
Manage your LLM configurations and role assignments.
Manage your LLM configurations and role assignments for this search space.
</p>
</div>
</div>
@ -57,11 +59,11 @@ export default function SettingsPage() {
</div>
<TabsContent value="models" className="space-y-6">
<ModelConfigManager />
<ModelConfigManager searchSpaceId={searchSpaceId} />
</TabsContent>
<TabsContent value="roles" className="space-y-6">
<LLMRoleManager />
<LLMRoleManager searchSpaceId={searchSpaceId} />
</TabsContent>
</Tabs>
</div>

View file

@ -4,7 +4,6 @@ import { Loader2 } from "lucide-react";
import { useRouter } from "next/navigation";
import { useEffect, useState } from "react";
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
import { useLLMPreferences } from "@/hooks/use-llm-configs";
interface DashboardLayoutProps {
children: React.ReactNode;
@ -12,7 +11,6 @@ interface DashboardLayoutProps {
export default function DashboardLayout({ children }: DashboardLayoutProps) {
const router = useRouter();
const { loading, error, isOnboardingComplete } = useLLMPreferences();
const [isCheckingAuth, setIsCheckingAuth] = useState(true);
useEffect(() => {
@ -25,23 +23,14 @@ export default function DashboardLayout({ children }: DashboardLayoutProps) {
setIsCheckingAuth(false);
}, [router]);
useEffect(() => {
// Wait for preferences to load, then check if onboarding is complete
if (!loading && !error && !isCheckingAuth) {
if (!isOnboardingComplete()) {
router.push("/onboard");
}
}
}, [loading, error, isCheckingAuth, isOnboardingComplete, router]);
// Show loading screen while checking authentication or loading preferences
if (isCheckingAuth || loading) {
// Show loading screen while checking authentication
if (isCheckingAuth) {
return (
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
<Card className="w-[350px] bg-background/60 backdrop-blur-sm">
<CardHeader className="pb-2">
<CardTitle className="text-xl font-medium">Loading Dashboard</CardTitle>
<CardDescription>Checking your configuration...</CardDescription>
<CardDescription>Checking authentication...</CardDescription>
</CardHeader>
<CardContent className="flex justify-center py-6">
<Loader2 className="h-12 w-12 text-primary animate-spin" />
@ -51,42 +40,5 @@ export default function DashboardLayout({ children }: DashboardLayoutProps) {
);
}
// Show error screen if there's an error loading preferences
if (error) {
return (
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
<Card className="w-[400px] bg-background/60 backdrop-blur-sm border-destructive/20">
<CardHeader className="pb-2">
<CardTitle className="text-xl font-medium text-destructive">
Configuration Error
</CardTitle>
<CardDescription>Failed to load your LLM configuration</CardDescription>
</CardHeader>
<CardContent>
<p className="text-sm text-muted-foreground">{error}</p>
</CardContent>
</Card>
</div>
);
}
// Only render children if onboarding is complete
if (isOnboardingComplete()) {
return <>{children}</>;
}
// This should not be reached due to redirect, but just in case
return (
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
<Card className="w-[350px] bg-background/60 backdrop-blur-sm">
<CardHeader className="pb-2">
<CardTitle className="text-xl font-medium">Redirecting...</CardTitle>
<CardDescription>Taking you to complete your setup</CardDescription>
</CardHeader>
<CardContent className="flex justify-center py-6">
<Loader2 className="h-12 w-12 text-primary animate-spin" />
</CardContent>
</Card>
</div>
);
return <>{children}</>;
}

View file

@ -66,10 +66,6 @@ export function UserDropdown({
</DropdownMenuItem>
</DropdownMenuGroup>
<DropdownMenuSeparator />
<DropdownMenuItem onClick={() => router.push(`/settings`)}>
<Settings className="mr-2 h-4 w-4" />
Settings
</DropdownMenuItem>
<DropdownMenuItem onClick={handleLogout}>
<LogOut className="mr-2 h-4 w-4" />
Log out

View file

@ -332,8 +332,11 @@ const ResearchModeSelector = React.memo(
ResearchModeSelector.displayName = "ResearchModeSelector";
const LLMSelector = React.memo(() => {
const { llmConfigs, loading: llmLoading, error } = useLLMConfigs();
const { preferences, updatePreferences, loading: preferencesLoading } = useLLMPreferences();
const { search_space_id } = useParams();
const searchSpaceId = Number(search_space_id);
const { llmConfigs, loading: llmLoading, error } = useLLMConfigs(searchSpaceId);
const { preferences, updatePreferences, loading: preferencesLoading } = useLLMPreferences(searchSpaceId);
const isLoading = llmLoading || preferencesLoading;

View file

@ -23,12 +23,17 @@ import { type CreateLLMConfig, useLLMConfigs } from "@/hooks/use-llm-configs";
import InferenceParamsEditor from "../inference-params-editor";
interface AddProviderStepProps {
searchSpaceId: number;
onConfigCreated?: () => void;
onConfigDeleted?: () => void;
}
export function AddProviderStep({ onConfigCreated, onConfigDeleted }: AddProviderStepProps) {
const { llmConfigs, createLLMConfig, deleteLLMConfig } = useLLMConfigs();
export function AddProviderStep({
searchSpaceId,
onConfigCreated,
onConfigDeleted,
}: AddProviderStepProps) {
const { llmConfigs, createLLMConfig, deleteLLMConfig } = useLLMConfigs(searchSpaceId);
const [isAddingNew, setIsAddingNew] = useState(false);
const [formData, setFormData] = useState<CreateLLMConfig>({
name: "",
@ -38,6 +43,7 @@ export function AddProviderStep({ onConfigCreated, onConfigDeleted }: AddProvide
api_key: "",
api_base: "",
litellm_params: {},
search_space_id: searchSpaceId,
});
const [isSubmitting, setIsSubmitting] = useState(false);
@ -65,6 +71,7 @@ export function AddProviderStep({ onConfigCreated, onConfigDeleted }: AddProvide
api_key: "",
api_base: "",
litellm_params: {},
search_space_id: searchSpaceId,
});
setIsAddingNew(false);
// Notify parent component that a config was created
@ -253,7 +260,6 @@ export function AddProviderStep({ onConfigCreated, onConfigDeleted }: AddProvide
/>
</div>
<div className="flex gap-2 pt-4">
<Button type="submit" disabled={isSubmitting}>
{isSubmitting ? "Adding..." : "Add Provider"}

View file

@ -41,12 +41,13 @@ const ROLE_DESCRIPTIONS = {
};
interface AssignRolesStepProps {
searchSpaceId: number;
onPreferencesUpdated?: () => Promise<void>;
}
export function AssignRolesStep({ onPreferencesUpdated }: AssignRolesStepProps) {
const { llmConfigs } = useLLMConfigs();
const { preferences, updatePreferences } = useLLMPreferences();
export function AssignRolesStep({ searchSpaceId, onPreferencesUpdated }: AssignRolesStepProps) {
const { llmConfigs } = useLLMConfigs(searchSpaceId);
const { preferences, updatePreferences } = useLLMPreferences(searchSpaceId);
const [assignments, setAssignments] = useState({
long_context_llm_id: preferences.long_context_llm_id || "",

View file

@ -12,9 +12,13 @@ const ROLE_ICONS = {
strategic: Bot,
};
export function CompletionStep() {
const { llmConfigs } = useLLMConfigs();
const { preferences } = useLLMPreferences();
interface CompletionStepProps {
searchSpaceId: number;
}
export function CompletionStep({ searchSpaceId }: CompletionStepProps) {
const { llmConfigs } = useLLMConfigs(searchSpaceId);
const { preferences } = useLLMPreferences(searchSpaceId);
const assignedConfigs = {
long_context: llmConfigs.find((c) => c.id === preferences.long_context_llm_id),

View file

@ -56,20 +56,24 @@ const ROLE_DESCRIPTIONS = {
},
};
export function LLMRoleManager() {
interface LLMRoleManagerProps {
searchSpaceId: number;
}
export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
const {
llmConfigs,
loading: configsLoading,
error: configsError,
refreshConfigs,
} = useLLMConfigs();
} = useLLMConfigs(searchSpaceId);
const {
preferences,
loading: preferencesLoading,
error: preferencesError,
updatePreferences,
refreshPreferences,
} = useLLMPreferences();
} = useLLMPreferences(searchSpaceId);
const [assignments, setAssignments] = useState({
long_context_llm_id: preferences.long_context_llm_id || "",

View file

@ -41,7 +41,11 @@ import { LLM_PROVIDERS } from "@/contracts/enums/llm-providers";
import { type CreateLLMConfig, type LLMConfig, useLLMConfigs } from "@/hooks/use-llm-configs";
import InferenceParamsEditor from "../inference-params-editor";
export function ModelConfigManager() {
interface ModelConfigManagerProps {
searchSpaceId: number;
}
export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) {
const {
llmConfigs,
loading,
@ -50,7 +54,7 @@ export function ModelConfigManager() {
updateLLMConfig,
deleteLLMConfig,
refreshConfigs,
} = useLLMConfigs();
} = useLLMConfigs(searchSpaceId);
const [isAddingNew, setIsAddingNew] = useState(false);
const [editingConfig, setEditingConfig] = useState<LLMConfig | null>(null);
const [showApiKey, setShowApiKey] = useState<Record<number, boolean>>({});
@ -62,6 +66,7 @@ export function ModelConfigManager() {
api_key: "",
api_base: "",
litellm_params: {},
search_space_id: searchSpaceId,
});
const [isSubmitting, setIsSubmitting] = useState(false);
@ -76,9 +81,10 @@ export function ModelConfigManager() {
api_key: editingConfig.api_key,
api_base: editingConfig.api_base || "",
litellm_params: editingConfig.litellm_params || {},
search_space_id: searchSpaceId,
});
}
}, [editingConfig]);
}, [editingConfig, searchSpaceId]);
const handleInputChange = (field: keyof CreateLLMConfig, value: string) => {
setFormData((prev) => ({ ...prev, [field]: value }));
@ -113,6 +119,7 @@ export function ModelConfigManager() {
api_key: "",
api_base: "",
litellm_params: {},
search_space_id: searchSpaceId,
});
setIsAddingNew(false);
setEditingConfig(null);
@ -426,6 +433,7 @@ export function ModelConfigManager() {
api_key: "",
api_base: "",
litellm_params: {},
search_space_id: searchSpaceId,
});
}
}}
@ -462,18 +470,12 @@ export function ModelConfigManager() {
value={formData.provider}
onValueChange={(value) => handleInputChange("provider", value)}
>
<SelectTrigger className="h-auto min-h-[2.5rem] py-2">
<SelectTrigger>
<SelectValue placeholder="Select a provider">
{formData.provider && (
<div className="flex items-center space-x-2 py-1">
<div className="font-medium">
{LLM_PROVIDERS.find((p) => p.value === formData.provider)?.label}
</div>
<div className="text-xs text-muted-foreground"></div>
<div className="text-xs text-muted-foreground">
{LLM_PROVIDERS.find((p) => p.value === formData.provider)?.description}
</div>
</div>
<span className="font-medium">
{LLM_PROVIDERS.find((p) => p.value === formData.provider)?.label}
</span>
)}
</SelectValue>
</SelectTrigger>
@ -549,7 +551,7 @@ export function ModelConfigManager() {
<InferenceParamsEditor
params={formData.litellm_params || {}}
setParams={(newParams) =>
setFormData((prev) => ({ ...prev, litellm_params: newParams }))
setFormData((prev) => ({ ...prev, litellm_params: newParams }))
}
/>
</div>
@ -578,6 +580,7 @@ export function ModelConfigManager() {
api_key: "",
api_base: "",
litellm_params: {},
search_space_id: searchSpaceId,
});
}}
disabled={isSubmitting}

View file

@ -12,7 +12,7 @@ export interface LLMConfig {
api_base?: string;
litellm_params?: Record<string, any>;
created_at: string;
user_id: string;
search_space_id: number;
}
export interface LLMPreferences {
@ -32,6 +32,7 @@ export interface CreateLLMConfig {
api_key: string;
api_base?: string;
litellm_params?: Record<string, any>;
search_space_id: number;
}
export interface UpdateLLMConfig {
@ -44,16 +45,21 @@ export interface UpdateLLMConfig {
litellm_params?: Record<string, any>;
}
export function useLLMConfigs() {
export function useLLMConfigs(searchSpaceId: number | null) {
const [llmConfigs, setLlmConfigs] = useState<LLMConfig[]>([]);
const [loading, setLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
const fetchLLMConfigs = async () => {
if (!searchSpaceId) {
setLoading(false);
return;
}
try {
setLoading(true);
const response = await fetch(
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/llm-configs/`,
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/llm-configs/?search_space_id=${searchSpaceId}`,
{
headers: {
Authorization: `Bearer ${localStorage.getItem("surfsense_bearer_token")}`,
@ -79,7 +85,7 @@ export function useLLMConfigs() {
useEffect(() => {
fetchLLMConfigs();
}, []);
}, [searchSpaceId]);
const createLLMConfig = async (config: CreateLLMConfig): Promise<LLMConfig | null> => {
try {
@ -181,16 +187,21 @@ export function useLLMConfigs() {
};
}
export function useLLMPreferences() {
export function useLLMPreferences(searchSpaceId: number | null) {
const [preferences, setPreferences] = useState<LLMPreferences>({});
const [loading, setLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
const fetchPreferences = async () => {
if (!searchSpaceId) {
setLoading(false);
return;
}
try {
setLoading(true);
const response = await fetch(
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/users/me/llm-preferences`,
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/llm-preferences`,
{
headers: {
Authorization: `Bearer ${localStorage.getItem("surfsense_bearer_token")}`,
@ -216,12 +227,17 @@ export function useLLMPreferences() {
useEffect(() => {
fetchPreferences();
}, []);
}, [searchSpaceId]);
const updatePreferences = async (newPreferences: Partial<LLMPreferences>): Promise<boolean> => {
if (!searchSpaceId) {
toast.error("Search space ID is required");
return false;
}
try {
const response = await fetch(
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/users/me/llm-preferences`,
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/llm-preferences`,
{
method: "PUT",
headers: {