SurfSense/surfsense_backend/app/routes/prompts_routes.py

166 lines
4.8 KiB
Python

from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.db import Prompt, SearchSpaceMembership, User, get_async_session
from app.schemas.prompts import (
PromptCreate,
PromptRead,
PromptUpdate,
PublicPromptRead,
)
from app.users import current_active_user
router = APIRouter(tags=["Prompts"])
@router.get("/prompts", response_model=list[PromptRead])
async def list_prompts(
search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
query = select(Prompt).where(Prompt.user_id == user.id)
if search_space_id is not None:
query = query.where(Prompt.search_space_id == search_space_id)
query = query.order_by(Prompt.created_at.desc())
result = await session.execute(query)
return result.scalars().all()
@router.post("/prompts", response_model=PromptRead)
async def create_prompt(
body: PromptCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
if body.search_space_id is not None:
membership = await session.execute(
select(SearchSpaceMembership).where(
SearchSpaceMembership.user_id == user.id,
SearchSpaceMembership.search_space_id == body.search_space_id,
)
)
if not membership.scalar_one_or_none():
raise HTTPException(
status_code=403,
detail="You are not a member of this search space",
)
prompt = Prompt(
user_id=user.id,
search_space_id=body.search_space_id,
name=body.name,
prompt=body.prompt,
mode=body.mode,
is_public=body.is_public,
)
session.add(prompt)
await session.commit()
await session.refresh(prompt)
return prompt
@router.put("/prompts/{prompt_id}", response_model=PromptRead)
async def update_prompt(
prompt_id: int,
body: PromptUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
result = await session.execute(
select(Prompt).where(
Prompt.id == prompt_id,
Prompt.user_id == user.id,
)
)
prompt = result.scalar_one_or_none()
if not prompt:
raise HTTPException(status_code=404, detail="Prompt not found")
updates = body.model_dump(exclude_unset=True)
content_fields = {"name", "prompt", "mode"}
has_content_change = bool(updates.keys() & content_fields)
for field, value in updates.items():
setattr(prompt, field, value)
if has_content_change:
prompt.version = Prompt.version + 1
session.add(prompt)
await session.commit()
await session.refresh(prompt)
return prompt
@router.delete("/prompts/{prompt_id}")
async def delete_prompt(
prompt_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
result = await session.execute(
select(Prompt).where(
Prompt.id == prompt_id,
Prompt.user_id == user.id,
)
)
prompt = result.scalar_one_or_none()
if not prompt:
raise HTTPException(status_code=404, detail="Prompt not found")
await session.delete(prompt)
await session.commit()
return {"success": True}
@router.get("/prompts/public", response_model=list[PublicPromptRead])
async def list_public_prompts(
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
result = await session.execute(
select(Prompt)
.options(selectinload(Prompt.user))
.where(Prompt.is_public.is_(True), Prompt.user_id != user.id)
.order_by(Prompt.created_at.desc())
)
prompts = result.scalars().all()
return [
PublicPromptRead(
**PromptRead.model_validate(p).model_dump(),
author_name=p.user.email if p.user else None,
)
for p in prompts
]
@router.post("/prompts/{prompt_id}/copy", response_model=PromptRead)
async def copy_public_prompt(
prompt_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
result = await session.execute(
select(Prompt).where(
Prompt.id == prompt_id,
Prompt.is_public.is_(True),
)
)
source = result.scalar_one_or_none()
if not source:
raise HTTPException(status_code=404, detail="Prompt not found")
copy = Prompt(
user_id=user.id,
name=source.name,
prompt=source.prompt,
mode=source.mode,
is_public=False,
)
session.add(copy)
await session.commit()
await session.refresh(copy)
return copy