Simplify prompts to seed-based CRUD with version tracking

This commit is contained in:
CREDO23 2026-03-31 18:05:42 +02:00
parent 80d096db32
commit 0c975a6f80
6 changed files with 56 additions and 148 deletions

View file

@ -4,35 +4,17 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.db import Prompt, User, get_async_session
from app.prompts.system_defaults import SYSTEM_PROMPT_DEFAULTS, SYSTEM_PROMPT_SLUGS
from app.schemas.prompts import (
PromptCreate,
PromptRead,
PromptUpdate,
PublicPromptRead,
SystemPromptUpdate,
)
from app.users import current_active_user
router = APIRouter(tags=["Prompts"])
def _prompt_to_read(prompt: Prompt) -> PromptRead:
source = "system" if prompt.system_prompt_slug else "custom"
return PromptRead(
id=prompt.id,
name=prompt.name,
prompt=prompt.prompt,
mode=prompt.mode.value if hasattr(prompt.mode, "value") else prompt.mode,
search_space_id=prompt.search_space_id,
is_public=prompt.is_public,
created_at=prompt.created_at,
source=source,
system_prompt_slug=prompt.system_prompt_slug,
is_modified=source == "system",
)
@router.get("/prompts", response_model=list[PromptRead])
async def list_prompts(
search_space_id: int | None = None,
@ -42,35 +24,9 @@ async def list_prompts(
query = select(Prompt).where(Prompt.user_id == user.id)
if search_space_id is not None:
query = query.where(Prompt.search_space_id == search_space_id)
query = query.order_by(Prompt.created_at.desc())
result = await session.execute(query)
user_prompts = result.scalars().all()
overrides = {p.system_prompt_slug: p for p in user_prompts if p.system_prompt_slug}
custom_prompts = [p for p in user_prompts if not p.system_prompt_slug]
merged: list[PromptRead] = []
for default in SYSTEM_PROMPT_DEFAULTS:
slug = default["slug"]
override = overrides.get(slug)
if override:
merged.append(_prompt_to_read(override))
else:
merged.append(
PromptRead(
id=None,
name=default["name"],
prompt=default["prompt"],
mode=default["mode"],
source="system",
system_prompt_slug=slug,
is_modified=False,
)
)
for p in sorted(custom_prompts, key=lambda x: x.created_at, reverse=True):
merged.append(_prompt_to_read(p))
return merged
return result.scalars().all()
@router.post("/prompts", response_model=PromptRead)
@ -89,7 +45,7 @@ async def create_prompt(
session.add(prompt)
await session.commit()
await session.refresh(prompt)
return _prompt_to_read(prompt)
return prompt
@router.put("/prompts/{prompt_id}", response_model=PromptRead)
@ -112,10 +68,12 @@ async def update_prompt(
for field, value in body.model_dump(exclude_unset=True).items():
setattr(prompt, field, value)
prompt.version = (prompt.version or 0) + 1
session.add(prompt)
await session.commit()
await session.refresh(prompt)
return _prompt_to_read(prompt)
return prompt
@router.delete("/prompts/{prompt_id}")
@ -139,69 +97,6 @@ async def delete_prompt(
return {"success": True}
@router.put("/prompts/system/{slug}", response_model=PromptRead)
async def update_system_prompt(
slug: str,
body: SystemPromptUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
if slug not in SYSTEM_PROMPT_SLUGS:
raise HTTPException(status_code=404, detail="System prompt not found")
result = await session.execute(
select(Prompt).where(
Prompt.user_id == user.id,
Prompt.system_prompt_slug == slug,
)
)
override = result.scalar_one_or_none()
default = next(d for d in SYSTEM_PROMPT_DEFAULTS if d["slug"] == slug)
if override:
for field, value in body.model_dump(exclude_unset=True).items():
setattr(override, field, value)
else:
updates = body.model_dump(exclude_unset=True)
override = Prompt(
user_id=user.id,
system_prompt_slug=slug,
name=updates.get("name", default["name"]),
prompt=updates.get("prompt", default["prompt"]),
mode=updates.get("mode", default["mode"]),
)
session.add(override)
await session.commit()
await session.refresh(override)
return _prompt_to_read(override)
@router.delete("/prompts/system/{slug}")
async def reset_system_prompt(
slug: str,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
if slug not in SYSTEM_PROMPT_SLUGS:
raise HTTPException(status_code=404, detail="System prompt not found")
result = await session.execute(
select(Prompt).where(
Prompt.user_id == user.id,
Prompt.system_prompt_slug == slug,
)
)
override = result.scalar_one_or_none()
if not override:
return {"success": True}
await session.delete(override)
await session.commit()
return {"success": True}
@router.get("/prompts/public", response_model=list[PublicPromptRead])
async def list_public_prompts(
session: AsyncSession = Depends(get_async_session),
@ -216,7 +111,7 @@ async def list_public_prompts(
prompts = result.scalars().all()
return [
PublicPromptRead(
**_prompt_to_read(p).model_dump(),
**PromptRead.model_validate(p).model_dump(),
author_name=p.user.email if p.user else None,
)
for p in prompts
@ -249,4 +144,4 @@ async def copy_public_prompt(
session.add(copy)
await session.commit()
await session.refresh(copy)
return _prompt_to_read(copy)
return copy