mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
feat: moved LLMConfigs from User to SearchSpaces
- RBAC soon?? - Updated various services and routes to handle search space-specific LLM preferences. - Modified frontend components to pass search space ID for LLM configuration management. - Removed onboarding page and settings page as part of the refactor.
This commit is contained in:
parent
a1b1db3895
commit
633ea3ac0f
44 changed files with 1075 additions and 518 deletions
|
|
@ -0,0 +1,352 @@
|
|||
"""Migrate LLM configs to search spaces and add user preferences
|
||||
|
||||
Revision ID: 25
|
||||
Revises: 24
|
||||
Create Date: 2025-01-10 14:00:00.000000
|
||||
|
||||
Changes:
|
||||
1. Migrate llm_configs from user association to search_space association
|
||||
2. Create user_search_space_preferences table for per-user LLM preferences
|
||||
3. Migrate existing user LLM preferences to user_search_space_preferences
|
||||
4. Remove LLM preference columns from user table
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "25"
|
||||
down_revision: str | None = "24"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""
|
||||
Upgrade schema to support collaborative search spaces with per-user preferences.
|
||||
|
||||
Migration steps:
|
||||
1. Add search_space_id to llm_configs
|
||||
2. Migrate existing llm_configs to first search space of their user
|
||||
3. Replace user_id with search_space_id in llm_configs
|
||||
4. Create user_search_space_preferences table
|
||||
5. Migrate user LLM preferences to user_search_space_preferences
|
||||
6. Remove LLM preference columns from user table
|
||||
"""
|
||||
|
||||
from sqlalchemy import inspect
|
||||
|
||||
conn = op.get_bind()
|
||||
inspector = inspect(conn)
|
||||
|
||||
# Get existing columns
|
||||
llm_config_columns = [col["name"] for col in inspector.get_columns("llm_configs")]
|
||||
user_columns = [col["name"] for col in inspector.get_columns("user")]
|
||||
|
||||
# ===== STEP 1: Add search_space_id to llm_configs =====
|
||||
if "search_space_id" not in llm_config_columns:
|
||||
op.add_column(
|
||||
"llm_configs",
|
||||
sa.Column("search_space_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
# ===== STEP 2: Populate search_space_id with user's first search space =====
|
||||
# This ensures existing LLM configs are assigned to a valid search space
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE llm_configs lc
|
||||
SET search_space_id = (
|
||||
SELECT id
|
||||
FROM searchspaces ss
|
||||
WHERE ss.user_id = lc.user_id
|
||||
ORDER BY ss.created_at ASC
|
||||
LIMIT 1
|
||||
)
|
||||
WHERE search_space_id IS NULL AND user_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# ===== STEP 3: Make search_space_id NOT NULL and add FK constraint =====
|
||||
op.alter_column(
|
||||
"llm_configs",
|
||||
"search_space_id",
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Add foreign key constraint
|
||||
foreign_keys = [fk["name"] for fk in inspector.get_foreign_keys("llm_configs")]
|
||||
if "fk_llm_configs_search_space_id" not in foreign_keys:
|
||||
op.create_foreign_key(
|
||||
"fk_llm_configs_search_space_id",
|
||||
"llm_configs",
|
||||
"searchspaces",
|
||||
["search_space_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# Drop old user_id foreign key if it exists
|
||||
if "fk_llm_configs_user_id_user" in foreign_keys:
|
||||
op.drop_constraint(
|
||||
"fk_llm_configs_user_id_user",
|
||||
"llm_configs",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Remove user_id column
|
||||
if "user_id" in llm_config_columns:
|
||||
op.drop_column("llm_configs", "user_id")
|
||||
|
||||
# ===== STEP 4: Create user_search_space_preferences table =====
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_name = 'user_search_space_preferences'
|
||||
) THEN
|
||||
CREATE TABLE user_search_space_preferences (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
user_id UUID NOT NULL REFERENCES "user"(id) ON DELETE CASCADE,
|
||||
search_space_id INTEGER NOT NULL REFERENCES searchspaces(id) ON DELETE CASCADE,
|
||||
long_context_llm_id INTEGER REFERENCES llm_configs(id) ON DELETE SET NULL,
|
||||
fast_llm_id INTEGER REFERENCES llm_configs(id) ON DELETE SET NULL,
|
||||
strategic_llm_id INTEGER REFERENCES llm_configs(id) ON DELETE SET NULL,
|
||||
CONSTRAINT uq_user_searchspace UNIQUE (user_id, search_space_id)
|
||||
);
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
# Create indexes
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_indexes
|
||||
WHERE tablename = 'user_search_space_preferences'
|
||||
AND indexname = 'ix_user_search_space_preferences_id'
|
||||
) THEN
|
||||
CREATE INDEX ix_user_search_space_preferences_id
|
||||
ON user_search_space_preferences(id);
|
||||
END IF;
|
||||
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_indexes
|
||||
WHERE tablename = 'user_search_space_preferences'
|
||||
AND indexname = 'ix_user_search_space_preferences_created_at'
|
||||
) THEN
|
||||
CREATE INDEX ix_user_search_space_preferences_created_at
|
||||
ON user_search_space_preferences(created_at);
|
||||
END IF;
|
||||
END$$;
|
||||
"""
|
||||
)
|
||||
|
||||
# ===== STEP 5: Migrate user LLM preferences to user_search_space_preferences =====
|
||||
# For each user, create preferences for each of their search spaces
|
||||
if all(
|
||||
col in user_columns
|
||||
for col in ["long_context_llm_id", "fast_llm_id", "strategic_llm_id"]
|
||||
):
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO user_search_space_preferences
|
||||
(user_id, search_space_id, long_context_llm_id, fast_llm_id, strategic_llm_id, created_at)
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
ss.id as search_space_id,
|
||||
u.long_context_llm_id,
|
||||
u.fast_llm_id,
|
||||
u.strategic_llm_id,
|
||||
NOW() as created_at
|
||||
FROM "user" u
|
||||
CROSS JOIN searchspaces ss
|
||||
WHERE ss.user_id = u.id
|
||||
ON CONFLICT (user_id, search_space_id) DO NOTHING
|
||||
"""
|
||||
)
|
||||
|
||||
# ===== STEP 6: Remove LLM preference columns from user table =====
|
||||
# Get fresh list of foreign keys after previous operations
|
||||
user_foreign_keys = [fk["name"] for fk in inspector.get_foreign_keys("user")]
|
||||
|
||||
# Drop foreign key constraints if they exist
|
||||
if "fk_user_long_context_llm_id_llm_configs" in user_foreign_keys:
|
||||
op.drop_constraint(
|
||||
"fk_user_long_context_llm_id_llm_configs",
|
||||
"user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
if "fk_user_fast_llm_id_llm_configs" in user_foreign_keys:
|
||||
op.drop_constraint(
|
||||
"fk_user_fast_llm_id_llm_configs",
|
||||
"user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
if "fk_user_strategic_llm_id_llm_configs" in user_foreign_keys:
|
||||
op.drop_constraint(
|
||||
"fk_user_strategic_llm_id_llm_configs",
|
||||
"user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Drop columns from user table
|
||||
if "long_context_llm_id" in user_columns:
|
||||
op.drop_column("user", "long_context_llm_id")
|
||||
|
||||
if "fast_llm_id" in user_columns:
|
||||
op.drop_column("user", "fast_llm_id")
|
||||
|
||||
if "strategic_llm_id" in user_columns:
|
||||
op.drop_column("user", "strategic_llm_id")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""
|
||||
Downgrade schema back to user-owned LLM configs.
|
||||
|
||||
WARNING: This downgrade will result in data loss:
|
||||
- LLM configs will be moved back to user ownership (first occurrence kept)
|
||||
- Per-search-space user preferences will be consolidated to user level
|
||||
- Additional LLM configs in search spaces beyond the first will be deleted
|
||||
"""
|
||||
|
||||
from sqlalchemy import inspect
|
||||
|
||||
conn = op.get_bind()
|
||||
inspector = inspect(conn)
|
||||
|
||||
# Get existing columns and constraints
|
||||
llm_config_columns = [col["name"] for col in inspector.get_columns("llm_configs")]
|
||||
user_columns = [col["name"] for col in inspector.get_columns("user")]
|
||||
|
||||
# ===== STEP 1: Add LLM preference columns back to user table =====
|
||||
if "long_context_llm_id" not in user_columns:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("long_context_llm_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
if "fast_llm_id" not in user_columns:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("fast_llm_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
if "strategic_llm_id" not in user_columns:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("strategic_llm_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
# ===== STEP 2: Migrate preferences back to user table =====
|
||||
# Take the first preference for each user
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user" u
|
||||
SET
|
||||
long_context_llm_id = ussp.long_context_llm_id,
|
||||
fast_llm_id = ussp.fast_llm_id,
|
||||
strategic_llm_id = ussp.strategic_llm_id
|
||||
FROM (
|
||||
SELECT DISTINCT ON (user_id)
|
||||
user_id,
|
||||
long_context_llm_id,
|
||||
fast_llm_id,
|
||||
strategic_llm_id
|
||||
FROM user_search_space_preferences
|
||||
ORDER BY user_id, created_at ASC
|
||||
) ussp
|
||||
WHERE u.id = ussp.user_id
|
||||
"""
|
||||
)
|
||||
|
||||
# ===== STEP 3: Add foreign key constraints back to user table =====
|
||||
op.create_foreign_key(
|
||||
"fk_user_long_context_llm_id_llm_configs",
|
||||
"user",
|
||||
"llm_configs",
|
||||
["long_context_llm_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"fk_user_fast_llm_id_llm_configs",
|
||||
"user",
|
||||
"llm_configs",
|
||||
["fast_llm_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"fk_user_strategic_llm_id_llm_configs",
|
||||
"user",
|
||||
"llm_configs",
|
||||
["strategic_llm_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
# ===== STEP 4: Drop user_search_space_preferences table =====
|
||||
op.execute("DROP TABLE IF EXISTS user_search_space_preferences CASCADE")
|
||||
|
||||
# ===== STEP 5: Add user_id back to llm_configs =====
|
||||
if "user_id" not in llm_config_columns:
|
||||
op.add_column(
|
||||
"llm_configs",
|
||||
sa.Column("user_id", postgresql.UUID(), nullable=True),
|
||||
)
|
||||
|
||||
# Populate user_id from search_space
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE llm_configs lc
|
||||
SET user_id = ss.user_id
|
||||
FROM searchspaces ss
|
||||
WHERE lc.search_space_id = ss.id
|
||||
"""
|
||||
)
|
||||
|
||||
# Make user_id NOT NULL
|
||||
op.alter_column(
|
||||
"llm_configs",
|
||||
"user_id",
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Add foreign key constraint for user_id
|
||||
op.create_foreign_key(
|
||||
"fk_llm_configs_user_id_user",
|
||||
"llm_configs",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# ===== STEP 6: Remove search_space_id from llm_configs =====
|
||||
# Drop foreign key constraint
|
||||
foreign_keys = [fk["name"] for fk in inspector.get_foreign_keys("llm_configs")]
|
||||
if "fk_llm_configs_search_space_id" in foreign_keys:
|
||||
op.drop_constraint(
|
||||
"fk_llm_configs_search_space_id",
|
||||
"llm_configs",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Drop search_space_id column
|
||||
if "search_space_id" in llm_config_columns:
|
||||
op.drop_column("llm_configs", "search_space_id")
|
||||
|
|
@ -17,6 +17,7 @@ class Configuration:
|
|||
# and when you invoke the graph
|
||||
podcast_title: str
|
||||
user_id: str
|
||||
search_space_id: int
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
|
|
|
|||
|
|
@ -28,11 +28,12 @@ async def create_podcast_transcript(
|
|||
# Get configuration from runnable config
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
user_id = configuration.user_id
|
||||
search_space_id = configuration.search_space_id
|
||||
|
||||
# Get user's long context LLM
|
||||
llm = await get_user_long_context_llm(state.db_session, user_id)
|
||||
llm = await get_user_long_context_llm(state.db_session, user_id, search_space_id)
|
||||
if not llm:
|
||||
error_message = f"No long context LLM configured for user {user_id}"
|
||||
error_message = f"No long context LLM configured for user {user_id} in search space {search_space_id}"
|
||||
print(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
|
|
|
|||
|
|
@ -577,6 +577,7 @@ async def write_answer_outline(
|
|||
user_query = configuration.user_query
|
||||
num_sections = configuration.num_sections
|
||||
user_id = configuration.user_id
|
||||
search_space_id = configuration.search_space_id
|
||||
|
||||
writer(
|
||||
{
|
||||
|
|
@ -587,9 +588,9 @@ async def write_answer_outline(
|
|||
)
|
||||
|
||||
# Get user's strategic LLM
|
||||
llm = await get_user_strategic_llm(state.db_session, user_id)
|
||||
llm = await get_user_strategic_llm(state.db_session, user_id, search_space_id)
|
||||
if not llm:
|
||||
error_message = f"No strategic LLM configured for user {user_id}"
|
||||
error_message = f"No strategic LLM configured for user {user_id} in search space {search_space_id}"
|
||||
writer({"yield_value": streaming_service.format_error(error_message)})
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
|
|
@ -1854,6 +1855,7 @@ async def reformulate_user_query(
|
|||
user_query=user_query,
|
||||
session=state.db_session,
|
||||
user_id=configuration.user_id,
|
||||
search_space_id=configuration.search_space_id,
|
||||
chat_history_str=chat_history_str,
|
||||
)
|
||||
|
||||
|
|
@ -2093,6 +2095,7 @@ async def generate_further_questions(
|
|||
configuration = Configuration.from_runnable_config(config)
|
||||
chat_history = state.chat_history
|
||||
user_id = configuration.user_id
|
||||
search_space_id = configuration.search_space_id
|
||||
streaming_service = state.streaming_service
|
||||
|
||||
# Get reranked documents from the state (will be populated by sub-agents)
|
||||
|
|
@ -2107,9 +2110,9 @@ async def generate_further_questions(
|
|||
)
|
||||
|
||||
# Get user's fast LLM
|
||||
llm = await get_user_fast_llm(state.db_session, user_id)
|
||||
llm = await get_user_fast_llm(state.db_session, user_id, search_space_id)
|
||||
if not llm:
|
||||
error_message = f"No fast LLM configured for user {user_id}"
|
||||
error_message = f"No fast LLM configured for user {user_id} in search space {search_space_id}"
|
||||
print(error_message)
|
||||
writer({"yield_value": streaming_service.format_error(error_message)})
|
||||
|
||||
|
|
|
|||
|
|
@ -101,11 +101,12 @@ async def answer_question(state: State, config: RunnableConfig) -> dict[str, Any
|
|||
documents = state.reranked_documents
|
||||
user_query = configuration.user_query
|
||||
user_id = configuration.user_id
|
||||
search_space_id = configuration.search_space_id
|
||||
|
||||
# Get user's fast LLM
|
||||
llm = await get_user_fast_llm(state.db_session, user_id)
|
||||
llm = await get_user_fast_llm(state.db_session, user_id, search_space_id)
|
||||
if not llm:
|
||||
error_message = f"No fast LLM configured for user {user_id}"
|
||||
error_message = f"No fast LLM configured for user {user_id} in search space {search_space_id}"
|
||||
print(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
|
|
|
|||
|
|
@ -107,11 +107,12 @@ async def write_sub_section(state: State, config: RunnableConfig) -> dict[str, A
|
|||
configuration = Configuration.from_runnable_config(config)
|
||||
documents = state.reranked_documents
|
||||
user_id = configuration.user_id
|
||||
search_space_id = configuration.search_space_id
|
||||
|
||||
# Get user's fast LLM
|
||||
llm = await get_user_fast_llm(state.db_session, user_id)
|
||||
llm = await get_user_fast_llm(state.db_session, user_id, search_space_id)
|
||||
if not llm:
|
||||
error_message = f"No fast LLM configured for user {user_id}"
|
||||
error_message = f"No fast LLM configured for user {user_id} in search space {search_space_id}"
|
||||
print(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
|
|
|
|||
|
|
@ -240,6 +240,17 @@ class SearchSpace(BaseModel, TimestampMixin):
|
|||
order_by="SearchSourceConnector.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
llm_configs = relationship(
|
||||
"LLMConfig",
|
||||
back_populates="search_space",
|
||||
order_by="LLMConfig.id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
user_preferences = relationship(
|
||||
"UserSearchSpacePreference",
|
||||
back_populates="search_space",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class SearchSourceConnector(BaseModel, TimestampMixin):
|
||||
|
|
@ -288,10 +299,54 @@ class LLMConfig(BaseModel, TimestampMixin):
|
|||
# For any other parameters that litellm supports
|
||||
litellm_params = Column(JSON, nullable=True, default={})
|
||||
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="llm_configs")
|
||||
|
||||
|
||||
class UserSearchSpacePreference(BaseModel, TimestampMixin):
|
||||
__tablename__ = "user_search_space_preferences"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"user_id",
|
||||
"search_space_id",
|
||||
name="uq_user_searchspace",
|
||||
),
|
||||
)
|
||||
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
user = relationship("User", back_populates="llm_configs", foreign_keys=[user_id])
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# User-specific LLM preferences for this search space
|
||||
long_context_llm_id = Column(
|
||||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
fast_llm_id = Column(
|
||||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
strategic_llm_id = Column(
|
||||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
|
||||
# Future RBAC fields can be added here
|
||||
# role = Column(String(50), nullable=True) # e.g., 'owner', 'editor', 'viewer'
|
||||
# permissions = Column(JSON, nullable=True)
|
||||
|
||||
user = relationship("User", back_populates="search_space_preferences")
|
||||
search_space = relationship("SearchSpace", back_populates="user_preferences")
|
||||
|
||||
long_context_llm = relationship(
|
||||
"LLMConfig", foreign_keys=[long_context_llm_id], post_update=True
|
||||
)
|
||||
fast_llm = relationship("LLMConfig", foreign_keys=[fast_llm_id], post_update=True)
|
||||
strategic_llm = relationship(
|
||||
"LLMConfig", foreign_keys=[strategic_llm_id], post_update=True
|
||||
)
|
||||
|
||||
|
||||
class Log(BaseModel, TimestampMixin):
|
||||
|
|
@ -321,64 +376,22 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
"OAuthAccount", lazy="joined"
|
||||
)
|
||||
search_spaces = relationship("SearchSpace", back_populates="user")
|
||||
llm_configs = relationship(
|
||||
"LLMConfig",
|
||||
search_space_preferences = relationship(
|
||||
"UserSearchSpacePreference",
|
||||
back_populates="user",
|
||||
foreign_keys="LLMConfig.user_id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
long_context_llm_id = Column(
|
||||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
fast_llm_id = Column(
|
||||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
strategic_llm_id = Column(
|
||||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
|
||||
long_context_llm = relationship(
|
||||
"LLMConfig", foreign_keys=[long_context_llm_id], post_update=True
|
||||
)
|
||||
fast_llm = relationship(
|
||||
"LLMConfig", foreign_keys=[fast_llm_id], post_update=True
|
||||
)
|
||||
strategic_llm = relationship(
|
||||
"LLMConfig", foreign_keys=[strategic_llm_id], post_update=True
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
search_spaces = relationship("SearchSpace", back_populates="user")
|
||||
llm_configs = relationship(
|
||||
"LLMConfig",
|
||||
search_space_preferences = relationship(
|
||||
"UserSearchSpacePreference",
|
||||
back_populates="user",
|
||||
foreign_keys="LLMConfig.user_id",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
long_context_llm_id = Column(
|
||||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
fast_llm_id = Column(
|
||||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
strategic_llm_id = Column(
|
||||
Integer, ForeignKey("llm_configs.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
|
||||
long_context_llm = relationship(
|
||||
"LLMConfig", foreign_keys=[long_context_llm_id], post_update=True
|
||||
)
|
||||
fast_llm = relationship(
|
||||
"LLMConfig", foreign_keys=[fast_llm_id], post_update=True
|
||||
)
|
||||
strategic_llm = relationship(
|
||||
"LLMConfig", foreign_keys=[strategic_llm_id], post_update=True
|
||||
)
|
||||
|
||||
|
||||
engine = create_async_engine(DATABASE_URL)
|
||||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
|
|
|||
|
|
@ -17,19 +17,17 @@ from app.tasks.stream_connector_search_results import stream_connector_search_re
|
|||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
from app.utils.validators import (
|
||||
validate_search_space_id,
|
||||
validate_document_ids,
|
||||
validate_connectors,
|
||||
validate_document_ids,
|
||||
validate_messages,
|
||||
validate_research_mode,
|
||||
validate_search_mode,
|
||||
validate_messages,
|
||||
validate_search_space_id,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def handle_chat_data(
|
||||
request: AISDKChatRequest,
|
||||
|
|
@ -38,20 +36,22 @@ async def handle_chat_data(
|
|||
):
|
||||
# Validate and sanitize all input data
|
||||
messages = validate_messages(request.messages)
|
||||
|
||||
|
||||
if messages[-1]["role"] != "user":
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Last message must be a user message"
|
||||
)
|
||||
|
||||
user_query = messages[-1]["content"]
|
||||
|
||||
|
||||
# Extract and validate data from request
|
||||
request_data = request.data or {}
|
||||
search_space_id = validate_search_space_id(request_data.get("search_space_id"))
|
||||
research_mode = validate_research_mode(request_data.get("research_mode"))
|
||||
selected_connectors = validate_connectors(request_data.get("selected_connectors"))
|
||||
document_ids_to_add_in_context = validate_document_ids(request_data.get("document_ids_to_add_in_context"))
|
||||
document_ids_to_add_in_context = validate_document_ids(
|
||||
request_data.get("document_ids_to_add_in_context")
|
||||
)
|
||||
search_mode_str = validate_search_mode(request_data.get("search_mode"))
|
||||
|
||||
# Check if the search space belongs to the current user
|
||||
|
|
@ -132,21 +132,16 @@ async def read_chats(
|
|||
# Validate pagination parameters
|
||||
if skip < 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="skip must be a non-negative integer"
|
||||
status_code=400, detail="skip must be a non-negative integer"
|
||||
)
|
||||
|
||||
|
||||
if limit <= 0 or limit > 1000: # Reasonable upper limit
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="limit must be between 1 and 1000"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail="limit must be between 1 and 1000")
|
||||
|
||||
# Validate search_space_id if provided
|
||||
if search_space_id is not None and search_space_id <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="search_space_id must be a positive integer"
|
||||
status_code=400, detail="search_space_id must be a positive integer"
|
||||
)
|
||||
try:
|
||||
# Select specific fields excluding messages
|
||||
|
|
|
|||
|
|
@ -2,15 +2,72 @@ from fastapi import APIRouter, Depends, HTTPException
|
|||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.db import LLMConfig, User, get_async_session
|
||||
from app.db import (
|
||||
LLMConfig,
|
||||
SearchSpace,
|
||||
User,
|
||||
UserSearchSpacePreference,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas import LLMConfigCreate, LLMConfigRead, LLMConfigUpdate
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Helper function to check search space access
|
||||
async def check_search_space_access(
|
||||
session: AsyncSession, search_space_id: int, user: User
|
||||
) -> SearchSpace:
|
||||
"""Verify that the user has access to the search space"""
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(
|
||||
SearchSpace.id == search_space_id, SearchSpace.user_id == user.id
|
||||
)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Search space not found or you don't have permission to access it",
|
||||
)
|
||||
return search_space
|
||||
|
||||
|
||||
# Helper function to get or create user search space preference
|
||||
async def get_or_create_user_preference(
|
||||
session: AsyncSession, user_id, search_space_id: int
|
||||
) -> UserSearchSpacePreference:
|
||||
"""Get or create user preference for a search space"""
|
||||
result = await session.execute(
|
||||
select(UserSearchSpacePreference)
|
||||
.filter(
|
||||
UserSearchSpacePreference.user_id == user_id,
|
||||
UserSearchSpacePreference.search_space_id == search_space_id,
|
||||
)
|
||||
.options(
|
||||
selectinload(UserSearchSpacePreference.long_context_llm),
|
||||
selectinload(UserSearchSpacePreference.fast_llm),
|
||||
selectinload(UserSearchSpacePreference.strategic_llm),
|
||||
)
|
||||
)
|
||||
preference = result.scalars().first()
|
||||
|
||||
if not preference:
|
||||
# Create new preference entry
|
||||
preference = UserSearchSpacePreference(
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
session.add(preference)
|
||||
await session.commit()
|
||||
await session.refresh(preference)
|
||||
|
||||
return preference
|
||||
|
||||
|
||||
class LLMPreferencesUpdate(BaseModel):
|
||||
"""Schema for updating user LLM preferences"""
|
||||
|
||||
|
|
@ -36,9 +93,12 @@ async def create_llm_config(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Create a new LLM configuration for the authenticated user"""
|
||||
"""Create a new LLM configuration for a search space"""
|
||||
try:
|
||||
db_llm_config = LLMConfig(**llm_config.model_dump(), user_id=user.id)
|
||||
# Verify user has access to the search space
|
||||
await check_search_space_access(session, llm_config.search_space_id, user)
|
||||
|
||||
db_llm_config = LLMConfig(**llm_config.model_dump())
|
||||
session.add(db_llm_config)
|
||||
await session.commit()
|
||||
await session.refresh(db_llm_config)
|
||||
|
|
@ -54,20 +114,26 @@ async def create_llm_config(
|
|||
|
||||
@router.get("/llm-configs/", response_model=list[LLMConfigRead])
|
||||
async def read_llm_configs(
|
||||
search_space_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 200,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get all LLM configurations for the authenticated user"""
|
||||
"""Get all LLM configurations for a search space"""
|
||||
try:
|
||||
# Verify user has access to the search space
|
||||
await check_search_space_access(session, search_space_id, user)
|
||||
|
||||
result = await session.execute(
|
||||
select(LLMConfig)
|
||||
.filter(LLMConfig.user_id == user.id)
|
||||
.filter(LLMConfig.search_space_id == search_space_id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch LLM configurations: {e!s}"
|
||||
|
|
@ -82,7 +148,18 @@ async def read_llm_config(
|
|||
):
|
||||
"""Get a specific LLM configuration by ID"""
|
||||
try:
|
||||
llm_config = await check_ownership(session, LLMConfig, llm_config_id, user)
|
||||
# Get the LLM config
|
||||
result = await session.execute(
|
||||
select(LLMConfig).filter(LLMConfig.id == llm_config_id)
|
||||
)
|
||||
llm_config = result.scalars().first()
|
||||
|
||||
if not llm_config:
|
||||
raise HTTPException(status_code=404, detail="LLM configuration not found")
|
||||
|
||||
# Verify user has access to the search space
|
||||
await check_search_space_access(session, llm_config.search_space_id, user)
|
||||
|
||||
return llm_config
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -101,7 +178,18 @@ async def update_llm_config(
|
|||
):
|
||||
"""Update an existing LLM configuration"""
|
||||
try:
|
||||
db_llm_config = await check_ownership(session, LLMConfig, llm_config_id, user)
|
||||
# Get the LLM config
|
||||
result = await session.execute(
|
||||
select(LLMConfig).filter(LLMConfig.id == llm_config_id)
|
||||
)
|
||||
db_llm_config = result.scalars().first()
|
||||
|
||||
if not db_llm_config:
|
||||
raise HTTPException(status_code=404, detail="LLM configuration not found")
|
||||
|
||||
# Verify user has access to the search space
|
||||
await check_search_space_access(session, db_llm_config.search_space_id, user)
|
||||
|
||||
update_data = llm_config_update.model_dump(exclude_unset=True)
|
||||
|
||||
for key, value in update_data.items():
|
||||
|
|
@ -127,7 +215,18 @@ async def delete_llm_config(
|
|||
):
|
||||
"""Delete an LLM configuration"""
|
||||
try:
|
||||
db_llm_config = await check_ownership(session, LLMConfig, llm_config_id, user)
|
||||
# Get the LLM config
|
||||
result = await session.execute(
|
||||
select(LLMConfig).filter(LLMConfig.id == llm_config_id)
|
||||
)
|
||||
db_llm_config = result.scalars().first()
|
||||
|
||||
if not db_llm_config:
|
||||
raise HTTPException(status_code=404, detail="LLM configuration not found")
|
||||
|
||||
# Verify user has access to the search space
|
||||
await check_search_space_access(session, db_llm_config.search_space_id, user)
|
||||
|
||||
await session.delete(db_llm_config)
|
||||
await session.commit()
|
||||
return {"message": "LLM configuration deleted successfully"}
|
||||
|
|
@ -143,99 +242,101 @@ async def delete_llm_config(
|
|||
# User LLM Preferences endpoints
|
||||
|
||||
|
||||
@router.get("/users/me/llm-preferences", response_model=LLMPreferencesRead)
|
||||
@router.get(
|
||||
"/search-spaces/{search_space_id}/llm-preferences",
|
||||
response_model=LLMPreferencesRead,
|
||||
)
|
||||
async def get_user_llm_preferences(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Get the current user's LLM preferences"""
|
||||
"""Get the current user's LLM preferences for a specific search space"""
|
||||
try:
|
||||
# Refresh user to get latest relationships
|
||||
await session.refresh(user)
|
||||
# Verify user has access to the search space
|
||||
await check_search_space_access(session, search_space_id, user)
|
||||
|
||||
result = {
|
||||
"long_context_llm_id": user.long_context_llm_id,
|
||||
"fast_llm_id": user.fast_llm_id,
|
||||
"strategic_llm_id": user.strategic_llm_id,
|
||||
"long_context_llm": None,
|
||||
"fast_llm": None,
|
||||
"strategic_llm": None,
|
||||
# Get or create user preference for this search space
|
||||
preference = await get_or_create_user_preference(
|
||||
session, user.id, search_space_id
|
||||
)
|
||||
|
||||
return {
|
||||
"long_context_llm_id": preference.long_context_llm_id,
|
||||
"fast_llm_id": preference.fast_llm_id,
|
||||
"strategic_llm_id": preference.strategic_llm_id,
|
||||
"long_context_llm": preference.long_context_llm,
|
||||
"fast_llm": preference.fast_llm,
|
||||
"strategic_llm": preference.strategic_llm,
|
||||
}
|
||||
|
||||
# Fetch the actual LLM configs if they exist
|
||||
if user.long_context_llm_id:
|
||||
long_context_llm = await session.execute(
|
||||
select(LLMConfig).filter(
|
||||
LLMConfig.id == user.long_context_llm_id,
|
||||
LLMConfig.user_id == user.id,
|
||||
)
|
||||
)
|
||||
llm_config = long_context_llm.scalars().first()
|
||||
if llm_config:
|
||||
result["long_context_llm"] = llm_config
|
||||
|
||||
if user.fast_llm_id:
|
||||
fast_llm = await session.execute(
|
||||
select(LLMConfig).filter(
|
||||
LLMConfig.id == user.fast_llm_id, LLMConfig.user_id == user.id
|
||||
)
|
||||
)
|
||||
llm_config = fast_llm.scalars().first()
|
||||
if llm_config:
|
||||
result["fast_llm"] = llm_config
|
||||
|
||||
if user.strategic_llm_id:
|
||||
strategic_llm = await session.execute(
|
||||
select(LLMConfig).filter(
|
||||
LLMConfig.id == user.strategic_llm_id, LLMConfig.user_id == user.id
|
||||
)
|
||||
)
|
||||
llm_config = strategic_llm.scalars().first()
|
||||
if llm_config:
|
||||
result["strategic_llm"] = llm_config
|
||||
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch LLM preferences: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.put("/users/me/llm-preferences", response_model=LLMPreferencesRead)
|
||||
@router.put(
|
||||
"/search-spaces/{search_space_id}/llm-preferences",
|
||||
response_model=LLMPreferencesRead,
|
||||
)
|
||||
async def update_user_llm_preferences(
|
||||
search_space_id: int,
|
||||
preferences: LLMPreferencesUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Update the current user's LLM preferences"""
|
||||
"""Update the current user's LLM preferences for a specific search space"""
|
||||
try:
|
||||
# Validate that all provided LLM config IDs belong to the user
|
||||
# Verify user has access to the search space
|
||||
await check_search_space_access(session, search_space_id, user)
|
||||
|
||||
# Get or create user preference for this search space
|
||||
preference = await get_or_create_user_preference(
|
||||
session, user.id, search_space_id
|
||||
)
|
||||
|
||||
# Validate that all provided LLM config IDs belong to the search space
|
||||
update_data = preferences.model_dump(exclude_unset=True)
|
||||
|
||||
for _key, llm_config_id in update_data.items():
|
||||
if llm_config_id is not None:
|
||||
# Verify ownership of the LLM config
|
||||
# Verify the LLM config belongs to the search space
|
||||
result = await session.execute(
|
||||
select(LLMConfig).filter(
|
||||
LLMConfig.id == llm_config_id, LLMConfig.user_id == user.id
|
||||
LLMConfig.id == llm_config_id,
|
||||
LLMConfig.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
llm_config = result.scalars().first()
|
||||
if not llm_config:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"LLM configuration {llm_config_id} not found or you don't have permission to access it",
|
||||
detail=f"LLM configuration {llm_config_id} not found in this search space",
|
||||
)
|
||||
|
||||
# Update user preferences
|
||||
for key, value in update_data.items():
|
||||
setattr(user, key, value)
|
||||
setattr(preference, key, value)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
await session.refresh(preference)
|
||||
|
||||
# Reload relationships
|
||||
await session.refresh(
|
||||
preference, ["long_context_llm", "fast_llm", "strategic_llm"]
|
||||
)
|
||||
|
||||
# Return updated preferences
|
||||
return await get_user_llm_preferences(session, user)
|
||||
return {
|
||||
"long_context_llm_id": preference.long_context_llm_id,
|
||||
"fast_llm_id": preference.fast_llm_id,
|
||||
"strategic_llm_id": preference.strategic_llm_id,
|
||||
"long_context_llm": preference.long_context_llm,
|
||||
"fast_llm": preference.fast_llm,
|
||||
"strategic_llm": preference.strategic_llm,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -30,7 +29,9 @@ class LLMConfigBase(BaseModel):
|
|||
|
||||
|
||||
class LLMConfigCreate(LLMConfigBase):
|
||||
pass
|
||||
search_space_id: int = Field(
|
||||
..., description="Search space ID to associate the LLM config with"
|
||||
)
|
||||
|
||||
|
||||
class LLMConfigUpdate(BaseModel):
|
||||
|
|
@ -56,6 +57,6 @@ class LLMConfigUpdate(BaseModel):
|
|||
class LLMConfigRead(LLMConfigBase, IDModel, TimestampModel):
|
||||
id: int
|
||||
created_at: datetime
|
||||
user_id: uuid.UUID
|
||||
search_space_id: int
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
|
|||
|
|
@ -24,4 +24,4 @@ class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
|
|||
created_at: datetime
|
||||
user_id: uuid.UUID
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,14 @@
|
|||
import logging
|
||||
|
||||
import litellm
|
||||
from langchain_litellm import ChatLiteLLM
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import LLMConfig, User
|
||||
from app.db import LLMConfig, UserSearchSpacePreference
|
||||
|
||||
# Configure litellm to automatically drop unsupported parameters
|
||||
litellm.drop_params = True
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -16,54 +20,67 @@ class LLMRole:
|
|||
|
||||
|
||||
async def get_user_llm_instance(
|
||||
session: AsyncSession, user_id: str, role: str
|
||||
session: AsyncSession, user_id: str, search_space_id: int, role: str
|
||||
) -> ChatLiteLLM | None:
|
||||
"""
|
||||
Get a ChatLiteLLM instance for a specific user and role.
|
||||
Get a ChatLiteLLM instance for a specific user, search space, and role.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
user_id: User ID
|
||||
search_space_id: Search Space ID
|
||||
role: LLM role ('long_context', 'fast', or 'strategic')
|
||||
|
||||
Returns:
|
||||
ChatLiteLLM instance or None if not found
|
||||
"""
|
||||
try:
|
||||
# Get user with their LLM preferences
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalars().first()
|
||||
# Get user's LLM preferences for this search space
|
||||
result = await session.execute(
|
||||
select(UserSearchSpacePreference).where(
|
||||
UserSearchSpacePreference.user_id == user_id,
|
||||
UserSearchSpacePreference.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
preference = result.scalars().first()
|
||||
|
||||
if not user:
|
||||
logger.error(f"User {user_id} not found")
|
||||
if not preference:
|
||||
logger.error(
|
||||
f"No LLM preferences found for user {user_id} in search space {search_space_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Get the appropriate LLM config ID based on role
|
||||
llm_config_id = None
|
||||
if role == LLMRole.LONG_CONTEXT:
|
||||
llm_config_id = user.long_context_llm_id
|
||||
llm_config_id = preference.long_context_llm_id
|
||||
elif role == LLMRole.FAST:
|
||||
llm_config_id = user.fast_llm_id
|
||||
llm_config_id = preference.fast_llm_id
|
||||
elif role == LLMRole.STRATEGIC:
|
||||
llm_config_id = user.strategic_llm_id
|
||||
llm_config_id = preference.strategic_llm_id
|
||||
else:
|
||||
logger.error(f"Invalid LLM role: {role}")
|
||||
return None
|
||||
|
||||
if not llm_config_id:
|
||||
logger.error(f"No {role} LLM configured for user {user_id}")
|
||||
logger.error(
|
||||
f"No {role} LLM configured for user {user_id} in search space {search_space_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Get the LLM configuration
|
||||
result = await session.execute(
|
||||
select(LLMConfig).where(
|
||||
LLMConfig.id == llm_config_id, LLMConfig.user_id == user_id
|
||||
LLMConfig.id == llm_config_id,
|
||||
LLMConfig.search_space_id == search_space_id,
|
||||
)
|
||||
)
|
||||
llm_config = result.scalars().first()
|
||||
|
||||
if not llm_config:
|
||||
logger.error(f"LLM config {llm_config_id} not found for user {user_id}")
|
||||
logger.error(
|
||||
f"LLM config {llm_config_id} not found in search space {search_space_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Build the model string for litellm
|
||||
|
|
@ -113,19 +130,25 @@ async def get_user_llm_instance(
|
|||
|
||||
|
||||
async def get_user_long_context_llm(
|
||||
session: AsyncSession, user_id: str
|
||||
session: AsyncSession, user_id: str, search_space_id: int
|
||||
) -> ChatLiteLLM | None:
|
||||
"""Get user's long context LLM instance."""
|
||||
return await get_user_llm_instance(session, user_id, LLMRole.LONG_CONTEXT)
|
||||
"""Get user's long context LLM instance for a specific search space."""
|
||||
return await get_user_llm_instance(
|
||||
session, user_id, search_space_id, LLMRole.LONG_CONTEXT
|
||||
)
|
||||
|
||||
|
||||
async def get_user_fast_llm(session: AsyncSession, user_id: str) -> ChatLiteLLM | None:
|
||||
"""Get user's fast LLM instance."""
|
||||
return await get_user_llm_instance(session, user_id, LLMRole.FAST)
|
||||
async def get_user_fast_llm(
|
||||
session: AsyncSession, user_id: str, search_space_id: int
|
||||
) -> ChatLiteLLM | None:
|
||||
"""Get user's fast LLM instance for a specific search space."""
|
||||
return await get_user_llm_instance(session, user_id, search_space_id, LLMRole.FAST)
|
||||
|
||||
|
||||
async def get_user_strategic_llm(
|
||||
session: AsyncSession, user_id: str
|
||||
session: AsyncSession, user_id: str, search_space_id: int
|
||||
) -> ChatLiteLLM | None:
|
||||
"""Get user's strategic LLM instance."""
|
||||
return await get_user_llm_instance(session, user_id, LLMRole.STRATEGIC)
|
||||
"""Get user's strategic LLM instance for a specific search space."""
|
||||
return await get_user_llm_instance(
|
||||
session, user_id, search_space_id, LLMRole.STRATEGIC
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ class QueryService:
|
|||
user_query: str,
|
||||
session: AsyncSession,
|
||||
user_id: str,
|
||||
search_space_id: int,
|
||||
chat_history_str: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
|
|
@ -27,6 +28,7 @@ class QueryService:
|
|||
user_query: The original user query
|
||||
session: Database session for accessing user LLM configs
|
||||
user_id: User ID to get their specific LLM configuration
|
||||
search_space_id: Search Space ID to get user's LLM preferences
|
||||
chat_history_str: Optional chat history string
|
||||
|
||||
Returns:
|
||||
|
|
@ -37,10 +39,10 @@ class QueryService:
|
|||
|
||||
try:
|
||||
# Get the user's strategic LLM instance
|
||||
llm = await get_user_strategic_llm(session, user_id)
|
||||
llm = await get_user_strategic_llm(session, user_id, search_space_id)
|
||||
if not llm:
|
||||
print(
|
||||
f"Warning: No strategic LLM configured for user {user_id}. Using original query."
|
||||
f"Warning: No strategic LLM configured for user {user_id} in search space {search_space_id}. Using original query."
|
||||
)
|
||||
return user_query
|
||||
|
||||
|
|
|
|||
|
|
@ -260,7 +260,9 @@ async def index_airtable_records(
|
|||
continue
|
||||
|
||||
# Generate document summary
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
document_metadata = {
|
||||
|
|
|
|||
|
|
@ -222,7 +222,9 @@ async def index_clickup_tasks(
|
|||
continue
|
||||
|
||||
# Generate summary with metadata
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
document_metadata = {
|
||||
|
|
|
|||
|
|
@ -233,7 +233,9 @@ async def index_confluence_pages(
|
|||
continue
|
||||
|
||||
# Generate summary with metadata
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
comment_count = len(comments)
|
||||
|
||||
if user_llm:
|
||||
|
|
|
|||
|
|
@ -325,7 +325,9 @@ async def index_discord_messages(
|
|||
continue
|
||||
|
||||
# Get user's long context LLM
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
if not user_llm:
|
||||
logger.error(
|
||||
f"No long context LLM configured for user {user_id}"
|
||||
|
|
|
|||
|
|
@ -213,7 +213,9 @@ async def index_github_repos(
|
|||
continue
|
||||
|
||||
# Generate summary with metadata
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
if user_llm:
|
||||
# Extract file extension from file path
|
||||
file_extension = (
|
||||
|
|
|
|||
|
|
@ -266,7 +266,9 @@ async def index_google_calendar_events(
|
|||
continue
|
||||
|
||||
# Generate summary with metadata
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
document_metadata = {
|
||||
|
|
|
|||
|
|
@ -210,7 +210,9 @@ async def index_google_gmail_messages(
|
|||
continue
|
||||
|
||||
# Generate summary with metadata
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
document_metadata = {
|
||||
|
|
|
|||
|
|
@ -216,7 +216,9 @@ async def index_jira_issues(
|
|||
continue
|
||||
|
||||
# Generate summary with metadata
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
comment_count = len(formatted_issue.get("comments", []))
|
||||
|
||||
if user_llm:
|
||||
|
|
|
|||
|
|
@ -228,7 +228,9 @@ async def index_linear_issues(
|
|||
continue
|
||||
|
||||
# Generate summary with metadata
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
state = formatted_issue.get("state", "Unknown")
|
||||
description = formatted_issue.get("description", "")
|
||||
comment_count = len(formatted_issue.get("comments", []))
|
||||
|
|
|
|||
|
|
@ -270,7 +270,9 @@ async def index_luma_events(
|
|||
continue
|
||||
|
||||
# Generate summary with metadata
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
|
||||
if user_llm:
|
||||
document_metadata = {
|
||||
|
|
|
|||
|
|
@ -299,7 +299,9 @@ async def index_notion_pages(
|
|||
continue
|
||||
|
||||
# Get user's long context LLM
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
user_llm = await get_user_long_context_llm(
|
||||
session, user_id, search_space_id
|
||||
)
|
||||
if not user_llm:
|
||||
logger.error(f"No long context LLM configured for user {user_id}")
|
||||
skipped_pages.append(f"{page_title} (no LLM configured)")
|
||||
|
|
|
|||
|
|
@ -104,9 +104,11 @@ async def add_extension_received_document(
|
|||
return existing_document
|
||||
|
||||
# Get user's long context LLM
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
|
||||
if not user_llm:
|
||||
raise RuntimeError(f"No long context LLM configured for user {user_id}")
|
||||
raise RuntimeError(
|
||||
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
|
||||
)
|
||||
|
||||
# Generate summary with metadata
|
||||
document_metadata = {
|
||||
|
|
|
|||
|
|
@ -60,9 +60,11 @@ async def add_received_file_document_using_unstructured(
|
|||
# TODO: Check if file_markdown exceeds token limit of embedding model
|
||||
|
||||
# Get user's long context LLM
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
|
||||
if not user_llm:
|
||||
raise RuntimeError(f"No long context LLM configured for user {user_id}")
|
||||
raise RuntimeError(
|
||||
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
|
||||
)
|
||||
|
||||
# Generate summary with metadata
|
||||
document_metadata = {
|
||||
|
|
@ -140,9 +142,11 @@ async def add_received_file_document_using_llamacloud(
|
|||
return existing_document
|
||||
|
||||
# Get user's long context LLM
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
|
||||
if not user_llm:
|
||||
raise RuntimeError(f"No long context LLM configured for user {user_id}")
|
||||
raise RuntimeError(
|
||||
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
|
||||
)
|
||||
|
||||
# Generate summary with metadata
|
||||
document_metadata = {
|
||||
|
|
@ -221,9 +225,11 @@ async def add_received_file_document_using_docling(
|
|||
return existing_document
|
||||
|
||||
# Get user's long context LLM
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
|
||||
if not user_llm:
|
||||
raise RuntimeError(f"No long context LLM configured for user {user_id}")
|
||||
raise RuntimeError(
|
||||
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
|
||||
)
|
||||
|
||||
# Generate summary using chunked processing for large documents
|
||||
from app.services.docling_service import create_docling_service
|
||||
|
|
|
|||
|
|
@ -75,9 +75,11 @@ async def add_received_markdown_file_document(
|
|||
return existing_document
|
||||
|
||||
# Get user's long context LLM
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
|
||||
if not user_llm:
|
||||
raise RuntimeError(f"No long context LLM configured for user {user_id}")
|
||||
raise RuntimeError(
|
||||
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
|
||||
)
|
||||
|
||||
# Generate summary with metadata
|
||||
document_metadata = {
|
||||
|
|
|
|||
|
|
@ -161,9 +161,11 @@ async def add_crawled_url_document(
|
|||
)
|
||||
|
||||
# Get user's long context LLM
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
|
||||
if not user_llm:
|
||||
raise RuntimeError(f"No long context LLM configured for user {user_id}")
|
||||
raise RuntimeError(
|
||||
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
|
||||
)
|
||||
|
||||
# Generate summary
|
||||
await task_logger.log_task_progress(
|
||||
|
|
|
|||
|
|
@ -234,9 +234,11 @@ async def add_youtube_video_document(
|
|||
)
|
||||
|
||||
# Get user's long context LLM
|
||||
user_llm = await get_user_long_context_llm(session, user_id)
|
||||
user_llm = await get_user_long_context_llm(session, user_id, search_space_id)
|
||||
if not user_llm:
|
||||
raise RuntimeError(f"No long context LLM configured for user {user_id}")
|
||||
raise RuntimeError(
|
||||
f"No long context LLM configured for user {user_id} in search space {search_space_id}"
|
||||
)
|
||||
|
||||
# Generate summary
|
||||
await task_logger.log_task_progress(
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ async def generate_chat_podcast(
|
|||
"configurable": {
|
||||
"podcast_title": "SurfSense",
|
||||
"user_id": str(user_id),
|
||||
"search_space_id": search_space_id,
|
||||
}
|
||||
}
|
||||
# Initialize state with database session and streaming service
|
||||
|
|
|
|||
|
|
@ -16,89 +16,80 @@ from fastapi import HTTPException
|
|||
def validate_search_space_id(search_space_id: Any) -> int:
|
||||
"""
|
||||
Validate and convert search_space_id to integer.
|
||||
|
||||
|
||||
Args:
|
||||
search_space_id: The search space ID to validate
|
||||
|
||||
|
||||
Returns:
|
||||
int: Validated search space ID
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails
|
||||
"""
|
||||
if search_space_id is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="search_space_id is required"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail="search_space_id is required")
|
||||
|
||||
if isinstance(search_space_id, bool):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="search_space_id must be an integer, not a boolean"
|
||||
status_code=400, detail="search_space_id must be an integer, not a boolean"
|
||||
)
|
||||
|
||||
|
||||
if isinstance(search_space_id, int):
|
||||
if search_space_id <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="search_space_id must be a positive integer"
|
||||
|
||||
status_code=400, detail="search_space_id must be a positive integer"
|
||||
)
|
||||
return search_space_id
|
||||
|
||||
|
||||
if isinstance(search_space_id, str):
|
||||
# Check if it's a valid integer string
|
||||
if not search_space_id.strip():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="search_space_id cannot be empty"
|
||||
status_code=400, detail="search_space_id cannot be empty"
|
||||
)
|
||||
|
||||
|
||||
# Check for valid integer format (no leading zeros, no decimal points)
|
||||
if not re.match(r'^[1-9]\d*$', search_space_id.strip()):
|
||||
if not re.match(r"^[1-9]\d*$", search_space_id.strip()):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="search_space_id must be a valid positive integer"
|
||||
detail="search_space_id must be a valid positive integer",
|
||||
)
|
||||
|
||||
|
||||
value = int(search_space_id.strip())
|
||||
# Regex already guarantees value > 0, but check retained for clarity
|
||||
if value <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="search_space_id must be a positive integer"
|
||||
status_code=400, detail="search_space_id must be a positive integer"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="search_space_id must be an integer or string representation of an integer"
|
||||
detail="search_space_id must be an integer or string representation of an integer",
|
||||
)
|
||||
|
||||
|
||||
def validate_document_ids(document_ids: Any) -> list[int]:
|
||||
"""
|
||||
Validate and convert document_ids to list of integers.
|
||||
|
||||
|
||||
Args:
|
||||
document_ids: The document IDs to validate
|
||||
|
||||
|
||||
Returns:
|
||||
List[int]: Validated list of document IDs
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails
|
||||
"""
|
||||
if document_ids is None:
|
||||
return []
|
||||
|
||||
|
||||
if not isinstance(document_ids, list):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="document_ids_to_add_in_context must be a list"
|
||||
status_code=400, detail="document_ids_to_add_in_context must be a list"
|
||||
)
|
||||
|
||||
|
||||
validated_ids = []
|
||||
for i, doc_id in enumerate(document_ids):
|
||||
if isinstance(doc_id, bool):
|
||||
|
|
@ -111,119 +102,110 @@ def validate_document_ids(document_ids: Any) -> list[int]:
|
|||
if doc_id <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"document_ids_to_add_in_context[{i}] must be a positive integer"
|
||||
detail=f"document_ids_to_add_in_context[{i}] must be a positive integer",
|
||||
)
|
||||
validated_ids.append(doc_id)
|
||||
elif isinstance(doc_id, str):
|
||||
if not doc_id.strip():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"document_ids_to_add_in_context[{i}] cannot be empty"
|
||||
detail=f"document_ids_to_add_in_context[{i}] cannot be empty",
|
||||
)
|
||||
|
||||
if not re.match(r'^[1-9]\d*$', doc_id.strip()):
|
||||
|
||||
if not re.match(r"^[1-9]\d*$", doc_id.strip()):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"document_ids_to_add_in_context[{i}] must be a valid positive integer"
|
||||
detail=f"document_ids_to_add_in_context[{i}] must be a valid positive integer",
|
||||
)
|
||||
|
||||
|
||||
value = int(doc_id.strip())
|
||||
# Regex already guarantees value > 0
|
||||
if value <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"document_ids_to_add_in_context[{i}] must be a positive integer"
|
||||
detail=f"document_ids_to_add_in_context[{i}] must be a positive integer",
|
||||
)
|
||||
validated_ids.append(value)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"document_ids_to_add_in_context[{i}] must be an integer or string representation of an integer"
|
||||
detail=f"document_ids_to_add_in_context[{i}] must be an integer or string representation of an integer",
|
||||
)
|
||||
|
||||
|
||||
return validated_ids
|
||||
|
||||
|
||||
def validate_connectors(connectors: Any) -> list[str]:
|
||||
"""
|
||||
Validate selected_connectors list.
|
||||
|
||||
|
||||
Args:
|
||||
connectors: The connectors to validate
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: Validated list of connector names
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails
|
||||
"""
|
||||
if connectors is None:
|
||||
return []
|
||||
|
||||
|
||||
if not isinstance(connectors, list):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="selected_connectors must be a list"
|
||||
status_code=400, detail="selected_connectors must be a list"
|
||||
)
|
||||
|
||||
|
||||
validated_connectors = []
|
||||
for i, connector in enumerate(connectors):
|
||||
if not isinstance(connector, str):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"selected_connectors[{i}] must be a string"
|
||||
status_code=400, detail=f"selected_connectors[{i}] must be a string"
|
||||
)
|
||||
|
||||
|
||||
if not connector.strip():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"selected_connectors[{i}] cannot be empty"
|
||||
status_code=400, detail=f"selected_connectors[{i}] cannot be empty"
|
||||
)
|
||||
|
||||
|
||||
trimmed = connector.strip()
|
||||
if not re.fullmatch(r'[\w\-_]+', trimmed):
|
||||
if not re.fullmatch(r"[\w\-_]+", trimmed):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"selected_connectors[{i}] contains invalid characters"
|
||||
detail=f"selected_connectors[{i}] contains invalid characters",
|
||||
)
|
||||
validated_connectors.append(trimmed)
|
||||
|
||||
|
||||
return validated_connectors
|
||||
|
||||
|
||||
def validate_research_mode(research_mode: Any) -> str:
|
||||
"""
|
||||
Validate research_mode parameter.
|
||||
|
||||
|
||||
Args:
|
||||
research_mode: The research mode to validate
|
||||
|
||||
|
||||
Returns:
|
||||
str: Validated research mode
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails
|
||||
"""
|
||||
if research_mode is None:
|
||||
return "QNA" # Default value
|
||||
|
||||
|
||||
if not isinstance(research_mode, str):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="research_mode must be a string"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="research_mode must be a string")
|
||||
normalized_mode = research_mode.strip().upper()
|
||||
if not normalized_mode:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="research_mode cannot be empty"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="research_mode cannot be empty")
|
||||
|
||||
valid_modes = ["REPORT_GENERAL", "REPORT_DEEP", "REPORT_DEEPER", "QNA"]
|
||||
if normalized_mode not in valid_modes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"research_mode must be one of: {', '.join(valid_modes)}"
|
||||
detail=f"research_mode must be one of: {', '.join(valid_modes)}",
|
||||
)
|
||||
return normalized_mode
|
||||
|
||||
|
|
@ -231,36 +213,30 @@ def validate_research_mode(research_mode: Any) -> str:
|
|||
def validate_search_mode(search_mode: Any) -> str:
|
||||
"""
|
||||
Validate search_mode parameter.
|
||||
|
||||
|
||||
Args:
|
||||
search_mode: The search mode to validate
|
||||
|
||||
|
||||
Returns:
|
||||
str: Validated search mode
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails
|
||||
"""
|
||||
if search_mode is None:
|
||||
return "CHUNKS" # Default value
|
||||
|
||||
|
||||
if not isinstance(search_mode, str):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="search_mode must be a string"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="search_mode must be a string")
|
||||
normalized_mode = search_mode.strip().upper()
|
||||
if not normalized_mode:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="search_mode cannot be empty"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="search_mode cannot be empty")
|
||||
|
||||
valid_modes = ["CHUNKS", "DOCUMENTS"]
|
||||
if normalized_mode not in valid_modes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"search_mode must be one of: {', '.join(valid_modes)}"
|
||||
detail=f"search_mode must be one of: {', '.join(valid_modes)}",
|
||||
)
|
||||
return normalized_mode
|
||||
|
||||
|
|
@ -268,185 +244,155 @@ def validate_search_mode(search_mode: Any) -> str:
|
|||
def validate_messages(messages: Any) -> list[dict]:
|
||||
"""
|
||||
Validate messages structure.
|
||||
|
||||
|
||||
Args:
|
||||
messages: The messages to validate
|
||||
|
||||
|
||||
Returns:
|
||||
List[dict]: Validated messages
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails
|
||||
"""
|
||||
if not isinstance(messages, list):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="messages must be a list"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail="messages must be a list")
|
||||
|
||||
if not messages:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="messages cannot be empty"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail="messages cannot be empty")
|
||||
|
||||
validated_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
if not isinstance(message, dict):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"messages[{i}] must be a dictionary"
|
||||
status_code=400, detail=f"messages[{i}] must be a dictionary"
|
||||
)
|
||||
|
||||
|
||||
if "role" not in message:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"messages[{i}] must have a 'role' field"
|
||||
status_code=400, detail=f"messages[{i}] must have a 'role' field"
|
||||
)
|
||||
|
||||
|
||||
if "content" not in message:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"messages[{i}] must have a 'content' field"
|
||||
status_code=400, detail=f"messages[{i}] must have a 'content' field"
|
||||
)
|
||||
|
||||
|
||||
role = message["role"]
|
||||
if not isinstance(role, str) or role not in ["user", "assistant", "system"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"messages[{i}].role must be 'user', 'assistant', or 'system'"
|
||||
detail=f"messages[{i}].role must be 'user', 'assistant', or 'system'",
|
||||
)
|
||||
|
||||
|
||||
content = message["content"]
|
||||
if not isinstance(content, str):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"messages[{i}].content must be a string"
|
||||
status_code=400, detail=f"messages[{i}].content must be a string"
|
||||
)
|
||||
|
||||
|
||||
if not content.strip():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"messages[{i}].content cannot be empty"
|
||||
status_code=400, detail=f"messages[{i}].content cannot be empty"
|
||||
)
|
||||
|
||||
|
||||
# Trim content and enforce max length (10,000 chars)
|
||||
sanitized_content = content.strip()
|
||||
if len(sanitized_content) > 10000: # Reasonable limit
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"messages[{i}].content is too long (max 10000 characters)"
|
||||
detail=f"messages[{i}].content is too long (max 10000 characters)",
|
||||
)
|
||||
|
||||
validated_messages.append({
|
||||
"role": role,
|
||||
"content": sanitized_content
|
||||
})
|
||||
|
||||
|
||||
validated_messages.append({"role": role, "content": sanitized_content})
|
||||
|
||||
return validated_messages
|
||||
|
||||
|
||||
def validate_email(email: str) -> str:
|
||||
"""
|
||||
Validate email address using pyvalidators library.
|
||||
|
||||
|
||||
Args:
|
||||
email: The email address to validate
|
||||
|
||||
|
||||
Returns:
|
||||
str: Validated email address
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails
|
||||
"""
|
||||
if not email or not email.strip():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Email address is required"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail="Email address is required")
|
||||
|
||||
email = email.strip()
|
||||
|
||||
|
||||
if not validators.email(email):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid email address format"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail="Invalid email address format")
|
||||
|
||||
return email
|
||||
|
||||
|
||||
def validate_url(url: str) -> str:
|
||||
"""
|
||||
Validate URL using pyvalidators library.
|
||||
|
||||
|
||||
Args:
|
||||
url: The URL to validate
|
||||
|
||||
|
||||
Returns:
|
||||
str: Validated URL
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails
|
||||
"""
|
||||
if not url or not url.strip():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="URL is required"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail="URL is required")
|
||||
|
||||
url = url.strip()
|
||||
|
||||
|
||||
if not validators.url(url):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid URL format"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail="Invalid URL format")
|
||||
|
||||
return url
|
||||
|
||||
|
||||
def validate_uuid(uuid_string: str) -> str:
|
||||
"""
|
||||
Validate UUID using pyvalidators library.
|
||||
|
||||
|
||||
Args:
|
||||
uuid_string: The UUID string to validate
|
||||
|
||||
|
||||
Returns:
|
||||
str: Validated UUID string
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails
|
||||
"""
|
||||
if not uuid_string or not uuid_string.strip():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="UUID is required"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail="UUID is required")
|
||||
|
||||
uuid_string = uuid_string.strip()
|
||||
|
||||
|
||||
if not validators.uuid(uuid_string):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid UUID format"
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail="Invalid UUID format")
|
||||
|
||||
return uuid_string
|
||||
|
||||
|
||||
def validate_connector_config(connector_type: str | Any, config: dict[str, Any]) -> dict[str, Any]:
|
||||
def validate_connector_config(
|
||||
connector_type: str | Any, config: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Validate connector configuration based on connector type.
|
||||
|
||||
|
||||
Args:
|
||||
connector_type: The type of connector (string or enum)
|
||||
config: The configuration dictionary to validate
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Validated configuration
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If validation fails
|
||||
"""
|
||||
|
|
@ -454,76 +400,69 @@ def validate_connector_config(connector_type: str | Any, config: dict[str, Any])
|
|||
raise ValueError("config must be a dictionary of connector settings")
|
||||
|
||||
# Convert enum to string if needed
|
||||
connector_type_str = str(connector_type).split('.')[-1] if hasattr(connector_type, 'value') else str(connector_type)
|
||||
|
||||
connector_type_str = (
|
||||
str(connector_type).split(".")[-1]
|
||||
if hasattr(connector_type, "value")
|
||||
else str(connector_type)
|
||||
)
|
||||
|
||||
# Validation function helpers
|
||||
def validate_email_field(key: str, connector_name: str) -> None:
|
||||
if not validators.email(config.get(key, "")):
|
||||
raise ValueError(f"Invalid email format for {connector_name} connector")
|
||||
|
||||
|
||||
def validate_url_field(key: str, connector_name: str) -> None:
|
||||
if not validators.url(config.get(key, "")):
|
||||
raise ValueError(f"Invalid base URL format for {connector_name} connector")
|
||||
|
||||
|
||||
def validate_list_field(key: str, field_name: str) -> None:
|
||||
value = config.get(key)
|
||||
if not isinstance(value, list) or not value:
|
||||
raise ValueError(f"{field_name} must be a non-empty list of strings")
|
||||
|
||||
|
||||
# Lookup table for connector validation rules
|
||||
connector_rules = {
|
||||
"SERPER_API": {
|
||||
"required": ["SERPER_API_KEY"],
|
||||
"validators": {}
|
||||
},
|
||||
"TAVILY_API": {
|
||||
"required": ["TAVILY_API_KEY"],
|
||||
"validators": {}
|
||||
},
|
||||
"LINKUP_API": {
|
||||
"required": ["LINKUP_API_KEY"],
|
||||
"validators": {}
|
||||
},
|
||||
"SLACK_CONNECTOR": {
|
||||
"required": ["SLACK_BOT_TOKEN"],
|
||||
"validators": {}
|
||||
},
|
||||
"SERPER_API": {"required": ["SERPER_API_KEY"], "validators": {}},
|
||||
"TAVILY_API": {"required": ["TAVILY_API_KEY"], "validators": {}},
|
||||
"LINKUP_API": {"required": ["LINKUP_API_KEY"], "validators": {}},
|
||||
"SLACK_CONNECTOR": {"required": ["SLACK_BOT_TOKEN"], "validators": {}},
|
||||
"NOTION_CONNECTOR": {
|
||||
"required": ["NOTION_INTEGRATION_TOKEN"],
|
||||
"validators": {}
|
||||
"validators": {},
|
||||
},
|
||||
"GITHUB_CONNECTOR": {
|
||||
"required": ["GITHUB_PAT", "repo_full_names"],
|
||||
"validators": {
|
||||
"repo_full_names": lambda: validate_list_field("repo_full_names", "repo_full_names")
|
||||
}
|
||||
},
|
||||
"LINEAR_CONNECTOR": {
|
||||
"required": ["LINEAR_API_KEY"],
|
||||
"validators": {}
|
||||
},
|
||||
"DISCORD_CONNECTOR": {
|
||||
"required": ["DISCORD_BOT_TOKEN"],
|
||||
"validators": {}
|
||||
"repo_full_names": lambda: validate_list_field(
|
||||
"repo_full_names", "repo_full_names"
|
||||
)
|
||||
},
|
||||
},
|
||||
"LINEAR_CONNECTOR": {"required": ["LINEAR_API_KEY"], "validators": {}},
|
||||
"DISCORD_CONNECTOR": {"required": ["DISCORD_BOT_TOKEN"], "validators": {}},
|
||||
"JIRA_CONNECTOR": {
|
||||
"required": ["JIRA_EMAIL", "JIRA_API_TOKEN", "JIRA_BASE_URL"],
|
||||
"validators": {
|
||||
"JIRA_EMAIL": lambda: validate_email_field("JIRA_EMAIL", "JIRA"),
|
||||
"JIRA_BASE_URL": lambda: validate_url_field("JIRA_BASE_URL", "JIRA")
|
||||
}
|
||||
"JIRA_BASE_URL": lambda: validate_url_field("JIRA_BASE_URL", "JIRA"),
|
||||
},
|
||||
},
|
||||
"CONFLUENCE_CONNECTOR": {
|
||||
"required": ["CONFLUENCE_BASE_URL", "CONFLUENCE_EMAIL", "CONFLUENCE_API_TOKEN"],
|
||||
"required": [
|
||||
"CONFLUENCE_BASE_URL",
|
||||
"CONFLUENCE_EMAIL",
|
||||
"CONFLUENCE_API_TOKEN",
|
||||
],
|
||||
"validators": {
|
||||
"CONFLUENCE_EMAIL": lambda: validate_email_field("CONFLUENCE_EMAIL", "Confluence"),
|
||||
"CONFLUENCE_BASE_URL": lambda: validate_url_field("CONFLUENCE_BASE_URL", "Confluence")
|
||||
}
|
||||
},
|
||||
"CLICKUP_CONNECTOR": {
|
||||
"required": ["CLICKUP_API_TOKEN"],
|
||||
"validators": {}
|
||||
"CONFLUENCE_EMAIL": lambda: validate_email_field(
|
||||
"CONFLUENCE_EMAIL", "Confluence"
|
||||
),
|
||||
"CONFLUENCE_BASE_URL": lambda: validate_url_field(
|
||||
"CONFLUENCE_BASE_URL", "Confluence"
|
||||
),
|
||||
},
|
||||
},
|
||||
"CLICKUP_CONNECTOR": {"required": ["CLICKUP_API_TOKEN"], "validators": {}},
|
||||
# "GOOGLE_CALENDAR_CONNECTOR": {
|
||||
# "required": ["token", "refresh_token", "token_uri", "client_id", "expiry", "scopes", "client_secret"],
|
||||
# "validators": {},
|
||||
|
|
@ -538,26 +477,23 @@ def validate_connector_config(connector_type: str | Any, config: dict[str, Any])
|
|||
# "required": ["AIRTABLE_API_KEY", "AIRTABLE_BASE_ID"],
|
||||
# "validators": {}
|
||||
# },
|
||||
"LUMA_CONNECTOR": {
|
||||
"required": ["LUMA_API_KEY"],
|
||||
"validators": {}
|
||||
}
|
||||
"LUMA_CONNECTOR": {"required": ["LUMA_API_KEY"], "validators": {}},
|
||||
}
|
||||
|
||||
|
||||
rules = connector_rules.get(connector_type_str)
|
||||
if not rules:
|
||||
return config # Unknown connector type, pass through
|
||||
|
||||
|
||||
# Validate required keys match exactly
|
||||
if set(config.keys()) != set(rules["required"]):
|
||||
raise ValueError(
|
||||
f"For {connector_type_str} connector type, config must only contain these keys: {rules['required']}"
|
||||
)
|
||||
|
||||
|
||||
# Apply custom validators first (these check format before emptiness)
|
||||
for validator_func in rules["validators"].values():
|
||||
validator_func()
|
||||
|
||||
|
||||
# Validate each field is not empty
|
||||
for key in rules["required"]:
|
||||
# Special handling for Google connectors that don't allow None or empty strings
|
||||
|
|
@ -568,5 +504,5 @@ def validate_connector_config(connector_type: str | Any, config: dict[str, Any])
|
|||
# Standard check: field must have a truthy value
|
||||
if not config.get(key):
|
||||
raise ValueError(f"{key} cannot be empty")
|
||||
|
||||
|
||||
return config
|
||||
|
|
|
|||
|
|
@ -1,12 +1,16 @@
|
|||
"use client";
|
||||
|
||||
import { Loader2 } from "lucide-react";
|
||||
import { usePathname, useRouter } from "next/navigation";
|
||||
import type React from "react";
|
||||
import { useState } from "react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { DashboardBreadcrumb } from "@/components/dashboard-breadcrumb";
|
||||
import { AppSidebarProvider } from "@/components/sidebar/AppSidebarProvider";
|
||||
import { ThemeTogglerComponent } from "@/components/theme/theme-toggle";
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
|
||||
import { useLLMPreferences } from "@/hooks/use-llm-configs";
|
||||
|
||||
export function DashboardClientLayout({
|
||||
children,
|
||||
|
|
@ -19,6 +23,16 @@ export function DashboardClientLayout({
|
|||
navSecondary: any[];
|
||||
navMain: any[];
|
||||
}) {
|
||||
const router = useRouter();
|
||||
const pathname = usePathname();
|
||||
const searchSpaceIdNum = Number(searchSpaceId);
|
||||
|
||||
const { loading, error, isOnboardingComplete } = useLLMPreferences(searchSpaceIdNum);
|
||||
const [hasCheckedOnboarding, setHasCheckedOnboarding] = useState(false);
|
||||
|
||||
// Skip onboarding check if we're already on the onboarding page
|
||||
const isOnboardingPage = pathname?.includes("/onboard");
|
||||
|
||||
const [open, setOpen] = useState<boolean>(() => {
|
||||
try {
|
||||
const match = document.cookie.match(/(?:^|; )sidebar_state=([^;]+)/);
|
||||
|
|
@ -29,6 +43,68 @@ export function DashboardClientLayout({
|
|||
return true;
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
// Skip check if already on onboarding page
|
||||
if (isOnboardingPage) {
|
||||
setHasCheckedOnboarding(true);
|
||||
return;
|
||||
}
|
||||
|
||||
// Only check once after preferences have loaded
|
||||
if (!loading && !hasCheckedOnboarding) {
|
||||
const onboardingComplete = isOnboardingComplete();
|
||||
|
||||
if (!onboardingComplete) {
|
||||
router.push(`/dashboard/${searchSpaceId}/onboard`);
|
||||
}
|
||||
|
||||
setHasCheckedOnboarding(true);
|
||||
}
|
||||
}, [
|
||||
loading,
|
||||
isOnboardingComplete,
|
||||
isOnboardingPage,
|
||||
router,
|
||||
searchSpaceId,
|
||||
hasCheckedOnboarding,
|
||||
]);
|
||||
|
||||
// Show loading screen while checking onboarding status (only on first load)
|
||||
if (!hasCheckedOnboarding && loading && !isOnboardingPage) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
|
||||
<Card className="w-[350px] bg-background/60 backdrop-blur-sm">
|
||||
<CardHeader className="pb-2">
|
||||
<CardTitle className="text-xl font-medium">Loading Configuration</CardTitle>
|
||||
<CardDescription>Checking your LLM preferences...</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="flex justify-center py-6">
|
||||
<Loader2 className="h-12 w-12 text-primary animate-spin" />
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Show error screen if there's an error loading preferences (but not on onboarding page)
|
||||
if (error && !hasCheckedOnboarding && !isOnboardingPage) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
|
||||
<Card className="w-[400px] bg-background/60 backdrop-blur-sm border-destructive/20">
|
||||
<CardHeader className="pb-2">
|
||||
<CardTitle className="text-xl font-medium text-destructive">
|
||||
Configuration Error
|
||||
</CardTitle>
|
||||
<CardDescription>Failed to load your LLM configuration</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<p className="text-sm text-muted-foreground">{error}</p>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<SidebarProvider open={open} onOpenChange={setOpen}>
|
||||
{/* Use AppSidebarProvider which fetches user, search space, and recent chats */}
|
||||
|
|
|
|||
|
|
@ -33,6 +33,12 @@ export default function DashboardLayout({
|
|||
icon: "SquareTerminal",
|
||||
items: [],
|
||||
},
|
||||
{
|
||||
title: "Manage LLMs",
|
||||
url: `/dashboard/${search_space_id}/settings`,
|
||||
icon: "Settings2",
|
||||
items: [],
|
||||
},
|
||||
|
||||
{
|
||||
title: "Documents",
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import { ArrowLeft, ArrowRight, Bot, CheckCircle, Sparkles } from "lucide-react";
|
||||
import { AnimatePresence, motion } from "motion/react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useParams, useRouter } from "next/navigation";
|
||||
import { useEffect, useState } from "react";
|
||||
import { Logo } from "@/components/Logo";
|
||||
import { AddProviderStep } from "@/components/onboard/add-provider-step";
|
||||
|
|
@ -17,13 +17,16 @@ const TOTAL_STEPS = 3;
|
|||
|
||||
const OnboardPage = () => {
|
||||
const router = useRouter();
|
||||
const { llmConfigs, loading: configsLoading, refreshConfigs } = useLLMConfigs();
|
||||
const params = useParams();
|
||||
const searchSpaceId = Number(params.search_space_id);
|
||||
|
||||
const { llmConfigs, loading: configsLoading, refreshConfigs } = useLLMConfigs(searchSpaceId);
|
||||
const {
|
||||
preferences,
|
||||
loading: preferencesLoading,
|
||||
isOnboardingComplete,
|
||||
refreshPreferences,
|
||||
} = useLLMPreferences();
|
||||
} = useLLMPreferences(searchSpaceId);
|
||||
const [currentStep, setCurrentStep] = useState(1);
|
||||
const [hasUserProgressed, setHasUserProgressed] = useState(false);
|
||||
|
||||
|
|
@ -44,11 +47,23 @@ const OnboardPage = () => {
|
|||
}, [currentStep]);
|
||||
|
||||
// Redirect to dashboard if onboarding is already complete and user hasn't progressed (fresh page load)
|
||||
// But only check once to avoid redirect loops
|
||||
useEffect(() => {
|
||||
if (!preferencesLoading && isOnboardingComplete() && !hasUserProgressed) {
|
||||
router.push("/dashboard");
|
||||
if (!preferencesLoading && !configsLoading && isOnboardingComplete() && !hasUserProgressed) {
|
||||
// Small delay to ensure the check is stable
|
||||
const timer = setTimeout(() => {
|
||||
router.push(`/dashboard/${searchSpaceId}`);
|
||||
}, 100);
|
||||
return () => clearTimeout(timer);
|
||||
}
|
||||
}, [preferencesLoading, isOnboardingComplete, hasUserProgressed, router]);
|
||||
}, [
|
||||
preferencesLoading,
|
||||
configsLoading,
|
||||
isOnboardingComplete,
|
||||
hasUserProgressed,
|
||||
router,
|
||||
searchSpaceId,
|
||||
]);
|
||||
|
||||
const progress = (currentStep / TOTAL_STEPS) * 100;
|
||||
|
||||
|
|
@ -80,7 +95,7 @@ const OnboardPage = () => {
|
|||
};
|
||||
|
||||
const handleComplete = () => {
|
||||
router.push("/dashboard");
|
||||
router.push(`/dashboard/${searchSpaceId}/documents`);
|
||||
};
|
||||
|
||||
if (configsLoading || preferencesLoading) {
|
||||
|
|
@ -184,12 +199,18 @@ const OnboardPage = () => {
|
|||
>
|
||||
{currentStep === 1 && (
|
||||
<AddProviderStep
|
||||
searchSpaceId={searchSpaceId}
|
||||
onConfigCreated={refreshConfigs}
|
||||
onConfigDeleted={refreshConfigs}
|
||||
/>
|
||||
)}
|
||||
{currentStep === 2 && <AssignRolesStep onPreferencesUpdated={refreshPreferences} />}
|
||||
{currentStep === 3 && <CompletionStep />}
|
||||
{currentStep === 2 && (
|
||||
<AssignRolesStep
|
||||
searchSpaceId={searchSpaceId}
|
||||
onPreferencesUpdated={refreshPreferences}
|
||||
/>
|
||||
)}
|
||||
{currentStep === 3 && <CompletionStep searchSpaceId={searchSpaceId} />}
|
||||
</motion.div>
|
||||
</AnimatePresence>
|
||||
</CardContent>
|
||||
|
|
@ -1,14 +1,16 @@
|
|||
"use client";
|
||||
|
||||
import { ArrowLeft, Bot, Brain, Settings } from "lucide-react"; // Import ArrowLeft icon
|
||||
import { useRouter } from "next/navigation"; // Add this import
|
||||
import { ArrowLeft, Bot, Brain, Settings } from "lucide-react";
|
||||
import { useParams, useRouter } from "next/navigation";
|
||||
import { LLMRoleManager } from "@/components/settings/llm-role-manager";
|
||||
import { ModelConfigManager } from "@/components/settings/model-config-manager";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
|
||||
export default function SettingsPage() {
|
||||
const router = useRouter(); // Initialize router
|
||||
const router = useRouter();
|
||||
const params = useParams();
|
||||
const searchSpaceId = Number(params.search_space_id);
|
||||
|
||||
return (
|
||||
<div className="min-h-screen bg-background">
|
||||
|
|
@ -19,7 +21,7 @@ export default function SettingsPage() {
|
|||
<div className="flex items-center space-x-4">
|
||||
{/* Back Button */}
|
||||
<button
|
||||
onClick={() => router.push("/dashboard")}
|
||||
onClick={() => router.push(`/dashboard/${searchSpaceId}`)}
|
||||
className="flex items-center justify-center h-10 w-10 rounded-lg bg-primary/10 hover:bg-primary/20 transition-colors"
|
||||
aria-label="Back to Dashboard"
|
||||
type="button"
|
||||
|
|
@ -32,7 +34,7 @@ export default function SettingsPage() {
|
|||
<div className="space-y-1">
|
||||
<h1 className="text-3xl font-bold tracking-tight">Settings</h1>
|
||||
<p className="text-lg text-muted-foreground">
|
||||
Manage your LLM configurations and role assignments.
|
||||
Manage your LLM configurations and role assignments for this search space.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -57,11 +59,11 @@ export default function SettingsPage() {
|
|||
</div>
|
||||
|
||||
<TabsContent value="models" className="space-y-6">
|
||||
<ModelConfigManager />
|
||||
<ModelConfigManager searchSpaceId={searchSpaceId} />
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="roles" className="space-y-6">
|
||||
<LLMRoleManager />
|
||||
<LLMRoleManager searchSpaceId={searchSpaceId} />
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
</div>
|
||||
|
|
@ -4,7 +4,6 @@ import { Loader2 } from "lucide-react";
|
|||
import { useRouter } from "next/navigation";
|
||||
import { useEffect, useState } from "react";
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { useLLMPreferences } from "@/hooks/use-llm-configs";
|
||||
|
||||
interface DashboardLayoutProps {
|
||||
children: React.ReactNode;
|
||||
|
|
@ -12,7 +11,6 @@ interface DashboardLayoutProps {
|
|||
|
||||
export default function DashboardLayout({ children }: DashboardLayoutProps) {
|
||||
const router = useRouter();
|
||||
const { loading, error, isOnboardingComplete } = useLLMPreferences();
|
||||
const [isCheckingAuth, setIsCheckingAuth] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
|
|
@ -25,23 +23,14 @@ export default function DashboardLayout({ children }: DashboardLayoutProps) {
|
|||
setIsCheckingAuth(false);
|
||||
}, [router]);
|
||||
|
||||
useEffect(() => {
|
||||
// Wait for preferences to load, then check if onboarding is complete
|
||||
if (!loading && !error && !isCheckingAuth) {
|
||||
if (!isOnboardingComplete()) {
|
||||
router.push("/onboard");
|
||||
}
|
||||
}
|
||||
}, [loading, error, isCheckingAuth, isOnboardingComplete, router]);
|
||||
|
||||
// Show loading screen while checking authentication or loading preferences
|
||||
if (isCheckingAuth || loading) {
|
||||
// Show loading screen while checking authentication
|
||||
if (isCheckingAuth) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
|
||||
<Card className="w-[350px] bg-background/60 backdrop-blur-sm">
|
||||
<CardHeader className="pb-2">
|
||||
<CardTitle className="text-xl font-medium">Loading Dashboard</CardTitle>
|
||||
<CardDescription>Checking your configuration...</CardDescription>
|
||||
<CardDescription>Checking authentication...</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="flex justify-center py-6">
|
||||
<Loader2 className="h-12 w-12 text-primary animate-spin" />
|
||||
|
|
@ -51,42 +40,5 @@ export default function DashboardLayout({ children }: DashboardLayoutProps) {
|
|||
);
|
||||
}
|
||||
|
||||
// Show error screen if there's an error loading preferences
|
||||
if (error) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
|
||||
<Card className="w-[400px] bg-background/60 backdrop-blur-sm border-destructive/20">
|
||||
<CardHeader className="pb-2">
|
||||
<CardTitle className="text-xl font-medium text-destructive">
|
||||
Configuration Error
|
||||
</CardTitle>
|
||||
<CardDescription>Failed to load your LLM configuration</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<p className="text-sm text-muted-foreground">{error}</p>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Only render children if onboarding is complete
|
||||
if (isOnboardingComplete()) {
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
// This should not be reached due to redirect, but just in case
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center min-h-screen space-y-4">
|
||||
<Card className="w-[350px] bg-background/60 backdrop-blur-sm">
|
||||
<CardHeader className="pb-2">
|
||||
<CardTitle className="text-xl font-medium">Redirecting...</CardTitle>
|
||||
<CardDescription>Taking you to complete your setup</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="flex justify-center py-6">
|
||||
<Loader2 className="h-12 w-12 text-primary animate-spin" />
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -66,10 +66,6 @@ export function UserDropdown({
|
|||
</DropdownMenuItem>
|
||||
</DropdownMenuGroup>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuItem onClick={() => router.push(`/settings`)}>
|
||||
<Settings className="mr-2 h-4 w-4" />
|
||||
Settings
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem onClick={handleLogout}>
|
||||
<LogOut className="mr-2 h-4 w-4" />
|
||||
Log out
|
||||
|
|
|
|||
|
|
@ -332,8 +332,11 @@ const ResearchModeSelector = React.memo(
|
|||
ResearchModeSelector.displayName = "ResearchModeSelector";
|
||||
|
||||
const LLMSelector = React.memo(() => {
|
||||
const { llmConfigs, loading: llmLoading, error } = useLLMConfigs();
|
||||
const { preferences, updatePreferences, loading: preferencesLoading } = useLLMPreferences();
|
||||
const { search_space_id } = useParams();
|
||||
const searchSpaceId = Number(search_space_id);
|
||||
|
||||
const { llmConfigs, loading: llmLoading, error } = useLLMConfigs(searchSpaceId);
|
||||
const { preferences, updatePreferences, loading: preferencesLoading } = useLLMPreferences(searchSpaceId);
|
||||
|
||||
const isLoading = llmLoading || preferencesLoading;
|
||||
|
||||
|
|
|
|||
|
|
@ -23,12 +23,17 @@ import { type CreateLLMConfig, useLLMConfigs } from "@/hooks/use-llm-configs";
|
|||
import InferenceParamsEditor from "../inference-params-editor";
|
||||
|
||||
interface AddProviderStepProps {
|
||||
searchSpaceId: number;
|
||||
onConfigCreated?: () => void;
|
||||
onConfigDeleted?: () => void;
|
||||
}
|
||||
|
||||
export function AddProviderStep({ onConfigCreated, onConfigDeleted }: AddProviderStepProps) {
|
||||
const { llmConfigs, createLLMConfig, deleteLLMConfig } = useLLMConfigs();
|
||||
export function AddProviderStep({
|
||||
searchSpaceId,
|
||||
onConfigCreated,
|
||||
onConfigDeleted,
|
||||
}: AddProviderStepProps) {
|
||||
const { llmConfigs, createLLMConfig, deleteLLMConfig } = useLLMConfigs(searchSpaceId);
|
||||
const [isAddingNew, setIsAddingNew] = useState(false);
|
||||
const [formData, setFormData] = useState<CreateLLMConfig>({
|
||||
name: "",
|
||||
|
|
@ -38,6 +43,7 @@ export function AddProviderStep({ onConfigCreated, onConfigDeleted }: AddProvide
|
|||
api_key: "",
|
||||
api_base: "",
|
||||
litellm_params: {},
|
||||
search_space_id: searchSpaceId,
|
||||
});
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
|
||||
|
|
@ -65,6 +71,7 @@ export function AddProviderStep({ onConfigCreated, onConfigDeleted }: AddProvide
|
|||
api_key: "",
|
||||
api_base: "",
|
||||
litellm_params: {},
|
||||
search_space_id: searchSpaceId,
|
||||
});
|
||||
setIsAddingNew(false);
|
||||
// Notify parent component that a config was created
|
||||
|
|
@ -253,7 +260,6 @@ export function AddProviderStep({ onConfigCreated, onConfigDeleted }: AddProvide
|
|||
/>
|
||||
</div>
|
||||
|
||||
|
||||
<div className="flex gap-2 pt-4">
|
||||
<Button type="submit" disabled={isSubmitting}>
|
||||
{isSubmitting ? "Adding..." : "Add Provider"}
|
||||
|
|
|
|||
|
|
@ -41,12 +41,13 @@ const ROLE_DESCRIPTIONS = {
|
|||
};
|
||||
|
||||
interface AssignRolesStepProps {
|
||||
searchSpaceId: number;
|
||||
onPreferencesUpdated?: () => Promise<void>;
|
||||
}
|
||||
|
||||
export function AssignRolesStep({ onPreferencesUpdated }: AssignRolesStepProps) {
|
||||
const { llmConfigs } = useLLMConfigs();
|
||||
const { preferences, updatePreferences } = useLLMPreferences();
|
||||
export function AssignRolesStep({ searchSpaceId, onPreferencesUpdated }: AssignRolesStepProps) {
|
||||
const { llmConfigs } = useLLMConfigs(searchSpaceId);
|
||||
const { preferences, updatePreferences } = useLLMPreferences(searchSpaceId);
|
||||
|
||||
const [assignments, setAssignments] = useState({
|
||||
long_context_llm_id: preferences.long_context_llm_id || "",
|
||||
|
|
|
|||
|
|
@ -12,9 +12,13 @@ const ROLE_ICONS = {
|
|||
strategic: Bot,
|
||||
};
|
||||
|
||||
export function CompletionStep() {
|
||||
const { llmConfigs } = useLLMConfigs();
|
||||
const { preferences } = useLLMPreferences();
|
||||
interface CompletionStepProps {
|
||||
searchSpaceId: number;
|
||||
}
|
||||
|
||||
export function CompletionStep({ searchSpaceId }: CompletionStepProps) {
|
||||
const { llmConfigs } = useLLMConfigs(searchSpaceId);
|
||||
const { preferences } = useLLMPreferences(searchSpaceId);
|
||||
|
||||
const assignedConfigs = {
|
||||
long_context: llmConfigs.find((c) => c.id === preferences.long_context_llm_id),
|
||||
|
|
|
|||
|
|
@ -56,20 +56,24 @@ const ROLE_DESCRIPTIONS = {
|
|||
},
|
||||
};
|
||||
|
||||
export function LLMRoleManager() {
|
||||
interface LLMRoleManagerProps {
|
||||
searchSpaceId: number;
|
||||
}
|
||||
|
||||
export function LLMRoleManager({ searchSpaceId }: LLMRoleManagerProps) {
|
||||
const {
|
||||
llmConfigs,
|
||||
loading: configsLoading,
|
||||
error: configsError,
|
||||
refreshConfigs,
|
||||
} = useLLMConfigs();
|
||||
} = useLLMConfigs(searchSpaceId);
|
||||
const {
|
||||
preferences,
|
||||
loading: preferencesLoading,
|
||||
error: preferencesError,
|
||||
updatePreferences,
|
||||
refreshPreferences,
|
||||
} = useLLMPreferences();
|
||||
} = useLLMPreferences(searchSpaceId);
|
||||
|
||||
const [assignments, setAssignments] = useState({
|
||||
long_context_llm_id: preferences.long_context_llm_id || "",
|
||||
|
|
|
|||
|
|
@ -41,7 +41,11 @@ import { LLM_PROVIDERS } from "@/contracts/enums/llm-providers";
|
|||
import { type CreateLLMConfig, type LLMConfig, useLLMConfigs } from "@/hooks/use-llm-configs";
|
||||
import InferenceParamsEditor from "../inference-params-editor";
|
||||
|
||||
export function ModelConfigManager() {
|
||||
interface ModelConfigManagerProps {
|
||||
searchSpaceId: number;
|
||||
}
|
||||
|
||||
export function ModelConfigManager({ searchSpaceId }: ModelConfigManagerProps) {
|
||||
const {
|
||||
llmConfigs,
|
||||
loading,
|
||||
|
|
@ -50,7 +54,7 @@ export function ModelConfigManager() {
|
|||
updateLLMConfig,
|
||||
deleteLLMConfig,
|
||||
refreshConfigs,
|
||||
} = useLLMConfigs();
|
||||
} = useLLMConfigs(searchSpaceId);
|
||||
const [isAddingNew, setIsAddingNew] = useState(false);
|
||||
const [editingConfig, setEditingConfig] = useState<LLMConfig | null>(null);
|
||||
const [showApiKey, setShowApiKey] = useState<Record<number, boolean>>({});
|
||||
|
|
@ -62,6 +66,7 @@ export function ModelConfigManager() {
|
|||
api_key: "",
|
||||
api_base: "",
|
||||
litellm_params: {},
|
||||
search_space_id: searchSpaceId,
|
||||
});
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
|
||||
|
|
@ -76,9 +81,10 @@ export function ModelConfigManager() {
|
|||
api_key: editingConfig.api_key,
|
||||
api_base: editingConfig.api_base || "",
|
||||
litellm_params: editingConfig.litellm_params || {},
|
||||
search_space_id: searchSpaceId,
|
||||
});
|
||||
}
|
||||
}, [editingConfig]);
|
||||
}, [editingConfig, searchSpaceId]);
|
||||
|
||||
const handleInputChange = (field: keyof CreateLLMConfig, value: string) => {
|
||||
setFormData((prev) => ({ ...prev, [field]: value }));
|
||||
|
|
@ -113,6 +119,7 @@ export function ModelConfigManager() {
|
|||
api_key: "",
|
||||
api_base: "",
|
||||
litellm_params: {},
|
||||
search_space_id: searchSpaceId,
|
||||
});
|
||||
setIsAddingNew(false);
|
||||
setEditingConfig(null);
|
||||
|
|
@ -426,6 +433,7 @@ export function ModelConfigManager() {
|
|||
api_key: "",
|
||||
api_base: "",
|
||||
litellm_params: {},
|
||||
search_space_id: searchSpaceId,
|
||||
});
|
||||
}
|
||||
}}
|
||||
|
|
@ -462,18 +470,12 @@ export function ModelConfigManager() {
|
|||
value={formData.provider}
|
||||
onValueChange={(value) => handleInputChange("provider", value)}
|
||||
>
|
||||
<SelectTrigger className="h-auto min-h-[2.5rem] py-2">
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="Select a provider">
|
||||
{formData.provider && (
|
||||
<div className="flex items-center space-x-2 py-1">
|
||||
<div className="font-medium">
|
||||
{LLM_PROVIDERS.find((p) => p.value === formData.provider)?.label}
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground">•</div>
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{LLM_PROVIDERS.find((p) => p.value === formData.provider)?.description}
|
||||
</div>
|
||||
</div>
|
||||
<span className="font-medium">
|
||||
{LLM_PROVIDERS.find((p) => p.value === formData.provider)?.label}
|
||||
</span>
|
||||
)}
|
||||
</SelectValue>
|
||||
</SelectTrigger>
|
||||
|
|
@ -549,7 +551,7 @@ export function ModelConfigManager() {
|
|||
<InferenceParamsEditor
|
||||
params={formData.litellm_params || {}}
|
||||
setParams={(newParams) =>
|
||||
setFormData((prev) => ({ ...prev, litellm_params: newParams }))
|
||||
setFormData((prev) => ({ ...prev, litellm_params: newParams }))
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
|
|
@ -578,6 +580,7 @@ export function ModelConfigManager() {
|
|||
api_key: "",
|
||||
api_base: "",
|
||||
litellm_params: {},
|
||||
search_space_id: searchSpaceId,
|
||||
});
|
||||
}}
|
||||
disabled={isSubmitting}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ export interface LLMConfig {
|
|||
api_base?: string;
|
||||
litellm_params?: Record<string, any>;
|
||||
created_at: string;
|
||||
user_id: string;
|
||||
search_space_id: number;
|
||||
}
|
||||
|
||||
export interface LLMPreferences {
|
||||
|
|
@ -32,6 +32,7 @@ export interface CreateLLMConfig {
|
|||
api_key: string;
|
||||
api_base?: string;
|
||||
litellm_params?: Record<string, any>;
|
||||
search_space_id: number;
|
||||
}
|
||||
|
||||
export interface UpdateLLMConfig {
|
||||
|
|
@ -44,16 +45,21 @@ export interface UpdateLLMConfig {
|
|||
litellm_params?: Record<string, any>;
|
||||
}
|
||||
|
||||
export function useLLMConfigs() {
|
||||
export function useLLMConfigs(searchSpaceId: number | null) {
|
||||
const [llmConfigs, setLlmConfigs] = useState<LLMConfig[]>([]);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const fetchLLMConfigs = async () => {
|
||||
if (!searchSpaceId) {
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
setLoading(true);
|
||||
const response = await fetch(
|
||||
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/llm-configs/`,
|
||||
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/llm-configs/?search_space_id=${searchSpaceId}`,
|
||||
{
|
||||
headers: {
|
||||
Authorization: `Bearer ${localStorage.getItem("surfsense_bearer_token")}`,
|
||||
|
|
@ -79,7 +85,7 @@ export function useLLMConfigs() {
|
|||
|
||||
useEffect(() => {
|
||||
fetchLLMConfigs();
|
||||
}, []);
|
||||
}, [searchSpaceId]);
|
||||
|
||||
const createLLMConfig = async (config: CreateLLMConfig): Promise<LLMConfig | null> => {
|
||||
try {
|
||||
|
|
@ -181,16 +187,21 @@ export function useLLMConfigs() {
|
|||
};
|
||||
}
|
||||
|
||||
export function useLLMPreferences() {
|
||||
export function useLLMPreferences(searchSpaceId: number | null) {
|
||||
const [preferences, setPreferences] = useState<LLMPreferences>({});
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const fetchPreferences = async () => {
|
||||
if (!searchSpaceId) {
|
||||
setLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
setLoading(true);
|
||||
const response = await fetch(
|
||||
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/users/me/llm-preferences`,
|
||||
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/llm-preferences`,
|
||||
{
|
||||
headers: {
|
||||
Authorization: `Bearer ${localStorage.getItem("surfsense_bearer_token")}`,
|
||||
|
|
@ -216,12 +227,17 @@ export function useLLMPreferences() {
|
|||
|
||||
useEffect(() => {
|
||||
fetchPreferences();
|
||||
}, []);
|
||||
}, [searchSpaceId]);
|
||||
|
||||
const updatePreferences = async (newPreferences: Partial<LLMPreferences>): Promise<boolean> => {
|
||||
if (!searchSpaceId) {
|
||||
toast.error("Search space ID is required");
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/users/me/llm-preferences`,
|
||||
`${process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL}/api/v1/search-spaces/${searchSpaceId}/llm-preferences`,
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue