diff --git a/surfsense_backend/alembic/versions/113_add_system_prompt_slug_to_prompts.py b/surfsense_backend/alembic/versions/113_add_system_prompt_slug_to_prompts.py index 9a30d4c04..ed020518b 100644 --- a/surfsense_backend/alembic/versions/113_add_system_prompt_slug_to_prompts.py +++ b/surfsense_backend/alembic/versions/113_add_system_prompt_slug_to_prompts.py @@ -1,4 +1,4 @@ -"""add default_prompt_slug, version and drop icon from prompts +"""add default_prompt_slug, version, drop icon, seed defaults Revision ID: 113 Revises: 112 @@ -6,6 +6,8 @@ Revises: 112 from collections.abc import Sequence +import sqlalchemy as sa + from alembic import op revision: str = "113" @@ -13,11 +15,61 @@ down_revision: str | None = "112" branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None +DEFAULTS = [ + ( + "fix-grammar", + "Fix grammar", + "Fix the grammar and spelling in the following text. Return only the corrected text, nothing else.\n\n{selection}", + "transform", + ), + ( + "make-shorter", + "Make shorter", + "Make the following text more concise while preserving its meaning. Return only the shortened text, nothing else.\n\n{selection}", + "transform", + ), + ( + "translate", + "Translate", + "Translate the following text to English. If it is already in English, translate it to French. Return only the translation, nothing else.\n\n{selection}", + "transform", + ), + ( + "rewrite", + "Rewrite", + "Rewrite the following text to improve clarity and readability. Return only the rewritten text, nothing else.\n\n{selection}", + "transform", + ), + ( + "summarize", + "Summarize", + "Summarize the following text concisely. Return only the summary, nothing else.\n\n{selection}", + "transform", + ), + ( + "explain", + "Explain", + "Explain the following text in simple terms:\n\n{selection}", + "explore", + ), + ( + "ask-knowledge-base", + "Ask my knowledge base", + "Search my knowledge base for information related to:\n\n{selection}", + "explore", + ), + ( + "look-up-web", + "Look up on the web", + "Search the web for information about:\n\n{selection}", + "explore", + ), +] + def upgrade() -> None: op.execute( - "ALTER TABLE prompts ADD COLUMN IF NOT EXISTS" - " default_prompt_slug VARCHAR(100)" + "ALTER TABLE prompts ADD COLUMN IF NOT EXISTS default_prompt_slug VARCHAR(100)" ) op.execute( "CREATE INDEX IF NOT EXISTS ix_prompts_default_prompt_slug" @@ -33,14 +85,36 @@ def upgrade() -> None: ) op.execute("ALTER TABLE prompts DROP COLUMN IF EXISTS icon") + conn = op.get_bind() + users = conn.execute(sa.text('SELECT id FROM "user"')).fetchall() + + for user_row in users: + user_id = user_row[0] + for slug, name, prompt, mode in DEFAULTS: + conn.execute( + sa.text( + "INSERT INTO prompts" + " (user_id, default_prompt_slug, name, prompt, mode, version, is_public, created_at)" + " VALUES (:user_id, :slug, :name, :prompt, :mode::prompt_mode, :version, false, now())" + " ON CONFLICT (user_id, default_prompt_slug) DO NOTHING" + ), + { + "user_id": user_id, + "slug": slug, + "name": name, + "prompt": prompt, + "mode": mode, + "version": 1, + }, + ) + def downgrade() -> None: + op.execute("DELETE FROM prompts WHERE default_prompt_slug IS NOT NULL") op.execute("ALTER TABLE prompts ADD COLUMN IF NOT EXISTS icon VARCHAR(50)") op.execute("ALTER TABLE prompts DROP COLUMN IF EXISTS version") op.execute( "ALTER TABLE prompts DROP CONSTRAINT IF EXISTS uq_prompt_user_default_slug" ) op.execute("DROP INDEX IF EXISTS ix_prompts_default_prompt_slug") - op.execute( - "ALTER TABLE prompts DROP COLUMN IF EXISTS default_prompt_slug" - ) + op.execute("ALTER TABLE prompts DROP COLUMN IF EXISTS default_prompt_slug") diff --git a/surfsense_backend/app/users.py b/surfsense_backend/app/users.py index d24a6faf1..66e0cc8dd 100644 --- a/surfsense_backend/app/users.py +++ b/surfsense_backend/app/users.py @@ -17,6 +17,7 @@ from sqlalchemy import update from app.config import config from app.db import ( + Prompt, SearchSpace, SearchSpaceMembership, SearchSpaceRole, @@ -25,6 +26,7 @@ from app.db import ( get_default_roles_config, get_user_db, ) +from app.prompts.system_defaults import SYSTEM_PROMPT_DEFAULTS from app.utils.refresh_tokens import create_refresh_token logger = logging.getLogger(__name__) @@ -188,6 +190,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): ) session.add(owner_membership) + for default in SYSTEM_PROMPT_DEFAULTS: + session.add( + Prompt( + user_id=user.id, + default_prompt_slug=default["slug"], + name=default["name"], + prompt=default["prompt"], + mode=default["mode"], + version=default["version"], + ) + ) + await session.commit() logger.info( f"Created default search space (ID: {default_search_space.id}) for user {user.id}"