diff --git a/surfsense_backend/app/routes/prompts_routes.py b/surfsense_backend/app/routes/prompts_routes.py index d76c43663..5f9baf067 100644 --- a/surfsense_backend/app/routes/prompts_routes.py +++ b/surfsense_backend/app/routes/prompts_routes.py @@ -4,17 +4,35 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.db import Prompt, User, get_async_session +from app.prompts.system_defaults import SYSTEM_PROMPT_DEFAULTS, SYSTEM_PROMPT_SLUGS from app.schemas.prompts import ( PromptCreate, PromptRead, PromptUpdate, PublicPromptRead, + SystemPromptUpdate, ) from app.users import current_active_user router = APIRouter(tags=["Prompts"]) +def _prompt_to_read(prompt: Prompt) -> PromptRead: + source = "system" if prompt.system_prompt_slug else "custom" + return PromptRead( + id=prompt.id, + name=prompt.name, + prompt=prompt.prompt, + mode=prompt.mode.value if hasattr(prompt.mode, "value") else prompt.mode, + search_space_id=prompt.search_space_id, + is_public=prompt.is_public, + created_at=prompt.created_at, + source=source, + system_prompt_slug=prompt.system_prompt_slug, + is_modified=source == "system", + ) + + @router.get("/prompts", response_model=list[PromptRead]) async def list_prompts( search_space_id: int | None = None, @@ -24,9 +42,35 @@ async def list_prompts( query = select(Prompt).where(Prompt.user_id == user.id) if search_space_id is not None: query = query.where(Prompt.search_space_id == search_space_id) - query = query.order_by(Prompt.created_at.desc()) result = await session.execute(query) - return result.scalars().all() + user_prompts = result.scalars().all() + + overrides = {p.system_prompt_slug: p for p in user_prompts if p.system_prompt_slug} + custom_prompts = [p for p in user_prompts if not p.system_prompt_slug] + + merged: list[PromptRead] = [] + for default in SYSTEM_PROMPT_DEFAULTS: + slug = default["slug"] + override = overrides.get(slug) + if override: + merged.append(_prompt_to_read(override)) + else: + merged.append( + PromptRead( + id=None, + name=default["name"], + prompt=default["prompt"], + mode=default["mode"], + source="system", + system_prompt_slug=slug, + is_modified=False, + ) + ) + + for p in sorted(custom_prompts, key=lambda x: x.created_at, reverse=True): + merged.append(_prompt_to_read(p)) + + return merged @router.post("/prompts", response_model=PromptRead) @@ -45,7 +89,7 @@ async def create_prompt( session.add(prompt) await session.commit() await session.refresh(prompt) - return prompt + return _prompt_to_read(prompt) @router.put("/prompts/{prompt_id}", response_model=PromptRead) @@ -71,7 +115,7 @@ async def update_prompt( session.add(prompt) await session.commit() await session.refresh(prompt) - return prompt + return _prompt_to_read(prompt) @router.delete("/prompts/{prompt_id}") @@ -95,6 +139,69 @@ async def delete_prompt( return {"success": True} +@router.put("/prompts/system/{slug}", response_model=PromptRead) +async def update_system_prompt( + slug: str, + body: SystemPromptUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + if slug not in SYSTEM_PROMPT_SLUGS: + raise HTTPException(status_code=404, detail="System prompt not found") + + result = await session.execute( + select(Prompt).where( + Prompt.user_id == user.id, + Prompt.system_prompt_slug == slug, + ) + ) + override = result.scalar_one_or_none() + + default = next(d for d in SYSTEM_PROMPT_DEFAULTS if d["slug"] == slug) + + if override: + for field, value in body.model_dump(exclude_unset=True).items(): + setattr(override, field, value) + else: + updates = body.model_dump(exclude_unset=True) + override = Prompt( + user_id=user.id, + system_prompt_slug=slug, + name=updates.get("name", default["name"]), + prompt=updates.get("prompt", default["prompt"]), + mode=updates.get("mode", default["mode"]), + ) + + session.add(override) + await session.commit() + await session.refresh(override) + return _prompt_to_read(override) + + +@router.delete("/prompts/system/{slug}") +async def reset_system_prompt( + slug: str, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + if slug not in SYSTEM_PROMPT_SLUGS: + raise HTTPException(status_code=404, detail="System prompt not found") + + result = await session.execute( + select(Prompt).where( + Prompt.user_id == user.id, + Prompt.system_prompt_slug == slug, + ) + ) + override = result.scalar_one_or_none() + if not override: + return {"success": True} + + await session.delete(override) + await session.commit() + return {"success": True} + + @router.get("/prompts/public", response_model=list[PublicPromptRead]) async def list_public_prompts( session: AsyncSession = Depends(get_async_session), @@ -109,7 +216,7 @@ async def list_public_prompts( prompts = result.scalars().all() return [ PublicPromptRead( - **PromptRead.model_validate(p).model_dump(), + **_prompt_to_read(p).model_dump(), author_name=p.user.email if p.user else None, ) for p in prompts @@ -142,4 +249,4 @@ async def copy_public_prompt( session.add(copy) await session.commit() await session.refresh(copy) - return copy + return _prompt_to_read(copy)