diff --git a/surfsense_backend/alembic/versions/40_move_llm_preferences_to_searchspace.py b/surfsense_backend/alembic/versions/40_move_llm_preferences_to_searchspace.py index 1067cffcc..5f6ccb852 100644 --- a/surfsense_backend/alembic/versions/40_move_llm_preferences_to_searchspace.py +++ b/surfsense_backend/alembic/versions/40_move_llm_preferences_to_searchspace.py @@ -1,19 +1,6 @@ -"""Move LLM preferences from user-level to search space level - -Revision ID: 40 -Revises: 39 -Create Date: 2024-11-27 - -This migration moves LLM preferences (long_context_llm_id, fast_llm_id, strategic_llm_id) -from the user_search_space_preferences table to the searchspaces table itself. - -This change supports the RBAC model where LLM preferences are shared by all members -of a search space, rather than being per-user. -""" - import sqlalchemy as sa - from alembic import op +from sqlalchemy import inspect # revision identifiers, used by Alembic. revision = "40" @@ -23,26 +10,32 @@ depends_on = None def upgrade(): - # Add LLM preference columns to searchspaces table - op.add_column( - "searchspaces", - sa.Column("long_context_llm_id", sa.Integer(), nullable=True), - ) - op.add_column( - "searchspaces", - sa.Column("fast_llm_id", sa.Integer(), nullable=True), - ) - op.add_column( - "searchspaces", - sa.Column("strategic_llm_id", sa.Integer(), nullable=True), - ) + conn = op.get_bind() + inspector = inspect(conn) - # Migrate existing preferences from user_search_space_preferences to searchspaces - # We take the owner's preferences (the user who created the search space) - connection = op.get_bind() + existing_cols = {col["name"] for col in inspector.get_columns("searchspaces")} - # Get all search spaces and their owner's preferences - connection.execute( + # Add columns only if they don't already exist + if "long_context_llm_id" not in existing_cols: + op.add_column( + "searchspaces", + sa.Column("long_context_llm_id", sa.Integer(), nullable=True), + ) + + if "fast_llm_id" not in existing_cols: + op.add_column( + "searchspaces", + sa.Column("fast_llm_id", sa.Integer(), nullable=True), + ) + + if "strategic_llm_id" not in existing_cols: + op.add_column( + "searchspaces", + sa.Column("strategic_llm_id", sa.Integer(), nullable=True), + ) + + # Migrate existing data + conn.execute( sa.text(""" UPDATE searchspaces ss SET @@ -57,7 +50,16 @@ def upgrade(): def downgrade(): - # Remove LLM preference columns from searchspaces table - op.drop_column("searchspaces", "strategic_llm_id") - op.drop_column("searchspaces", "fast_llm_id") - op.drop_column("searchspaces", "long_context_llm_id") + conn = op.get_bind() + inspector = inspect(conn) + existing_cols = {col["name"] for col in inspector.get_columns("searchspaces")} + + # Drop columns only if they exist + if "strategic_llm_id" in existing_cols: + op.drop_column("searchspaces", "strategic_llm_id") + + if "fast_llm_id" in existing_cols: + op.drop_column("searchspaces", "fast_llm_id") + + if "long_context_llm_id" in existing_cols: + op.drop_column("searchspaces", "long_context_llm_id")