mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-24 21:38:09 +02:00
feat: enforce API access for chat routes
This commit is contained in:
parent
493e8d5a64
commit
70a0828b95
7 changed files with 152 additions and 103 deletions
|
|
@ -28,6 +28,7 @@ from pydantic import BaseModel
|
|||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.agents.chat.multi_agent_chat.shared.feature_flags import get_flags
|
||||
from app.db import (
|
||||
AgentActionLog,
|
||||
|
|
@ -36,7 +37,7 @@ from app.db import (
|
|||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -111,8 +112,9 @@ async def list_thread_actions(
|
|||
page: int = Query(0, ge=0),
|
||||
page_size: int = Query(50, ge=1, le=200),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
) -> AgentActionListResponse:
|
||||
user = auth.user
|
||||
"""List agent actions for a thread, newest first.
|
||||
|
||||
Authorization:
|
||||
|
|
@ -132,7 +134,7 @@ async def list_thread_actions(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to view this thread's action log.",
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Routes for chat comments and mentions.
|
|||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import User, get_async_session
|
||||
from app.schemas.chat_comments import (
|
||||
CommentBatchRequest,
|
||||
|
|
@ -25,7 +26,7 @@ from app.services.chat_comments_service import (
|
|||
get_user_mentions,
|
||||
update_comment,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -34,8 +35,9 @@ router = APIRouter()
|
|||
async def batch_list_comments(
|
||||
request: CommentBatchRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Batch-fetch comments for multiple messages in one request."""
|
||||
return await get_comments_for_messages_batch(session, request.message_ids, user)
|
||||
|
||||
|
|
@ -44,8 +46,9 @@ async def batch_list_comments(
|
|||
async def list_comments(
|
||||
message_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""List all comments for a message with their replies."""
|
||||
return await get_comments_for_message(session, message_id, user)
|
||||
|
||||
|
|
@ -55,8 +58,9 @@ async def add_comment(
|
|||
message_id: int,
|
||||
request: CommentCreateRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Create a top-level comment on an AI response."""
|
||||
return await create_comment(session, message_id, request.content, user)
|
||||
|
||||
|
|
@ -66,8 +70,9 @@ async def add_reply(
|
|||
comment_id: int,
|
||||
request: CommentCreateRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Reply to an existing comment."""
|
||||
return await create_reply(session, comment_id, request.content, user)
|
||||
|
||||
|
|
@ -77,8 +82,9 @@ async def edit_comment(
|
|||
comment_id: int,
|
||||
request: CommentUpdateRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Update a comment's content (author only)."""
|
||||
return await update_comment(session, comment_id, request.content, user)
|
||||
|
||||
|
|
@ -87,8 +93,9 @@ async def edit_comment(
|
|||
async def remove_comment(
|
||||
comment_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Delete a comment (author or user with COMMENTS_DELETE permission)."""
|
||||
return await delete_comment(session, comment_id, user)
|
||||
|
||||
|
|
@ -102,7 +109,8 @@ async def remove_comment(
|
|||
async def list_mentions(
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""List mentions for the current user."""
|
||||
return await get_user_mentions(session, user, search_space_id)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from sqlalchemy import select, update
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
Connection,
|
||||
|
|
@ -14,7 +15,6 @@ from app.db import (
|
|||
NewChatThread,
|
||||
Permission,
|
||||
SearchSpace,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas import (
|
||||
|
|
@ -42,7 +42,7 @@ from app.services.model_connection_service import (
|
|||
verify_connection,
|
||||
)
|
||||
from app.services.provider_registry import REGISTRY
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -257,8 +257,8 @@ async def _default_unset_roles(
|
|||
|
||||
|
||||
@router.get("/model-providers", response_model=list[ModelProviderRead])
|
||||
async def list_model_providers(user: User = Depends(current_active_user)):
|
||||
del user
|
||||
async def list_model_providers(auth: AuthContext = Depends(get_auth_context)):
|
||||
del auth
|
||||
local_only = {"ollama_chat", "lm_studio"}
|
||||
return [
|
||||
ModelProviderRead(
|
||||
|
|
@ -298,14 +298,15 @@ async def _load_connection(session: AsyncSession, connection_id: int) -> Connect
|
|||
|
||||
async def _assert_connection_access(
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
conn: Connection,
|
||||
permission: str = Permission.LLM_CONFIGS_CREATE.value,
|
||||
) -> None:
|
||||
user = auth.user
|
||||
if conn.search_space_id:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
conn.search_space_id,
|
||||
permission,
|
||||
"You don't have permission to manage model connections in this search space",
|
||||
|
|
@ -318,14 +319,14 @@ async def _assert_connection_access(
|
|||
|
||||
|
||||
@router.get("/global-llm-config-status")
|
||||
async def global_llm_config_status(user: User = Depends(current_active_user)):
|
||||
del user
|
||||
async def global_llm_config_status(auth: AuthContext = Depends(get_auth_context)):
|
||||
del auth
|
||||
return {"exists": config.GLOBAL_LLM_CONFIG_FILE_EXISTS}
|
||||
|
||||
|
||||
@router.get("/global-model-connections", response_model=list[ConnectionRead])
|
||||
async def list_global_connections(user: User = Depends(current_active_user)):
|
||||
del user
|
||||
async def list_global_connections(auth: AuthContext = Depends(get_auth_context)):
|
||||
del auth
|
||||
models_by_connection: dict[int, list[dict]] = {}
|
||||
for model in config.GLOBAL_MODELS:
|
||||
models_by_connection.setdefault(model["connection_id"], []).append(model)
|
||||
|
|
@ -339,13 +340,14 @@ async def list_global_connections(user: User = Depends(current_active_user)):
|
|||
async def list_connections(
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
stmt = select(Connection).options(selectinload(Connection.models))
|
||||
if search_space_id is not None:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to view model connections in this search space",
|
||||
|
|
@ -363,8 +365,9 @@ async def list_connections(
|
|||
async def create_connection(
|
||||
data: ConnectionCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
if data.scope == ConnectionScope.GLOBAL:
|
||||
raise HTTPException(status_code=400, detail="GLOBAL connections are YAML-only")
|
||||
if data.scope == ConnectionScope.SEARCH_SPACE:
|
||||
|
|
@ -372,7 +375,7 @@ async def create_connection(
|
|||
raise HTTPException(status_code=400, detail="search_space_id is required")
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
data.search_space_id,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to create model connections in this search space",
|
||||
|
|
@ -411,12 +414,13 @@ async def create_connection(
|
|||
async def preview_connection_models(
|
||||
data: ConnectionCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
data.search_space_id,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to create model connections in this search space",
|
||||
|
|
@ -445,12 +449,13 @@ async def preview_connection_models(
|
|||
async def test_preview_connection_model(
|
||||
data: ModelTestPreview,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
if data.scope == ConnectionScope.SEARCH_SPACE and data.search_space_id is not None:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
data.search_space_id,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to create model connections in this search space",
|
||||
|
|
@ -491,11 +496,11 @@ async def update_connection(
|
|||
connection_id: int,
|
||||
data: ConnectionUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_UPDATE.value
|
||||
session, auth, conn, Permission.LLM_CONFIGS_UPDATE.value
|
||||
)
|
||||
search_space_id = conn.search_space_id
|
||||
for key, value in data.model_dump(exclude_unset=True).items():
|
||||
|
|
@ -512,11 +517,11 @@ async def update_connection(
|
|||
async def delete_connection(
|
||||
connection_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_DELETE.value
|
||||
session, auth, conn, Permission.LLM_CONFIGS_DELETE.value
|
||||
)
|
||||
search_space_id = conn.search_space_id
|
||||
await session.delete(conn)
|
||||
|
|
@ -533,11 +538,11 @@ async def delete_connection(
|
|||
async def verify_model_connection(
|
||||
connection_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_CREATE.value
|
||||
session, auth, conn, Permission.LLM_CONFIGS_CREATE.value
|
||||
)
|
||||
result = await verify_connection(conn)
|
||||
return VerifyConnectionResponse(
|
||||
|
|
@ -551,11 +556,11 @@ async def verify_model_connection(
|
|||
async def discover_connection_models(
|
||||
connection_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_CREATE.value
|
||||
session, auth, conn, Permission.LLM_CONFIGS_CREATE.value
|
||||
)
|
||||
try:
|
||||
discovered = await discover_models(conn)
|
||||
|
|
@ -595,11 +600,11 @@ async def add_manual_model(
|
|||
connection_id: int,
|
||||
data: ModelCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_UPDATE.value
|
||||
session, auth, conn, Permission.LLM_CONFIGS_UPDATE.value
|
||||
)
|
||||
|
||||
model_id = data.model_id.strip()
|
||||
|
|
@ -640,11 +645,11 @@ async def bulk_update_models(
|
|||
connection_id: int,
|
||||
data: ModelsBulkUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
conn = await _load_connection(session, connection_id)
|
||||
await _assert_connection_access(
|
||||
session, user, conn, Permission.LLM_CONFIGS_UPDATE.value
|
||||
session, auth, conn, Permission.LLM_CONFIGS_UPDATE.value
|
||||
)
|
||||
search_space_id = conn.search_space_id
|
||||
|
||||
|
|
@ -674,7 +679,7 @@ async def update_model(
|
|||
model_id: int,
|
||||
data: ModelUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(Model)
|
||||
|
|
@ -685,7 +690,7 @@ async def update_model(
|
|||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
await _assert_connection_access(
|
||||
session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value
|
||||
session, auth, model.connection, Permission.LLM_CONFIGS_UPDATE.value
|
||||
)
|
||||
search_space_id = model.connection.search_space_id
|
||||
update = data.model_dump(exclude_unset=True)
|
||||
|
|
@ -704,7 +709,7 @@ async def update_model(
|
|||
async def test_connection_model(
|
||||
model_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(Model)
|
||||
|
|
@ -715,7 +720,7 @@ async def test_connection_model(
|
|||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
await _assert_connection_access(
|
||||
session, user, model.connection, Permission.LLM_CONFIGS_UPDATE.value
|
||||
session, auth, model.connection, Permission.LLM_CONFIGS_UPDATE.value
|
||||
)
|
||||
result = await test_model(model.connection, model)
|
||||
await session.commit()
|
||||
|
|
@ -730,11 +735,11 @@ async def test_connection_model(
|
|||
async def get_model_roles(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.LLM_CONFIGS_CREATE.value,
|
||||
"You don't have permission to view model roles in this search space",
|
||||
|
|
@ -756,11 +761,11 @@ async def update_model_roles(
|
|||
search_space_id: int,
|
||||
data: ModelRolesUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.LLM_CONFIGS_UPDATE.value,
|
||||
"You don't have permission to update model roles in this search space",
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.agents.chat.multi_agent_chat.main_agent.middleware.busy_mutex import (
|
||||
get_cancel_state,
|
||||
is_cancel_requested,
|
||||
|
|
@ -75,7 +76,7 @@ from app.tasks.chat.streaming.flows import (
|
|||
stream_new_chat,
|
||||
stream_resume_chat,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
from app.utils.perf import get_perf_logger
|
||||
from app.utils.rbac import check_permission
|
||||
from app.utils.user_message_multimodal import (
|
||||
|
|
@ -595,8 +596,9 @@ async def list_threads(
|
|||
search_space_id: int,
|
||||
limit: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List all accessible threads for the current user in a search space.
|
||||
Returns threads and archived_threads for ThreadListPrimitive.
|
||||
|
|
@ -615,7 +617,7 @@ async def list_threads(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to read chats in this search space",
|
||||
|
|
@ -702,8 +704,9 @@ async def search_threads(
|
|||
search_space_id: int,
|
||||
title: str,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Search accessible threads by title in a search space.
|
||||
|
||||
|
|
@ -721,7 +724,7 @@ async def search_threads(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to read chats in this search space",
|
||||
|
|
@ -794,8 +797,9 @@ async def search_threads(
|
|||
async def create_thread(
|
||||
thread: NewChatThreadCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Create a new chat thread.
|
||||
|
||||
|
|
@ -807,7 +811,7 @@ async def create_thread(
|
|||
try:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_CREATE.value,
|
||||
"You don't have permission to create chats in this search space",
|
||||
|
|
@ -852,8 +856,9 @@ async def create_thread(
|
|||
async def get_thread_messages(
|
||||
thread_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get a thread with all its messages.
|
||||
This is used by ThreadHistoryAdapter.load() to restore conversation.
|
||||
|
|
@ -877,7 +882,7 @@ async def get_thread_messages(
|
|||
# Check permission to read chats in this search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to read chats in this search space",
|
||||
|
|
@ -936,8 +941,9 @@ async def get_thread_messages(
|
|||
async def get_thread_full(
|
||||
thread_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Get full thread details with all messages.
|
||||
|
||||
|
|
@ -964,7 +970,7 @@ async def get_thread_full(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to read chats in this search space",
|
||||
|
|
@ -1005,8 +1011,9 @@ async def update_thread(
|
|||
thread_id: int,
|
||||
thread_update: NewChatThreadUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Update a thread (title, archived status).
|
||||
Used for renaming and archiving threads.
|
||||
|
|
@ -1027,7 +1034,7 @@ async def update_thread(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
db_thread.search_space_id,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
"You don't have permission to update chats in this search space",
|
||||
|
|
@ -1074,8 +1081,9 @@ async def update_thread(
|
|||
async def delete_thread(
|
||||
thread_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Delete a thread and all its messages.
|
||||
|
||||
|
|
@ -1095,7 +1103,7 @@ async def delete_thread(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
db_thread.search_space_id,
|
||||
Permission.CHATS_DELETE.value,
|
||||
"You don't have permission to delete chats in this search space",
|
||||
|
|
@ -1146,8 +1154,9 @@ async def update_thread_visibility(
|
|||
thread_id: int,
|
||||
visibility_update: NewChatThreadVisibilityUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Update the visibility/sharing settings of a thread.
|
||||
|
||||
|
|
@ -1168,7 +1177,7 @@ async def update_thread_visibility(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
db_thread.search_space_id,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
"You don't have permission to update chats in this search space",
|
||||
|
|
@ -1217,8 +1226,9 @@ async def update_thread_visibility(
|
|||
async def create_thread_snapshot(
|
||||
thread_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Create a public snapshot of the thread.
|
||||
|
||||
|
|
@ -1239,8 +1249,9 @@ async def create_thread_snapshot(
|
|||
async def list_thread_snapshots(
|
||||
thread_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List all public snapshots for this thread.
|
||||
|
||||
|
|
@ -1262,8 +1273,9 @@ async def delete_thread_snapshot(
|
|||
thread_id: int,
|
||||
snapshot_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Delete a specific snapshot.
|
||||
|
||||
|
|
@ -1290,8 +1302,9 @@ async def append_message(
|
|||
thread_id: int,
|
||||
request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
.. deprecated:: 2026-05
|
||||
Replaced by the **SSE-based message ID handshake**. The streaming
|
||||
|
|
@ -1370,7 +1383,7 @@ async def append_message(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
"You don't have permission to update chats in this search space",
|
||||
|
|
@ -1597,8 +1610,9 @@ async def list_messages(
|
|||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
List messages in a thread with pagination.
|
||||
|
||||
|
|
@ -1620,7 +1634,7 @@ async def list_messages(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to read chats in this search space",
|
||||
|
|
@ -1662,7 +1676,7 @@ async def list_messages(
|
|||
|
||||
@router.get("/agent/tools", response_model=list[AgentToolInfo])
|
||||
async def list_agent_tools(
|
||||
_user: User = Depends(current_active_user),
|
||||
_auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
"""Return the list of built-in agent tools with their metadata.
|
||||
|
||||
|
|
@ -1691,8 +1705,9 @@ async def handle_new_chat(
|
|||
request: NewChatRequest,
|
||||
http_request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Stream chat responses from the deep agent.
|
||||
|
||||
|
|
@ -1717,7 +1732,7 @@ async def handle_new_chat(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_CREATE.value,
|
||||
"You don't have permission to chat in this search space",
|
||||
|
|
@ -1821,8 +1836,9 @@ async def cancel_active_turn(
|
|||
thread_id: int,
|
||||
response: Response,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""Signal cancellation for the currently running turn on ``thread_id``."""
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
|
|
@ -1833,7 +1849,7 @@ async def cancel_active_turn(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
"You don't have permission to update chats in this search space",
|
||||
|
|
@ -1873,8 +1889,9 @@ async def cancel_active_turn(
|
|||
async def get_turn_status(
|
||||
thread_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
|
|
@ -1884,7 +1901,7 @@ async def get_turn_status(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_READ.value,
|
||||
"You don't have permission to view chats in this search space",
|
||||
|
|
@ -1911,8 +1928,9 @@ async def regenerate_response(
|
|||
request: RegenerateRequest,
|
||||
http_request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Regenerate the AI response for a chat thread.
|
||||
|
||||
|
|
@ -1947,7 +1965,7 @@ async def regenerate_response(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
"You don't have permission to update chats in this search space",
|
||||
|
|
@ -2356,8 +2374,9 @@ async def resume_chat(
|
|||
request: ResumeRequest,
|
||||
http_request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
|
|
@ -2369,7 +2388,7 @@ async def resume_chat(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.CHATS_CREATE.value,
|
||||
"You don't have permission to chat in this search space",
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from fastapi import APIRouter, Depends, HTTPException
|
|||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import User, get_async_session
|
||||
from app.schemas.new_chat import (
|
||||
CloneResponse,
|
||||
|
|
@ -23,7 +24,7 @@ from app.services.public_chat_service import (
|
|||
get_snapshot_report,
|
||||
get_snapshot_video_presentation,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.users import get_auth_context
|
||||
|
||||
router = APIRouter(prefix="/public", tags=["public"])
|
||||
|
||||
|
|
@ -46,8 +47,9 @@ async def read_public_chat(
|
|||
async def clone_public_chat(
|
||||
share_token: str,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
auth: AuthContext = Depends(get_auth_context),
|
||||
):
|
||||
user = auth.user
|
||||
"""
|
||||
Clone a public chat snapshot to the user's account.
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from sqlalchemy import delete, or_, select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
ChatComment,
|
||||
ChatCommentMention,
|
||||
|
|
@ -138,8 +139,9 @@ async def get_comment_thread_participants(
|
|||
async def get_comments_for_message(
|
||||
session: AsyncSession,
|
||||
message_id: int,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> CommentListResponse:
|
||||
user = auth.user
|
||||
"""
|
||||
Get all comments for a message with their replies.
|
||||
|
||||
|
|
@ -169,7 +171,7 @@ async def get_comments_for_message(
|
|||
# Check permission to read comments
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.COMMENTS_READ.value,
|
||||
"You don't have permission to read comments in this search space",
|
||||
|
|
@ -268,8 +270,9 @@ async def get_comments_for_message(
|
|||
async def get_comments_for_messages_batch(
|
||||
session: AsyncSession,
|
||||
message_ids: list[int],
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> CommentBatchResponse:
|
||||
user = auth.user
|
||||
"""
|
||||
Batch-fetch comments for multiple messages in a single DB round-trip.
|
||||
|
||||
|
|
@ -295,7 +298,7 @@ async def get_comments_for_messages_batch(
|
|||
for ss_id in search_space_ids:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
ss_id,
|
||||
Permission.COMMENTS_READ.value,
|
||||
"You don't have permission to read comments in this search space",
|
||||
|
|
@ -409,8 +412,9 @@ async def create_comment(
|
|||
session: AsyncSession,
|
||||
message_id: int,
|
||||
content: str,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> CommentResponse:
|
||||
user = auth.user
|
||||
"""
|
||||
Create a top-level comment on an AI response.
|
||||
|
||||
|
|
@ -521,8 +525,9 @@ async def create_reply(
|
|||
session: AsyncSession,
|
||||
comment_id: int,
|
||||
content: str,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> CommentReplyResponse:
|
||||
user = auth.user
|
||||
"""
|
||||
Create a reply to an existing comment.
|
||||
|
||||
|
|
@ -657,8 +662,9 @@ async def update_comment(
|
|||
session: AsyncSession,
|
||||
comment_id: int,
|
||||
content: str,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> CommentReplyResponse:
|
||||
user = auth.user
|
||||
"""
|
||||
Update a comment's content (author only).
|
||||
|
||||
|
|
@ -797,8 +803,9 @@ async def update_comment(
|
|||
async def delete_comment(
|
||||
session: AsyncSession,
|
||||
comment_id: int,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> dict:
|
||||
user = auth.user
|
||||
"""
|
||||
Delete a comment (author or user with COMMENTS_DELETE permission).
|
||||
|
||||
|
|
@ -844,9 +851,10 @@ async def delete_comment(
|
|||
|
||||
async def get_user_mentions(
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
search_space_id: int | None = None,
|
||||
) -> MentionListResponse:
|
||||
user = auth.user
|
||||
"""
|
||||
Get mentions for the current user, optionally filtered by search space.
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from sqlalchemy import delete, select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.auth.context import AuthContext
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
NewChatMessage,
|
||||
|
|
@ -163,8 +164,9 @@ def compute_content_hash(messages: list[dict]) -> str:
|
|||
async def create_snapshot(
|
||||
session: AsyncSession,
|
||||
thread_id: int,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> dict:
|
||||
user = auth.user
|
||||
"""
|
||||
Create a public snapshot of a chat thread.
|
||||
|
||||
|
|
@ -186,7 +188,7 @@ async def create_snapshot(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.PUBLIC_SHARING_CREATE.value,
|
||||
"You don't have permission to create public share links",
|
||||
|
|
@ -431,8 +433,9 @@ async def get_public_chat(
|
|||
async def list_snapshots_for_thread(
|
||||
session: AsyncSession,
|
||||
thread_id: int,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> list[dict]:
|
||||
user = auth.user
|
||||
"""List all public snapshots for a thread."""
|
||||
from app.config import config
|
||||
|
||||
|
|
@ -447,7 +450,7 @@ async def list_snapshots_for_thread(
|
|||
# Check permission to view public share links
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
thread.search_space_id,
|
||||
Permission.PUBLIC_SHARING_VIEW.value,
|
||||
"You don't have permission to view public share links",
|
||||
|
|
@ -477,14 +480,15 @@ async def list_snapshots_for_thread(
|
|||
async def list_snapshots_for_search_space(
|
||||
session: AsyncSession,
|
||||
search_space_id: int,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> list[dict]:
|
||||
user = auth.user
|
||||
"""List all public snapshots for a search space."""
|
||||
from app.config import config
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
search_space_id,
|
||||
Permission.PUBLIC_SHARING_VIEW.value,
|
||||
"You don't have permission to view public share links",
|
||||
|
|
@ -534,8 +538,9 @@ async def delete_snapshot(
|
|||
session: AsyncSession,
|
||||
thread_id: int,
|
||||
snapshot_id: int,
|
||||
user: User,
|
||||
auth: AuthContext,
|
||||
) -> bool:
|
||||
user = auth.user
|
||||
"""Delete a specific snapshot. Only thread owner can delete."""
|
||||
# Get snapshot with thread
|
||||
result = await session.execute(
|
||||
|
|
@ -553,7 +558,7 @@ async def delete_snapshot(
|
|||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
auth,
|
||||
snapshot.thread.search_space_id,
|
||||
Permission.PUBLIC_SHARING_DELETE.value,
|
||||
"You don't have permission to delete public share links",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue