from fastapi import APIRouter, Depends, HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.db import Prompt, SearchSpaceMembership, User, get_async_session from app.schemas.prompts import ( PromptCreate, PromptRead, PromptUpdate, PublicPromptRead, ) from app.users import current_active_user router = APIRouter(tags=["Prompts"]) @router.get("/prompts", response_model=list[PromptRead]) async def list_prompts( search_space_id: int | None = None, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): query = select(Prompt).where(Prompt.user_id == user.id) if search_space_id is not None: query = query.where(Prompt.search_space_id == search_space_id) query = query.order_by(Prompt.created_at.desc()) result = await session.execute(query) return result.scalars().all() @router.post("/prompts", response_model=PromptRead) async def create_prompt( body: PromptCreate, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): if body.search_space_id is not None: membership = await session.execute( select(SearchSpaceMembership).where( SearchSpaceMembership.user_id == user.id, SearchSpaceMembership.search_space_id == body.search_space_id, ) ) if not membership.scalar_one_or_none(): raise HTTPException( status_code=403, detail="You are not a member of this search space", ) prompt = Prompt( user_id=user.id, search_space_id=body.search_space_id, name=body.name, prompt=body.prompt, mode=body.mode, is_public=body.is_public, ) session.add(prompt) await session.commit() await session.refresh(prompt) return prompt @router.put("/prompts/{prompt_id}", response_model=PromptRead) async def update_prompt( prompt_id: int, body: PromptUpdate, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): result = await session.execute( select(Prompt).where( Prompt.id == prompt_id, Prompt.user_id == user.id, ) ) prompt = result.scalar_one_or_none() if not prompt: raise HTTPException(status_code=404, detail="Prompt not found") updates = body.model_dump(exclude_unset=True) content_fields = {"name", "prompt", "mode"} has_content_change = bool(updates.keys() & content_fields) for field, value in updates.items(): setattr(prompt, field, value) if has_content_change: prompt.version = Prompt.version + 1 session.add(prompt) await session.commit() await session.refresh(prompt) return prompt @router.delete("/prompts/{prompt_id}") async def delete_prompt( prompt_id: int, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): result = await session.execute( select(Prompt).where( Prompt.id == prompt_id, Prompt.user_id == user.id, ) ) prompt = result.scalar_one_or_none() if not prompt: raise HTTPException(status_code=404, detail="Prompt not found") await session.delete(prompt) await session.commit() return {"success": True} @router.get("/prompts/public", response_model=list[PublicPromptRead]) async def list_public_prompts( session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): result = await session.execute( select(Prompt) .options(selectinload(Prompt.user)) .where(Prompt.is_public.is_(True), Prompt.user_id != user.id) .order_by(Prompt.created_at.desc()) ) prompts = result.scalars().all() return [ PublicPromptRead( **PromptRead.model_validate(p).model_dump(), author_name=p.user.email if p.user else None, ) for p in prompts ] @router.post("/prompts/{prompt_id}/copy", response_model=PromptRead) async def copy_public_prompt( prompt_id: int, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): result = await session.execute( select(Prompt).where( Prompt.id == prompt_id, Prompt.is_public.is_(True), ) ) source = result.scalar_one_or_none() if not source: raise HTTPException(status_code=404, detail="Prompt not found") copy = Prompt( user_id=user.id, name=source.name, prompt=source.prompt, mode=source.mode, is_public=False, ) session.add(copy) await session.commit() await session.refresh(copy) return copy