2026-03-29 00:07:08 +02:00
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
2026-03-30 19:36:54 +02:00
|
|
|
from sqlalchemy.orm import selectinload
|
2026-03-29 00:07:08 +02:00
|
|
|
|
|
|
|
|
from app.db import Prompt, User, get_async_session
|
2026-03-31 16:47:51 +02:00
|
|
|
from app.prompts.system_defaults import SYSTEM_PROMPT_DEFAULTS, SYSTEM_PROMPT_SLUGS
|
2026-03-29 00:07:08 +02:00
|
|
|
from app.schemas.prompts import (
|
|
|
|
|
PromptCreate,
|
|
|
|
|
PromptRead,
|
|
|
|
|
PromptUpdate,
|
2026-03-30 19:36:54 +02:00
|
|
|
PublicPromptRead,
|
2026-03-31 16:47:51 +02:00
|
|
|
SystemPromptUpdate,
|
2026-03-29 00:07:08 +02:00
|
|
|
)
|
|
|
|
|
from app.users import current_active_user
|
|
|
|
|
|
|
|
|
|
router = APIRouter(tags=["Prompts"])
|
|
|
|
|
|
|
|
|
|
|
2026-03-31 16:47:51 +02:00
|
|
|
def _prompt_to_read(prompt: Prompt) -> PromptRead:
|
|
|
|
|
source = "system" if prompt.system_prompt_slug else "custom"
|
|
|
|
|
return PromptRead(
|
|
|
|
|
id=prompt.id,
|
|
|
|
|
name=prompt.name,
|
|
|
|
|
prompt=prompt.prompt,
|
|
|
|
|
mode=prompt.mode.value if hasattr(prompt.mode, "value") else prompt.mode,
|
|
|
|
|
search_space_id=prompt.search_space_id,
|
|
|
|
|
is_public=prompt.is_public,
|
|
|
|
|
created_at=prompt.created_at,
|
|
|
|
|
source=source,
|
|
|
|
|
system_prompt_slug=prompt.system_prompt_slug,
|
|
|
|
|
is_modified=source == "system",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2026-03-29 00:07:08 +02:00
|
|
|
@router.get("/prompts", response_model=list[PromptRead])
|
|
|
|
|
async def list_prompts(
|
|
|
|
|
search_space_id: int | None = None,
|
|
|
|
|
session: AsyncSession = Depends(get_async_session),
|
|
|
|
|
user: User = Depends(current_active_user),
|
|
|
|
|
):
|
|
|
|
|
query = select(Prompt).where(Prompt.user_id == user.id)
|
|
|
|
|
if search_space_id is not None:
|
|
|
|
|
query = query.where(Prompt.search_space_id == search_space_id)
|
|
|
|
|
result = await session.execute(query)
|
2026-03-31 16:47:51 +02:00
|
|
|
user_prompts = result.scalars().all()
|
|
|
|
|
|
|
|
|
|
overrides = {p.system_prompt_slug: p for p in user_prompts if p.system_prompt_slug}
|
|
|
|
|
custom_prompts = [p for p in user_prompts if not p.system_prompt_slug]
|
|
|
|
|
|
|
|
|
|
merged: list[PromptRead] = []
|
|
|
|
|
for default in SYSTEM_PROMPT_DEFAULTS:
|
|
|
|
|
slug = default["slug"]
|
|
|
|
|
override = overrides.get(slug)
|
|
|
|
|
if override:
|
|
|
|
|
merged.append(_prompt_to_read(override))
|
|
|
|
|
else:
|
|
|
|
|
merged.append(
|
|
|
|
|
PromptRead(
|
|
|
|
|
id=None,
|
|
|
|
|
name=default["name"],
|
|
|
|
|
prompt=default["prompt"],
|
|
|
|
|
mode=default["mode"],
|
|
|
|
|
source="system",
|
|
|
|
|
system_prompt_slug=slug,
|
|
|
|
|
is_modified=False,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for p in sorted(custom_prompts, key=lambda x: x.created_at, reverse=True):
|
|
|
|
|
merged.append(_prompt_to_read(p))
|
|
|
|
|
|
|
|
|
|
return merged
|
2026-03-29 00:07:08 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/prompts", response_model=PromptRead)
|
|
|
|
|
async def create_prompt(
|
|
|
|
|
body: PromptCreate,
|
|
|
|
|
session: AsyncSession = Depends(get_async_session),
|
|
|
|
|
user: User = Depends(current_active_user),
|
|
|
|
|
):
|
|
|
|
|
prompt = Prompt(
|
|
|
|
|
user_id=user.id,
|
|
|
|
|
search_space_id=body.search_space_id,
|
|
|
|
|
name=body.name,
|
|
|
|
|
prompt=body.prompt,
|
|
|
|
|
mode=body.mode,
|
|
|
|
|
)
|
|
|
|
|
session.add(prompt)
|
|
|
|
|
await session.commit()
|
|
|
|
|
await session.refresh(prompt)
|
2026-03-31 16:47:51 +02:00
|
|
|
return _prompt_to_read(prompt)
|
2026-03-29 00:07:08 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.put("/prompts/{prompt_id}", response_model=PromptRead)
|
|
|
|
|
async def update_prompt(
|
|
|
|
|
prompt_id: int,
|
|
|
|
|
body: PromptUpdate,
|
|
|
|
|
session: AsyncSession = Depends(get_async_session),
|
|
|
|
|
user: User = Depends(current_active_user),
|
|
|
|
|
):
|
|
|
|
|
result = await session.execute(
|
|
|
|
|
select(Prompt).where(
|
|
|
|
|
Prompt.id == prompt_id,
|
|
|
|
|
Prompt.user_id == user.id,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
prompt = result.scalar_one_or_none()
|
|
|
|
|
if not prompt:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Prompt not found")
|
|
|
|
|
|
|
|
|
|
for field, value in body.model_dump(exclude_unset=True).items():
|
|
|
|
|
setattr(prompt, field, value)
|
|
|
|
|
|
|
|
|
|
session.add(prompt)
|
|
|
|
|
await session.commit()
|
|
|
|
|
await session.refresh(prompt)
|
2026-03-31 16:47:51 +02:00
|
|
|
return _prompt_to_read(prompt)
|
2026-03-29 00:07:08 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.delete("/prompts/{prompt_id}")
|
|
|
|
|
async def delete_prompt(
|
|
|
|
|
prompt_id: int,
|
|
|
|
|
session: AsyncSession = Depends(get_async_session),
|
|
|
|
|
user: User = Depends(current_active_user),
|
|
|
|
|
):
|
|
|
|
|
result = await session.execute(
|
|
|
|
|
select(Prompt).where(
|
|
|
|
|
Prompt.id == prompt_id,
|
|
|
|
|
Prompt.user_id == user.id,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
prompt = result.scalar_one_or_none()
|
|
|
|
|
if not prompt:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Prompt not found")
|
|
|
|
|
|
|
|
|
|
await session.delete(prompt)
|
|
|
|
|
await session.commit()
|
|
|
|
|
return {"success": True}
|
2026-03-30 19:36:54 +02:00
|
|
|
|
|
|
|
|
|
2026-03-31 16:47:51 +02:00
|
|
|
@router.put("/prompts/system/{slug}", response_model=PromptRead)
|
|
|
|
|
async def update_system_prompt(
|
|
|
|
|
slug: str,
|
|
|
|
|
body: SystemPromptUpdate,
|
|
|
|
|
session: AsyncSession = Depends(get_async_session),
|
|
|
|
|
user: User = Depends(current_active_user),
|
|
|
|
|
):
|
|
|
|
|
if slug not in SYSTEM_PROMPT_SLUGS:
|
|
|
|
|
raise HTTPException(status_code=404, detail="System prompt not found")
|
|
|
|
|
|
|
|
|
|
result = await session.execute(
|
|
|
|
|
select(Prompt).where(
|
|
|
|
|
Prompt.user_id == user.id,
|
|
|
|
|
Prompt.system_prompt_slug == slug,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
override = result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
default = next(d for d in SYSTEM_PROMPT_DEFAULTS if d["slug"] == slug)
|
|
|
|
|
|
|
|
|
|
if override:
|
|
|
|
|
for field, value in body.model_dump(exclude_unset=True).items():
|
|
|
|
|
setattr(override, field, value)
|
|
|
|
|
else:
|
|
|
|
|
updates = body.model_dump(exclude_unset=True)
|
|
|
|
|
override = Prompt(
|
|
|
|
|
user_id=user.id,
|
|
|
|
|
system_prompt_slug=slug,
|
|
|
|
|
name=updates.get("name", default["name"]),
|
|
|
|
|
prompt=updates.get("prompt", default["prompt"]),
|
|
|
|
|
mode=updates.get("mode", default["mode"]),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
session.add(override)
|
|
|
|
|
await session.commit()
|
|
|
|
|
await session.refresh(override)
|
|
|
|
|
return _prompt_to_read(override)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.delete("/prompts/system/{slug}")
|
|
|
|
|
async def reset_system_prompt(
|
|
|
|
|
slug: str,
|
|
|
|
|
session: AsyncSession = Depends(get_async_session),
|
|
|
|
|
user: User = Depends(current_active_user),
|
|
|
|
|
):
|
|
|
|
|
if slug not in SYSTEM_PROMPT_SLUGS:
|
|
|
|
|
raise HTTPException(status_code=404, detail="System prompt not found")
|
|
|
|
|
|
|
|
|
|
result = await session.execute(
|
|
|
|
|
select(Prompt).where(
|
|
|
|
|
Prompt.user_id == user.id,
|
|
|
|
|
Prompt.system_prompt_slug == slug,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
override = result.scalar_one_or_none()
|
|
|
|
|
if not override:
|
|
|
|
|
return {"success": True}
|
|
|
|
|
|
|
|
|
|
await session.delete(override)
|
|
|
|
|
await session.commit()
|
|
|
|
|
return {"success": True}
|
|
|
|
|
|
|
|
|
|
|
2026-03-30 19:36:54 +02:00
|
|
|
@router.get("/prompts/public", response_model=list[PublicPromptRead])
|
|
|
|
|
async def list_public_prompts(
|
|
|
|
|
session: AsyncSession = Depends(get_async_session),
|
|
|
|
|
user: User = Depends(current_active_user),
|
|
|
|
|
):
|
|
|
|
|
result = await session.execute(
|
|
|
|
|
select(Prompt)
|
|
|
|
|
.options(selectinload(Prompt.user))
|
|
|
|
|
.where(Prompt.is_public.is_(True))
|
|
|
|
|
.order_by(Prompt.created_at.desc())
|
|
|
|
|
)
|
|
|
|
|
prompts = result.scalars().all()
|
|
|
|
|
return [
|
|
|
|
|
PublicPromptRead(
|
2026-03-31 16:47:51 +02:00
|
|
|
**_prompt_to_read(p).model_dump(),
|
2026-03-30 19:36:54 +02:00
|
|
|
author_name=p.user.email if p.user else None,
|
|
|
|
|
)
|
|
|
|
|
for p in prompts
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/prompts/{prompt_id}/copy", response_model=PromptRead)
|
|
|
|
|
async def copy_public_prompt(
|
|
|
|
|
prompt_id: int,
|
|
|
|
|
session: AsyncSession = Depends(get_async_session),
|
|
|
|
|
user: User = Depends(current_active_user),
|
|
|
|
|
):
|
|
|
|
|
result = await session.execute(
|
|
|
|
|
select(Prompt).where(
|
|
|
|
|
Prompt.id == prompt_id,
|
|
|
|
|
Prompt.is_public.is_(True),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
source = result.scalar_one_or_none()
|
|
|
|
|
if not source:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Prompt not found")
|
|
|
|
|
|
|
|
|
|
copy = Prompt(
|
|
|
|
|
user_id=user.id,
|
|
|
|
|
name=source.name,
|
|
|
|
|
prompt=source.prompt,
|
|
|
|
|
mode=source.mode,
|
|
|
|
|
is_public=False,
|
|
|
|
|
)
|
|
|
|
|
session.add(copy)
|
|
|
|
|
await session.commit()
|
|
|
|
|
await session.refresh(copy)
|
2026-03-31 16:47:51 +02:00
|
|
|
return _prompt_to_read(copy)
|