mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-19 18:45:15 +02:00
feat: added shared chats
This commit is contained in:
parent
764dd05582
commit
f22d649239
22 changed files with 1881 additions and 506 deletions
|
|
@ -19,12 +19,14 @@ from datetime import UTC, datetime
|
|||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
NewChatMessage,
|
||||
NewChatMessageRole,
|
||||
NewChatThread,
|
||||
|
|
@ -40,6 +42,7 @@ from app.schemas.new_chat import (
|
|||
NewChatThreadCreate,
|
||||
NewChatThreadRead,
|
||||
NewChatThreadUpdate,
|
||||
NewChatThreadVisibilityUpdate,
|
||||
NewChatThreadWithMessages,
|
||||
ThreadHistoryLoadResponse,
|
||||
ThreadListItem,
|
||||
|
|
@ -52,6 +55,61 @@ from app.utils.rbac import check_permission
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
async def check_thread_access(
|
||||
thread: NewChatThread,
|
||||
user: User,
|
||||
require_ownership: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a user has access to a thread based on visibility rules.
|
||||
|
||||
Access is granted if:
|
||||
- User is the creator of the thread
|
||||
- Thread visibility is SEARCH_SPACE (and user has permission to read chats)
|
||||
- Thread is a legacy thread (created_by_id is NULL) - visible to all
|
||||
|
||||
Args:
|
||||
thread: The thread to check access for
|
||||
user: The user requesting access
|
||||
require_ownership: If True, only the creator can access (for edit/delete operations)
|
||||
Legacy threads (NULL creator) are treated as accessible by all
|
||||
|
||||
Returns:
|
||||
True if access is granted
|
||||
|
||||
Raises:
|
||||
HTTPException: If access is denied
|
||||
"""
|
||||
is_owner = thread.created_by_id == user.id
|
||||
is_legacy = thread.created_by_id is None
|
||||
|
||||
# Legacy threads are accessible to all users in the search space
|
||||
if is_legacy:
|
||||
return True
|
||||
|
||||
# If ownership is required, only the creator can access
|
||||
if require_ownership:
|
||||
if not is_owner:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only the creator of this chat can perform this action",
|
||||
)
|
||||
return True
|
||||
|
||||
# For read access: owner or shared threads
|
||||
if is_owner:
|
||||
return True
|
||||
|
||||
if thread.visibility == ChatVisibility.SEARCH_SPACE:
|
||||
return True
|
||||
|
||||
# Private thread and user is not the owner
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You don't have access to this private chat",
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Thread Endpoints
|
||||
# =============================================================================
|
||||
|
|
@ -65,9 +123,14 @@ async def list_threads(
|
|||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
List all threads for the current user in a search space.
|
||||
List all accessible threads for the current user in a search space.
|
||||
Returns threads and archived_threads for ThreadListPrimitive.
|
||||
|
||||
A user can see threads that are:
|
||||
- Created by them (regardless of visibility)
|
||||
- Shared with the search space (visibility = SEARCH_SPACE)
|
||||
- Legacy threads with no creator (created_by_id is NULL)
|
||||
|
||||
Args:
|
||||
search_space_id: The search space to list threads for
|
||||
limit: Optional limit on number of threads to return (applies to active threads only)
|
||||
|
|
@ -83,10 +146,20 @@ async def list_threads(
|
|||
"You don't have permission to read chats in this search space",
|
||||
)
|
||||
|
||||
# Get all threads in this search space
|
||||
# Get threads that are either:
|
||||
# 1. Created by the current user (any visibility)
|
||||
# 2. Shared with the search space (visibility = SEARCH_SPACE)
|
||||
# 3. Legacy threads with no creator (created_by_id is NULL) - visible to all
|
||||
query = (
|
||||
select(NewChatThread)
|
||||
.filter(NewChatThread.search_space_id == search_space_id)
|
||||
.filter(
|
||||
NewChatThread.search_space_id == search_space_id,
|
||||
or_(
|
||||
NewChatThread.created_by_id == user.id,
|
||||
NewChatThread.visibility == ChatVisibility.SEARCH_SPACE,
|
||||
NewChatThread.created_by_id.is_(None), # Legacy threads
|
||||
),
|
||||
)
|
||||
.order_by(NewChatThread.updated_at.desc())
|
||||
)
|
||||
|
||||
|
|
@ -98,10 +171,17 @@ async def list_threads(
|
|||
archived_threads = []
|
||||
|
||||
for thread in all_threads:
|
||||
# Legacy threads (no creator) are treated as own threads for display purposes
|
||||
is_own_thread = (
|
||||
thread.created_by_id == user.id or thread.created_by_id is None
|
||||
)
|
||||
item = ThreadListItem(
|
||||
id=thread.id,
|
||||
title=thread.title,
|
||||
archived=thread.archived,
|
||||
visibility=thread.visibility,
|
||||
created_by_id=thread.created_by_id,
|
||||
is_own_thread=is_own_thread,
|
||||
created_at=thread.created_at,
|
||||
updated_at=thread.updated_at,
|
||||
)
|
||||
|
|
@ -137,7 +217,12 @@ async def search_threads(
|
|||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Search threads by title in a search space.
|
||||
Search accessible threads by title in a search space.
|
||||
|
||||
A user can search threads that are:
|
||||
- Created by them (regardless of visibility)
|
||||
- Shared with the search space (visibility = SEARCH_SPACE)
|
||||
- Legacy threads with no creator (created_by_id is NULL)
|
||||
|
||||
Args:
|
||||
search_space_id: The search space to search in
|
||||
|
|
@ -154,12 +239,17 @@ async def search_threads(
|
|||
"You don't have permission to read chats in this search space",
|
||||
)
|
||||
|
||||
# Search threads by title (case-insensitive)
|
||||
# Search accessible threads by title (case-insensitive)
|
||||
query = (
|
||||
select(NewChatThread)
|
||||
.filter(
|
||||
NewChatThread.search_space_id == search_space_id,
|
||||
NewChatThread.title.ilike(f"%{title}%"),
|
||||
or_(
|
||||
NewChatThread.created_by_id == user.id,
|
||||
NewChatThread.visibility == ChatVisibility.SEARCH_SPACE,
|
||||
NewChatThread.created_by_id.is_(None), # Legacy threads
|
||||
),
|
||||
)
|
||||
.order_by(NewChatThread.updated_at.desc())
|
||||
)
|
||||
|
|
@ -172,6 +262,12 @@ async def search_threads(
|
|||
id=thread.id,
|
||||
title=thread.title,
|
||||
archived=thread.archived,
|
||||
visibility=thread.visibility,
|
||||
created_by_id=thread.created_by_id,
|
||||
# Legacy threads (no creator) are treated as own threads
|
||||
is_own_thread=(
|
||||
thread.created_by_id == user.id or thread.created_by_id is None
|
||||
),
|
||||
created_at=thread.created_at,
|
||||
updated_at=thread.updated_at,
|
||||
)
|
||||
|
|
@ -200,6 +296,9 @@ async def create_thread(
|
|||
"""
|
||||
Create a new chat thread.
|
||||
|
||||
The thread is created with the specified visibility (defaults to PRIVATE).
|
||||
The current user is recorded as the creator of the thread.
|
||||
|
||||
Requires CHATS_CREATE permission.
|
||||
"""
|
||||
try:
|
||||
|
|
@ -215,7 +314,9 @@ async def create_thread(
|
|||
db_thread = NewChatThread(
|
||||
title=thread.title,
|
||||
archived=thread.archived,
|
||||
visibility=thread.visibility,
|
||||
search_space_id=thread.search_space_id,
|
||||
created_by_id=user.id,
|
||||
updated_at=now,
|
||||
)
|
||||
session.add(db_thread)
|
||||
|
|
@ -254,6 +355,10 @@ async def get_thread_messages(
|
|||
Get a thread with all its messages.
|
||||
This is used by ThreadHistoryAdapter.load() to restore conversation.
|
||||
|
||||
Access is granted if:
|
||||
- User is the creator of the thread
|
||||
- Thread visibility is SEARCH_SPACE
|
||||
|
||||
Requires CHATS_READ permission.
|
||||
"""
|
||||
try:
|
||||
|
|
@ -268,7 +373,7 @@ async def get_thread_messages(
|
|||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
# Check permission and ownership
|
||||
# Check permission to read chats in this search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
|
|
@ -277,6 +382,9 @@ async def get_thread_messages(
|
|||
"You don't have permission to read chats in this search space",
|
||||
)
|
||||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(thread, user)
|
||||
|
||||
# Return messages in the format expected by assistant-ui
|
||||
messages = [
|
||||
NewChatMessageRead(
|
||||
|
|
@ -313,6 +421,10 @@ async def get_thread_full(
|
|||
"""
|
||||
Get full thread details with all messages.
|
||||
|
||||
Access is granted if:
|
||||
- User is the creator of the thread
|
||||
- Thread visibility is SEARCH_SPACE
|
||||
|
||||
Requires CHATS_READ permission.
|
||||
"""
|
||||
try:
|
||||
|
|
@ -334,6 +446,9 @@ async def get_thread_full(
|
|||
"You don't have permission to read chats in this search space",
|
||||
)
|
||||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(thread, user)
|
||||
|
||||
return thread
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -360,6 +475,9 @@ async def update_thread(
|
|||
Update a thread (title, archived status).
|
||||
Used for renaming and archiving threads.
|
||||
|
||||
- PRIVATE threads: Only the creator can update
|
||||
- SEARCH_SPACE threads: Any member with CHATS_UPDATE permission can update
|
||||
|
||||
Requires CHATS_UPDATE permission.
|
||||
"""
|
||||
try:
|
||||
|
|
@ -379,6 +497,11 @@ async def update_thread(
|
|||
"You don't have permission to update chats in this search space",
|
||||
)
|
||||
|
||||
# For PRIVATE threads, only the creator can update
|
||||
# For SEARCH_SPACE threads, any member with permission can update
|
||||
if db_thread.visibility == ChatVisibility.PRIVATE:
|
||||
await check_thread_access(db_thread, user, require_ownership=True)
|
||||
|
||||
# Update fields
|
||||
update_data = thread_update.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
|
|
@ -420,6 +543,9 @@ async def delete_thread(
|
|||
"""
|
||||
Delete a thread and all its messages.
|
||||
|
||||
- PRIVATE threads: Only the creator can delete
|
||||
- SEARCH_SPACE threads: Any member with CHATS_DELETE permission can delete
|
||||
|
||||
Requires CHATS_DELETE permission.
|
||||
"""
|
||||
try:
|
||||
|
|
@ -439,6 +565,11 @@ async def delete_thread(
|
|||
"You don't have permission to delete chats in this search space",
|
||||
)
|
||||
|
||||
# For PRIVATE threads, only the creator can delete
|
||||
# For SEARCH_SPACE threads, any member with permission can delete
|
||||
if db_thread.visibility == ChatVisibility.PRIVATE:
|
||||
await check_thread_access(db_thread, user, require_ownership=True)
|
||||
|
||||
await session.delete(db_thread)
|
||||
await session.commit()
|
||||
return {"message": "Thread deleted successfully"}
|
||||
|
|
@ -463,6 +594,71 @@ async def delete_thread(
|
|||
) from None
|
||||
|
||||
|
||||
@router.patch("/threads/{thread_id}/visibility", response_model=NewChatThreadRead)
|
||||
async def update_thread_visibility(
|
||||
thread_id: int,
|
||||
visibility_update: NewChatThreadVisibilityUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Update the visibility/sharing settings of a thread.
|
||||
|
||||
Only the creator of the thread can change its visibility.
|
||||
- PRIVATE: Only the creator can access the thread (default)
|
||||
- SEARCH_SPACE: All members of the search space can access the thread
|
||||
|
||||
Requires CHATS_UPDATE permission.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
db_thread = result.scalars().first()
|
||||
|
||||
if not db_thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_thread.search_space_id,
|
||||
Permission.CHATS_UPDATE.value,
|
||||
"You don't have permission to update chats in this search space",
|
||||
)
|
||||
|
||||
# Only the creator can change visibility
|
||||
await check_thread_access(db_thread, user, require_ownership=True)
|
||||
|
||||
# Update visibility
|
||||
db_thread.visibility = visibility_update.visibility
|
||||
db_thread.updated_at = datetime.now(UTC)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_thread)
|
||||
return db_thread
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Database constraint violation. Please check your input data.",
|
||||
) from None
|
||||
except OperationalError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later."
|
||||
) from None
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An unexpected error occurred while updating thread visibility: {e!s}",
|
||||
) from None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Message Endpoints
|
||||
# =============================================================================
|
||||
|
|
@ -479,6 +675,10 @@ async def append_message(
|
|||
Append a message to a thread.
|
||||
This is used by ThreadHistoryAdapter.append() to persist messages.
|
||||
|
||||
Access is granted if:
|
||||
- User is the creator of the thread
|
||||
- Thread visibility is SEARCH_SPACE
|
||||
|
||||
Requires CHATS_UPDATE permission.
|
||||
"""
|
||||
try:
|
||||
|
|
@ -513,6 +713,9 @@ async def append_message(
|
|||
"You don't have permission to update chats in this search space",
|
||||
)
|
||||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(thread, user)
|
||||
|
||||
# Convert string role to enum
|
||||
role_str = (
|
||||
message.role.lower() if isinstance(message.role, str) else message.role
|
||||
|
|
@ -597,6 +800,10 @@ async def list_messages(
|
|||
"""
|
||||
List messages in a thread with pagination.
|
||||
|
||||
Access is granted if:
|
||||
- User is the creator of the thread
|
||||
- Thread visibility is SEARCH_SPACE
|
||||
|
||||
Requires CHATS_READ permission.
|
||||
"""
|
||||
try:
|
||||
|
|
@ -617,6 +824,9 @@ async def list_messages(
|
|||
"You don't have permission to read chats in this search space",
|
||||
)
|
||||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(thread, user)
|
||||
|
||||
# Get messages
|
||||
query = (
|
||||
select(NewChatMessage)
|
||||
|
|
@ -659,6 +869,10 @@ async def handle_new_chat(
|
|||
This endpoint handles the new chat functionality with streaming responses
|
||||
using Server-Sent Events (SSE) format compatible with Vercel AI SDK.
|
||||
|
||||
Access is granted if:
|
||||
- User is the creator of the thread
|
||||
- Thread visibility is SEARCH_SPACE
|
||||
|
||||
Requires CHATS_CREATE permission.
|
||||
"""
|
||||
try:
|
||||
|
|
@ -679,6 +893,9 @@ async def handle_new_chat(
|
|||
"You don't have permission to chat in this search space",
|
||||
)
|
||||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(thread, user)
|
||||
|
||||
# Get search space to check LLM config preferences
|
||||
search_space_result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue