Add system prompt merge logic, upsert override, and reset endpoints

This commit is contained in:
CREDO23 2026-03-31 16:47:51 +02:00
parent 329e979d48
commit 80d096db32

View file

@ -4,17 +4,35 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.db import Prompt, User, get_async_session
from app.prompts.system_defaults import SYSTEM_PROMPT_DEFAULTS, SYSTEM_PROMPT_SLUGS
from app.schemas.prompts import (
PromptCreate,
PromptRead,
PromptUpdate,
PublicPromptRead,
SystemPromptUpdate,
)
from app.users import current_active_user
router = APIRouter(tags=["Prompts"])
def _prompt_to_read(prompt: Prompt) -> PromptRead:
source = "system" if prompt.system_prompt_slug else "custom"
return PromptRead(
id=prompt.id,
name=prompt.name,
prompt=prompt.prompt,
mode=prompt.mode.value if hasattr(prompt.mode, "value") else prompt.mode,
search_space_id=prompt.search_space_id,
is_public=prompt.is_public,
created_at=prompt.created_at,
source=source,
system_prompt_slug=prompt.system_prompt_slug,
is_modified=source == "system",
)
@router.get("/prompts", response_model=list[PromptRead])
async def list_prompts(
search_space_id: int | None = None,
@ -24,9 +42,35 @@ async def list_prompts(
query = select(Prompt).where(Prompt.user_id == user.id)
if search_space_id is not None:
query = query.where(Prompt.search_space_id == search_space_id)
query = query.order_by(Prompt.created_at.desc())
result = await session.execute(query)
return result.scalars().all()
user_prompts = result.scalars().all()
overrides = {p.system_prompt_slug: p for p in user_prompts if p.system_prompt_slug}
custom_prompts = [p for p in user_prompts if not p.system_prompt_slug]
merged: list[PromptRead] = []
for default in SYSTEM_PROMPT_DEFAULTS:
slug = default["slug"]
override = overrides.get(slug)
if override:
merged.append(_prompt_to_read(override))
else:
merged.append(
PromptRead(
id=None,
name=default["name"],
prompt=default["prompt"],
mode=default["mode"],
source="system",
system_prompt_slug=slug,
is_modified=False,
)
)
for p in sorted(custom_prompts, key=lambda x: x.created_at, reverse=True):
merged.append(_prompt_to_read(p))
return merged
@router.post("/prompts", response_model=PromptRead)
@ -45,7 +89,7 @@ async def create_prompt(
session.add(prompt)
await session.commit()
await session.refresh(prompt)
return prompt
return _prompt_to_read(prompt)
@router.put("/prompts/{prompt_id}", response_model=PromptRead)
@ -71,7 +115,7 @@ async def update_prompt(
session.add(prompt)
await session.commit()
await session.refresh(prompt)
return prompt
return _prompt_to_read(prompt)
@router.delete("/prompts/{prompt_id}")
@ -95,6 +139,69 @@ async def delete_prompt(
return {"success": True}
@router.put("/prompts/system/{slug}", response_model=PromptRead)
async def update_system_prompt(
slug: str,
body: SystemPromptUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
if slug not in SYSTEM_PROMPT_SLUGS:
raise HTTPException(status_code=404, detail="System prompt not found")
result = await session.execute(
select(Prompt).where(
Prompt.user_id == user.id,
Prompt.system_prompt_slug == slug,
)
)
override = result.scalar_one_or_none()
default = next(d for d in SYSTEM_PROMPT_DEFAULTS if d["slug"] == slug)
if override:
for field, value in body.model_dump(exclude_unset=True).items():
setattr(override, field, value)
else:
updates = body.model_dump(exclude_unset=True)
override = Prompt(
user_id=user.id,
system_prompt_slug=slug,
name=updates.get("name", default["name"]),
prompt=updates.get("prompt", default["prompt"]),
mode=updates.get("mode", default["mode"]),
)
session.add(override)
await session.commit()
await session.refresh(override)
return _prompt_to_read(override)
@router.delete("/prompts/system/{slug}")
async def reset_system_prompt(
slug: str,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
if slug not in SYSTEM_PROMPT_SLUGS:
raise HTTPException(status_code=404, detail="System prompt not found")
result = await session.execute(
select(Prompt).where(
Prompt.user_id == user.id,
Prompt.system_prompt_slug == slug,
)
)
override = result.scalar_one_or_none()
if not override:
return {"success": True}
await session.delete(override)
await session.commit()
return {"success": True}
@router.get("/prompts/public", response_model=list[PublicPromptRead])
async def list_public_prompts(
session: AsyncSession = Depends(get_async_session),
@ -109,7 +216,7 @@ async def list_public_prompts(
prompts = result.scalars().all()
return [
PublicPromptRead(
**PromptRead.model_validate(p).model_dump(),
**_prompt_to_read(p).model_dump(),
author_name=p.user.email if p.user else None,
)
for p in prompts
@ -142,4 +249,4 @@ async def copy_public_prompt(
session.add(copy)
await session.commit()
await session.refresh(copy)
return copy
return _prompt_to_read(copy)