mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-09 15:52:40 +02:00
Simplify prompts to seed-based CRUD with version tracking
This commit is contained in:
parent
80d096db32
commit
0c975a6f80
6 changed files with 56 additions and 148 deletions
|
|
@ -1,4 +1,4 @@
|
||||||
"""add system_prompt_slug and drop icon from prompts
|
"""add default_prompt_slug, version and drop icon from prompts
|
||||||
|
|
||||||
Revision ID: 113
|
Revision ID: 113
|
||||||
Revises: 112
|
Revises: 112
|
||||||
|
|
@ -17,23 +17,30 @@ depends_on: str | Sequence[str] | None = None
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
op.execute(
|
op.execute(
|
||||||
"ALTER TABLE prompts ADD COLUMN IF NOT EXISTS"
|
"ALTER TABLE prompts ADD COLUMN IF NOT EXISTS"
|
||||||
" system_prompt_slug VARCHAR(100)"
|
" default_prompt_slug VARCHAR(100)"
|
||||||
)
|
)
|
||||||
op.execute(
|
op.execute(
|
||||||
"CREATE INDEX IF NOT EXISTS ix_prompts_system_prompt_slug"
|
"CREATE INDEX IF NOT EXISTS ix_prompts_default_prompt_slug"
|
||||||
" ON prompts (system_prompt_slug)"
|
" ON prompts (default_prompt_slug)"
|
||||||
)
|
)
|
||||||
op.execute(
|
op.execute(
|
||||||
"ALTER TABLE prompts ADD CONSTRAINT uq_prompt_user_system_slug"
|
"ALTER TABLE prompts ADD CONSTRAINT uq_prompt_user_default_slug"
|
||||||
" UNIQUE (user_id, system_prompt_slug)"
|
" UNIQUE (user_id, default_prompt_slug)"
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE prompts ADD COLUMN IF NOT EXISTS"
|
||||||
|
" version INTEGER NOT NULL DEFAULT 1"
|
||||||
)
|
)
|
||||||
op.execute("ALTER TABLE prompts DROP COLUMN IF EXISTS icon")
|
op.execute("ALTER TABLE prompts DROP COLUMN IF EXISTS icon")
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
op.execute("ALTER TABLE prompts ADD COLUMN IF NOT EXISTS icon VARCHAR(50)")
|
op.execute("ALTER TABLE prompts ADD COLUMN IF NOT EXISTS icon VARCHAR(50)")
|
||||||
|
op.execute("ALTER TABLE prompts DROP COLUMN IF EXISTS version")
|
||||||
op.execute(
|
op.execute(
|
||||||
"ALTER TABLE prompts DROP CONSTRAINT IF EXISTS uq_prompt_user_system_slug"
|
"ALTER TABLE prompts DROP CONSTRAINT IF EXISTS uq_prompt_user_default_slug"
|
||||||
|
)
|
||||||
|
op.execute("DROP INDEX IF EXISTS ix_prompts_default_prompt_slug")
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE prompts DROP COLUMN IF EXISTS default_prompt_slug"
|
||||||
)
|
)
|
||||||
op.execute("DROP INDEX IF EXISTS ix_prompts_system_prompt_slug")
|
|
||||||
op.execute("ALTER TABLE prompts DROP COLUMN IF EXISTS system_prompt_slug")
|
|
||||||
|
|
|
||||||
|
|
@ -1785,8 +1785,8 @@ class Prompt(BaseModel, TimestampMixin):
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
UniqueConstraint(
|
UniqueConstraint(
|
||||||
"user_id",
|
"user_id",
|
||||||
"system_prompt_slug",
|
"default_prompt_slug",
|
||||||
name="uq_prompt_user_system_slug",
|
name="uq_prompt_user_default_slug",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1802,10 +1802,11 @@ class Prompt(BaseModel, TimestampMixin):
|
||||||
nullable=True,
|
nullable=True,
|
||||||
index=True,
|
index=True,
|
||||||
)
|
)
|
||||||
system_prompt_slug = Column(String(100), nullable=True, index=True)
|
default_prompt_slug = Column(String(100), nullable=True, index=True)
|
||||||
name = Column(String(200), nullable=False)
|
name = Column(String(200), nullable=False)
|
||||||
prompt = Column(Text, nullable=False)
|
prompt = Column(Text, nullable=False)
|
||||||
mode = Column(SQLAlchemyEnum(PromptMode), nullable=False)
|
mode = Column(SQLAlchemyEnum(PromptMode), nullable=False)
|
||||||
|
version = Column(Integer, nullable=False, default=1)
|
||||||
is_public = Column(Boolean, nullable=False, default=False)
|
is_public = Column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
user = relationship("User")
|
user = relationship("User")
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
SYSTEM_PROMPT_DEFAULTS: list[dict] = [
|
SYSTEM_PROMPT_DEFAULTS: list[dict] = [
|
||||||
{
|
{
|
||||||
"slug": "fix-grammar",
|
"slug": "fix-grammar",
|
||||||
|
"version": 1,
|
||||||
"name": "Fix grammar",
|
"name": "Fix grammar",
|
||||||
"prompt": (
|
"prompt": (
|
||||||
"Fix the grammar and spelling in the following text."
|
"Fix the grammar and spelling in the following text."
|
||||||
|
|
@ -10,6 +11,7 @@ SYSTEM_PROMPT_DEFAULTS: list[dict] = [
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"slug": "make-shorter",
|
"slug": "make-shorter",
|
||||||
|
"version": 1,
|
||||||
"name": "Make shorter",
|
"name": "Make shorter",
|
||||||
"prompt": (
|
"prompt": (
|
||||||
"Make the following text more concise while preserving its meaning."
|
"Make the following text more concise while preserving its meaning."
|
||||||
|
|
@ -19,6 +21,7 @@ SYSTEM_PROMPT_DEFAULTS: list[dict] = [
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"slug": "translate",
|
"slug": "translate",
|
||||||
|
"version": 1,
|
||||||
"name": "Translate",
|
"name": "Translate",
|
||||||
"prompt": (
|
"prompt": (
|
||||||
"Translate the following text to English."
|
"Translate the following text to English."
|
||||||
|
|
@ -29,6 +32,7 @@ SYSTEM_PROMPT_DEFAULTS: list[dict] = [
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"slug": "rewrite",
|
"slug": "rewrite",
|
||||||
|
"version": 1,
|
||||||
"name": "Rewrite",
|
"name": "Rewrite",
|
||||||
"prompt": (
|
"prompt": (
|
||||||
"Rewrite the following text to improve clarity and readability."
|
"Rewrite the following text to improve clarity and readability."
|
||||||
|
|
@ -38,6 +42,7 @@ SYSTEM_PROMPT_DEFAULTS: list[dict] = [
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"slug": "summarize",
|
"slug": "summarize",
|
||||||
|
"version": 1,
|
||||||
"name": "Summarize",
|
"name": "Summarize",
|
||||||
"prompt": (
|
"prompt": (
|
||||||
"Summarize the following text concisely."
|
"Summarize the following text concisely."
|
||||||
|
|
@ -47,22 +52,23 @@ SYSTEM_PROMPT_DEFAULTS: list[dict] = [
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"slug": "explain",
|
"slug": "explain",
|
||||||
|
"version": 1,
|
||||||
"name": "Explain",
|
"name": "Explain",
|
||||||
"prompt": "Explain the following text in simple terms:\n\n{selection}",
|
"prompt": "Explain the following text in simple terms:\n\n{selection}",
|
||||||
"mode": "explore",
|
"mode": "explore",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"slug": "ask-knowledge-base",
|
"slug": "ask-knowledge-base",
|
||||||
|
"version": 1,
|
||||||
"name": "Ask my knowledge base",
|
"name": "Ask my knowledge base",
|
||||||
"prompt": "Search my knowledge base for information related to:\n\n{selection}",
|
"prompt": "Search my knowledge base for information related to:\n\n{selection}",
|
||||||
"mode": "explore",
|
"mode": "explore",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"slug": "look-up-web",
|
"slug": "look-up-web",
|
||||||
|
"version": 1,
|
||||||
"name": "Look up on the web",
|
"name": "Look up on the web",
|
||||||
"prompt": "Search the web for information about:\n\n{selection}",
|
"prompt": "Search the web for information about:\n\n{selection}",
|
||||||
"mode": "explore",
|
"mode": "explore",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
SYSTEM_PROMPT_SLUGS: set[str] = {p["slug"] for p in SYSTEM_PROMPT_DEFAULTS}
|
|
||||||
|
|
|
||||||
|
|
@ -4,35 +4,17 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from app.db import Prompt, User, get_async_session
|
from app.db import Prompt, User, get_async_session
|
||||||
from app.prompts.system_defaults import SYSTEM_PROMPT_DEFAULTS, SYSTEM_PROMPT_SLUGS
|
|
||||||
from app.schemas.prompts import (
|
from app.schemas.prompts import (
|
||||||
PromptCreate,
|
PromptCreate,
|
||||||
PromptRead,
|
PromptRead,
|
||||||
PromptUpdate,
|
PromptUpdate,
|
||||||
PublicPromptRead,
|
PublicPromptRead,
|
||||||
SystemPromptUpdate,
|
|
||||||
)
|
)
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
|
|
||||||
router = APIRouter(tags=["Prompts"])
|
router = APIRouter(tags=["Prompts"])
|
||||||
|
|
||||||
|
|
||||||
def _prompt_to_read(prompt: Prompt) -> PromptRead:
|
|
||||||
source = "system" if prompt.system_prompt_slug else "custom"
|
|
||||||
return PromptRead(
|
|
||||||
id=prompt.id,
|
|
||||||
name=prompt.name,
|
|
||||||
prompt=prompt.prompt,
|
|
||||||
mode=prompt.mode.value if hasattr(prompt.mode, "value") else prompt.mode,
|
|
||||||
search_space_id=prompt.search_space_id,
|
|
||||||
is_public=prompt.is_public,
|
|
||||||
created_at=prompt.created_at,
|
|
||||||
source=source,
|
|
||||||
system_prompt_slug=prompt.system_prompt_slug,
|
|
||||||
is_modified=source == "system",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/prompts", response_model=list[PromptRead])
|
@router.get("/prompts", response_model=list[PromptRead])
|
||||||
async def list_prompts(
|
async def list_prompts(
|
||||||
search_space_id: int | None = None,
|
search_space_id: int | None = None,
|
||||||
|
|
@ -42,35 +24,9 @@ async def list_prompts(
|
||||||
query = select(Prompt).where(Prompt.user_id == user.id)
|
query = select(Prompt).where(Prompt.user_id == user.id)
|
||||||
if search_space_id is not None:
|
if search_space_id is not None:
|
||||||
query = query.where(Prompt.search_space_id == search_space_id)
|
query = query.where(Prompt.search_space_id == search_space_id)
|
||||||
|
query = query.order_by(Prompt.created_at.desc())
|
||||||
result = await session.execute(query)
|
result = await session.execute(query)
|
||||||
user_prompts = result.scalars().all()
|
return result.scalars().all()
|
||||||
|
|
||||||
overrides = {p.system_prompt_slug: p for p in user_prompts if p.system_prompt_slug}
|
|
||||||
custom_prompts = [p for p in user_prompts if not p.system_prompt_slug]
|
|
||||||
|
|
||||||
merged: list[PromptRead] = []
|
|
||||||
for default in SYSTEM_PROMPT_DEFAULTS:
|
|
||||||
slug = default["slug"]
|
|
||||||
override = overrides.get(slug)
|
|
||||||
if override:
|
|
||||||
merged.append(_prompt_to_read(override))
|
|
||||||
else:
|
|
||||||
merged.append(
|
|
||||||
PromptRead(
|
|
||||||
id=None,
|
|
||||||
name=default["name"],
|
|
||||||
prompt=default["prompt"],
|
|
||||||
mode=default["mode"],
|
|
||||||
source="system",
|
|
||||||
system_prompt_slug=slug,
|
|
||||||
is_modified=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for p in sorted(custom_prompts, key=lambda x: x.created_at, reverse=True):
|
|
||||||
merged.append(_prompt_to_read(p))
|
|
||||||
|
|
||||||
return merged
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/prompts", response_model=PromptRead)
|
@router.post("/prompts", response_model=PromptRead)
|
||||||
|
|
@ -89,7 +45,7 @@ async def create_prompt(
|
||||||
session.add(prompt)
|
session.add(prompt)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(prompt)
|
await session.refresh(prompt)
|
||||||
return _prompt_to_read(prompt)
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
@router.put("/prompts/{prompt_id}", response_model=PromptRead)
|
@router.put("/prompts/{prompt_id}", response_model=PromptRead)
|
||||||
|
|
@ -112,10 +68,12 @@ async def update_prompt(
|
||||||
for field, value in body.model_dump(exclude_unset=True).items():
|
for field, value in body.model_dump(exclude_unset=True).items():
|
||||||
setattr(prompt, field, value)
|
setattr(prompt, field, value)
|
||||||
|
|
||||||
|
prompt.version = (prompt.version or 0) + 1
|
||||||
|
|
||||||
session.add(prompt)
|
session.add(prompt)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(prompt)
|
await session.refresh(prompt)
|
||||||
return _prompt_to_read(prompt)
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/prompts/{prompt_id}")
|
@router.delete("/prompts/{prompt_id}")
|
||||||
|
|
@ -139,69 +97,6 @@ async def delete_prompt(
|
||||||
return {"success": True}
|
return {"success": True}
|
||||||
|
|
||||||
|
|
||||||
@router.put("/prompts/system/{slug}", response_model=PromptRead)
|
|
||||||
async def update_system_prompt(
|
|
||||||
slug: str,
|
|
||||||
body: SystemPromptUpdate,
|
|
||||||
session: AsyncSession = Depends(get_async_session),
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
):
|
|
||||||
if slug not in SYSTEM_PROMPT_SLUGS:
|
|
||||||
raise HTTPException(status_code=404, detail="System prompt not found")
|
|
||||||
|
|
||||||
result = await session.execute(
|
|
||||||
select(Prompt).where(
|
|
||||||
Prompt.user_id == user.id,
|
|
||||||
Prompt.system_prompt_slug == slug,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
override = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
default = next(d for d in SYSTEM_PROMPT_DEFAULTS if d["slug"] == slug)
|
|
||||||
|
|
||||||
if override:
|
|
||||||
for field, value in body.model_dump(exclude_unset=True).items():
|
|
||||||
setattr(override, field, value)
|
|
||||||
else:
|
|
||||||
updates = body.model_dump(exclude_unset=True)
|
|
||||||
override = Prompt(
|
|
||||||
user_id=user.id,
|
|
||||||
system_prompt_slug=slug,
|
|
||||||
name=updates.get("name", default["name"]),
|
|
||||||
prompt=updates.get("prompt", default["prompt"]),
|
|
||||||
mode=updates.get("mode", default["mode"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
session.add(override)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(override)
|
|
||||||
return _prompt_to_read(override)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/prompts/system/{slug}")
|
|
||||||
async def reset_system_prompt(
|
|
||||||
slug: str,
|
|
||||||
session: AsyncSession = Depends(get_async_session),
|
|
||||||
user: User = Depends(current_active_user),
|
|
||||||
):
|
|
||||||
if slug not in SYSTEM_PROMPT_SLUGS:
|
|
||||||
raise HTTPException(status_code=404, detail="System prompt not found")
|
|
||||||
|
|
||||||
result = await session.execute(
|
|
||||||
select(Prompt).where(
|
|
||||||
Prompt.user_id == user.id,
|
|
||||||
Prompt.system_prompt_slug == slug,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
override = result.scalar_one_or_none()
|
|
||||||
if not override:
|
|
||||||
return {"success": True}
|
|
||||||
|
|
||||||
await session.delete(override)
|
|
||||||
await session.commit()
|
|
||||||
return {"success": True}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/prompts/public", response_model=list[PublicPromptRead])
|
@router.get("/prompts/public", response_model=list[PublicPromptRead])
|
||||||
async def list_public_prompts(
|
async def list_public_prompts(
|
||||||
session: AsyncSession = Depends(get_async_session),
|
session: AsyncSession = Depends(get_async_session),
|
||||||
|
|
@ -216,7 +111,7 @@ async def list_public_prompts(
|
||||||
prompts = result.scalars().all()
|
prompts = result.scalars().all()
|
||||||
return [
|
return [
|
||||||
PublicPromptRead(
|
PublicPromptRead(
|
||||||
**_prompt_to_read(p).model_dump(),
|
**PromptRead.model_validate(p).model_dump(),
|
||||||
author_name=p.user.email if p.user else None,
|
author_name=p.user.email if p.user else None,
|
||||||
)
|
)
|
||||||
for p in prompts
|
for p in prompts
|
||||||
|
|
@ -249,4 +144,4 @@ async def copy_public_prompt(
|
||||||
session.add(copy)
|
session.add(copy)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(copy)
|
await session.refresh(copy)
|
||||||
return _prompt_to_read(copy)
|
return copy
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
@ -19,23 +18,15 @@ class PromptUpdate(BaseModel):
|
||||||
is_public: bool | None = None
|
is_public: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
class SystemPromptUpdate(BaseModel):
|
|
||||||
name: str | None = Field(None, min_length=1, max_length=200)
|
|
||||||
prompt: str | None = Field(None, min_length=1)
|
|
||||||
mode: str | None = Field(None, pattern="^(transform|explore)$")
|
|
||||||
|
|
||||||
|
|
||||||
class PromptRead(BaseModel):
|
class PromptRead(BaseModel):
|
||||||
id: int | None
|
id: int
|
||||||
name: str
|
name: str
|
||||||
prompt: str
|
prompt: str
|
||||||
mode: str
|
mode: str
|
||||||
search_space_id: int | None = None
|
search_space_id: int | None
|
||||||
is_public: bool = False
|
is_public: bool
|
||||||
created_at: datetime | None = None
|
version: int
|
||||||
source: Literal["system", "custom"]
|
created_at: datetime
|
||||||
system_prompt_slug: str | None = None
|
|
||||||
is_modified: bool = False
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
|
||||||
|
|
@ -3,14 +3,16 @@ import { z } from "zod";
|
||||||
export type PromptMode = "transform" | "explore";
|
export type PromptMode = "transform" | "explore";
|
||||||
|
|
||||||
export const promptRead = z.object({
|
export const promptRead = z.object({
|
||||||
id: z.number(),
|
id: z.number().nullable(),
|
||||||
name: z.string(),
|
name: z.string(),
|
||||||
prompt: z.string(),
|
prompt: z.string(),
|
||||||
mode: z.enum(["transform", "explore"]),
|
mode: z.enum(["transform", "explore"]),
|
||||||
icon: z.string().nullable(),
|
search_space_id: z.number().nullable().optional(),
|
||||||
search_space_id: z.number().nullable(),
|
is_public: z.boolean().optional(),
|
||||||
is_public: z.boolean(),
|
created_at: z.string().nullable().optional(),
|
||||||
created_at: z.string(),
|
source: z.enum(["system", "custom"]),
|
||||||
|
system_prompt_slug: z.string().nullable().optional(),
|
||||||
|
is_modified: z.boolean().optional(),
|
||||||
});
|
});
|
||||||
|
|
||||||
export type PromptRead = z.infer<typeof promptRead>;
|
export type PromptRead = z.infer<typeof promptRead>;
|
||||||
|
|
@ -29,7 +31,6 @@ export const promptCreateRequest = z.object({
|
||||||
name: z.string().min(1).max(200),
|
name: z.string().min(1).max(200),
|
||||||
prompt: z.string().min(1),
|
prompt: z.string().min(1),
|
||||||
mode: z.enum(["transform", "explore"]),
|
mode: z.enum(["transform", "explore"]),
|
||||||
icon: z.string().max(50).nullable().optional(),
|
|
||||||
search_space_id: z.number().nullable().optional(),
|
search_space_id: z.number().nullable().optional(),
|
||||||
is_public: z.boolean().optional(),
|
is_public: z.boolean().optional(),
|
||||||
});
|
});
|
||||||
|
|
@ -40,12 +41,19 @@ export const promptUpdateRequest = z.object({
|
||||||
name: z.string().min(1).max(200).optional(),
|
name: z.string().min(1).max(200).optional(),
|
||||||
prompt: z.string().min(1).optional(),
|
prompt: z.string().min(1).optional(),
|
||||||
mode: z.enum(["transform", "explore"]).optional(),
|
mode: z.enum(["transform", "explore"]).optional(),
|
||||||
icon: z.string().max(50).nullable().optional(),
|
|
||||||
is_public: z.boolean().optional(),
|
is_public: z.boolean().optional(),
|
||||||
});
|
});
|
||||||
|
|
||||||
export type PromptUpdateRequest = z.infer<typeof promptUpdateRequest>;
|
export type PromptUpdateRequest = z.infer<typeof promptUpdateRequest>;
|
||||||
|
|
||||||
|
export const systemPromptUpdateRequest = z.object({
|
||||||
|
name: z.string().min(1).max(200).optional(),
|
||||||
|
prompt: z.string().min(1).optional(),
|
||||||
|
mode: z.enum(["transform", "explore"]).optional(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export type SystemPromptUpdateRequest = z.infer<typeof systemPromptUpdateRequest>;
|
||||||
|
|
||||||
export const promptDeleteResponse = z.object({
|
export const promptDeleteResponse = z.object({
|
||||||
success: z.boolean(),
|
success: z.boolean(),
|
||||||
});
|
});
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue