mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
Seed default prompts on registration and for existing users
This commit is contained in:
parent
0c975a6f80
commit
11387268a7
2 changed files with 94 additions and 6 deletions
|
|
@ -1,4 +1,4 @@
|
||||||
"""add default_prompt_slug, version and drop icon from prompts
|
"""add default_prompt_slug, version, drop icon, seed defaults
|
||||||
|
|
||||||
Revision ID: 113
|
Revision ID: 113
|
||||||
Revises: 112
|
Revises: 112
|
||||||
|
|
@ -6,6 +6,8 @@ Revises: 112
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
revision: str = "113"
|
revision: str = "113"
|
||||||
|
|
@ -13,11 +15,61 @@ down_revision: str | None = "112"
|
||||||
branch_labels: str | Sequence[str] | None = None
|
branch_labels: str | Sequence[str] | None = None
|
||||||
depends_on: str | Sequence[str] | None = None
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
DEFAULTS = [
|
||||||
|
(
|
||||||
|
"fix-grammar",
|
||||||
|
"Fix grammar",
|
||||||
|
"Fix the grammar and spelling in the following text. Return only the corrected text, nothing else.\n\n{selection}",
|
||||||
|
"transform",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"make-shorter",
|
||||||
|
"Make shorter",
|
||||||
|
"Make the following text more concise while preserving its meaning. Return only the shortened text, nothing else.\n\n{selection}",
|
||||||
|
"transform",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"translate",
|
||||||
|
"Translate",
|
||||||
|
"Translate the following text to English. If it is already in English, translate it to French. Return only the translation, nothing else.\n\n{selection}",
|
||||||
|
"transform",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"rewrite",
|
||||||
|
"Rewrite",
|
||||||
|
"Rewrite the following text to improve clarity and readability. Return only the rewritten text, nothing else.\n\n{selection}",
|
||||||
|
"transform",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"summarize",
|
||||||
|
"Summarize",
|
||||||
|
"Summarize the following text concisely. Return only the summary, nothing else.\n\n{selection}",
|
||||||
|
"transform",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"explain",
|
||||||
|
"Explain",
|
||||||
|
"Explain the following text in simple terms:\n\n{selection}",
|
||||||
|
"explore",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"ask-knowledge-base",
|
||||||
|
"Ask my knowledge base",
|
||||||
|
"Search my knowledge base for information related to:\n\n{selection}",
|
||||||
|
"explore",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"look-up-web",
|
||||||
|
"Look up on the web",
|
||||||
|
"Search the web for information about:\n\n{selection}",
|
||||||
|
"explore",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
op.execute(
|
op.execute(
|
||||||
"ALTER TABLE prompts ADD COLUMN IF NOT EXISTS"
|
"ALTER TABLE prompts ADD COLUMN IF NOT EXISTS default_prompt_slug VARCHAR(100)"
|
||||||
" default_prompt_slug VARCHAR(100)"
|
|
||||||
)
|
)
|
||||||
op.execute(
|
op.execute(
|
||||||
"CREATE INDEX IF NOT EXISTS ix_prompts_default_prompt_slug"
|
"CREATE INDEX IF NOT EXISTS ix_prompts_default_prompt_slug"
|
||||||
|
|
@ -33,14 +85,36 @@ def upgrade() -> None:
|
||||||
)
|
)
|
||||||
op.execute("ALTER TABLE prompts DROP COLUMN IF EXISTS icon")
|
op.execute("ALTER TABLE prompts DROP COLUMN IF EXISTS icon")
|
||||||
|
|
||||||
|
conn = op.get_bind()
|
||||||
|
users = conn.execute(sa.text('SELECT id FROM "user"')).fetchall()
|
||||||
|
|
||||||
|
for user_row in users:
|
||||||
|
user_id = user_row[0]
|
||||||
|
for slug, name, prompt, mode in DEFAULTS:
|
||||||
|
conn.execute(
|
||||||
|
sa.text(
|
||||||
|
"INSERT INTO prompts"
|
||||||
|
" (user_id, default_prompt_slug, name, prompt, mode, version, is_public, created_at)"
|
||||||
|
" VALUES (:user_id, :slug, :name, :prompt, :mode::prompt_mode, :version, false, now())"
|
||||||
|
" ON CONFLICT (user_id, default_prompt_slug) DO NOTHING"
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"slug": slug,
|
||||||
|
"name": name,
|
||||||
|
"prompt": prompt,
|
||||||
|
"mode": mode,
|
||||||
|
"version": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
|
op.execute("DELETE FROM prompts WHERE default_prompt_slug IS NOT NULL")
|
||||||
op.execute("ALTER TABLE prompts ADD COLUMN IF NOT EXISTS icon VARCHAR(50)")
|
op.execute("ALTER TABLE prompts ADD COLUMN IF NOT EXISTS icon VARCHAR(50)")
|
||||||
op.execute("ALTER TABLE prompts DROP COLUMN IF EXISTS version")
|
op.execute("ALTER TABLE prompts DROP COLUMN IF EXISTS version")
|
||||||
op.execute(
|
op.execute(
|
||||||
"ALTER TABLE prompts DROP CONSTRAINT IF EXISTS uq_prompt_user_default_slug"
|
"ALTER TABLE prompts DROP CONSTRAINT IF EXISTS uq_prompt_user_default_slug"
|
||||||
)
|
)
|
||||||
op.execute("DROP INDEX IF EXISTS ix_prompts_default_prompt_slug")
|
op.execute("DROP INDEX IF EXISTS ix_prompts_default_prompt_slug")
|
||||||
op.execute(
|
op.execute("ALTER TABLE prompts DROP COLUMN IF EXISTS default_prompt_slug")
|
||||||
"ALTER TABLE prompts DROP COLUMN IF EXISTS default_prompt_slug"
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ from sqlalchemy import update
|
||||||
|
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.db import (
|
from app.db import (
|
||||||
|
Prompt,
|
||||||
SearchSpace,
|
SearchSpace,
|
||||||
SearchSpaceMembership,
|
SearchSpaceMembership,
|
||||||
SearchSpaceRole,
|
SearchSpaceRole,
|
||||||
|
|
@ -25,6 +26,7 @@ from app.db import (
|
||||||
get_default_roles_config,
|
get_default_roles_config,
|
||||||
get_user_db,
|
get_user_db,
|
||||||
)
|
)
|
||||||
|
from app.prompts.system_defaults import SYSTEM_PROMPT_DEFAULTS
|
||||||
from app.utils.refresh_tokens import create_refresh_token
|
from app.utils.refresh_tokens import create_refresh_token
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -188,6 +190,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||||
)
|
)
|
||||||
session.add(owner_membership)
|
session.add(owner_membership)
|
||||||
|
|
||||||
|
for default in SYSTEM_PROMPT_DEFAULTS:
|
||||||
|
session.add(
|
||||||
|
Prompt(
|
||||||
|
user_id=user.id,
|
||||||
|
default_prompt_slug=default["slug"],
|
||||||
|
name=default["name"],
|
||||||
|
prompt=default["prompt"],
|
||||||
|
mode=default["mode"],
|
||||||
|
version=default["version"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Created default search space (ID: {default_search_space.id}) for user {user.id}"
|
f"Created default search space (ID: {default_search_space.id}) for user {user.id}"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue