mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
Merge remote-tracking branch 'upstream/dev' into fix/documents
This commit is contained in:
commit
c132e5ddb0
49 changed files with 1625 additions and 354 deletions
|
|
@ -12,6 +12,7 @@ from app.agents.new_chat.checkpointer import (
|
|||
from app.config import config, initialize_llm_router
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
from app.routes import router as crud_router
|
||||
from app.routes.auth_routes import router as auth_router
|
||||
from app.schemas import UserCreate, UserRead, UserUpdate
|
||||
from app.tasks.surfsense_docs_indexer import seed_surfsense_docs
|
||||
from app.users import SECRET, auth_backend, current_active_user, fastapi_users
|
||||
|
|
@ -111,6 +112,9 @@ app.include_router(
|
|||
tags=["users"],
|
||||
)
|
||||
|
||||
# Include custom auth routes (refresh token, logout)
|
||||
app.include_router(auth_router)
|
||||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
|
|
|
|||
|
|
@ -255,6 +255,14 @@ class Config:
|
|||
# OAuth JWT
|
||||
SECRET_KEY = os.getenv("SECRET_KEY")
|
||||
|
||||
# JWT Token Lifetimes
|
||||
ACCESS_TOKEN_LIFETIME_SECONDS = int(
|
||||
os.getenv("ACCESS_TOKEN_LIFETIME_SECONDS", str(24 * 60 * 60)) # 1 day
|
||||
)
|
||||
REFRESH_TOKEN_LIFETIME_SECONDS = int(
|
||||
os.getenv("REFRESH_TOKEN_LIFETIME_SECONDS", str(14 * 24 * 60 * 60)) # 2 weeks
|
||||
)
|
||||
|
||||
# ETL Service
|
||||
ETL_SERVICE = os.getenv("ETL_SERVICE")
|
||||
|
||||
|
|
|
|||
|
|
@ -71,6 +71,14 @@ class AirtableHistoryConnector:
|
|||
|
||||
config_data = connector.config.copy()
|
||||
|
||||
# Check if access_token exists before processing
|
||||
raw_access_token = config_data.get("access_token")
|
||||
if not raw_access_token:
|
||||
raise ValueError(
|
||||
"Airtable access token not found. "
|
||||
"Please reconnect your Airtable account."
|
||||
)
|
||||
|
||||
# Decrypt credentials if they are encrypted
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
|
|
@ -98,6 +106,14 @@ class AirtableHistoryConnector:
|
|||
f"Failed to decrypt Airtable credentials: {e!s}"
|
||||
) from e
|
||||
|
||||
# Final validation after decryption
|
||||
final_token = config_data.get("access_token")
|
||||
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
|
||||
raise ValueError(
|
||||
"Airtable access token is invalid or empty. "
|
||||
"Please reconnect your Airtable account."
|
||||
)
|
||||
|
||||
try:
|
||||
self._credentials = AirtableAuthCredentialsBase.from_dict(config_data)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -87,6 +87,14 @@ class ConfluenceHistoryConnector:
|
|||
|
||||
if is_oauth:
|
||||
# OAuth 2.0 authentication
|
||||
# Check if access_token exists before processing
|
||||
raw_access_token = config_data.get("access_token")
|
||||
if not raw_access_token:
|
||||
raise ValueError(
|
||||
"Confluence access token not found. "
|
||||
"Please reconnect your Confluence account."
|
||||
)
|
||||
|
||||
# Decrypt credentials if they are encrypted
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
|
|
@ -118,6 +126,14 @@ class ConfluenceHistoryConnector:
|
|||
f"Failed to decrypt Confluence credentials: {e!s}"
|
||||
) from e
|
||||
|
||||
# Final validation after decryption
|
||||
final_token = config_data.get("access_token")
|
||||
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
|
||||
raise ValueError(
|
||||
"Confluence access token is invalid or empty. "
|
||||
"Please reconnect your Confluence account."
|
||||
)
|
||||
|
||||
try:
|
||||
self._credentials = AtlassianAuthCredentialsBase.from_dict(
|
||||
config_data
|
||||
|
|
|
|||
|
|
@ -86,6 +86,14 @@ class JiraHistoryConnector:
|
|||
|
||||
if is_oauth:
|
||||
# OAuth 2.0 authentication
|
||||
# Check if access_token exists before processing
|
||||
raw_access_token = config_data.get("access_token")
|
||||
if not raw_access_token:
|
||||
raise ValueError(
|
||||
"Jira access token not found. "
|
||||
"Please reconnect your Jira account."
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise ValueError(
|
||||
"SECRET_KEY not configured but tokens are marked as encrypted"
|
||||
|
|
@ -119,6 +127,14 @@ class JiraHistoryConnector:
|
|||
f"Failed to decrypt Jira credentials: {e!s}"
|
||||
) from e
|
||||
|
||||
# Final validation after decryption
|
||||
final_token = config_data.get("access_token")
|
||||
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
|
||||
raise ValueError(
|
||||
"Jira access token is invalid or empty. "
|
||||
"Please reconnect your Jira account."
|
||||
)
|
||||
|
||||
try:
|
||||
self._credentials = AtlassianAuthCredentialsBase.from_dict(
|
||||
config_data
|
||||
|
|
|
|||
|
|
@ -116,6 +116,14 @@ class LinearConnector:
|
|||
|
||||
config_data = connector.config.copy()
|
||||
|
||||
# Check if access_token exists before processing
|
||||
raw_access_token = config_data.get("access_token")
|
||||
if not raw_access_token:
|
||||
raise ValueError(
|
||||
"Linear access token not found. "
|
||||
"Please reconnect your Linear account."
|
||||
)
|
||||
|
||||
# Decrypt credentials if they are encrypted
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
|
|
@ -143,6 +151,14 @@ class LinearConnector:
|
|||
f"Failed to decrypt Linear credentials: {e!s}"
|
||||
) from e
|
||||
|
||||
# Final validation after decryption
|
||||
final_token = config_data.get("access_token")
|
||||
if not final_token or (isinstance(final_token, str) and not final_token.strip()):
|
||||
raise ValueError(
|
||||
"Linear access token is invalid or empty. "
|
||||
"Please reconnect your Linear account."
|
||||
)
|
||||
|
||||
try:
|
||||
self._credentials = LinearAuthCredentialsBase.from_dict(config_data)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1449,6 +1449,13 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
display_name = Column(String, nullable=True)
|
||||
avatar_url = Column(String, nullable=True)
|
||||
|
||||
# Refresh tokens for this user
|
||||
refresh_tokens = relationship(
|
||||
"RefreshToken",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
|
|
@ -1514,6 +1521,43 @@ else:
|
|||
display_name = Column(String, nullable=True)
|
||||
avatar_url = Column(String, nullable=True)
|
||||
|
||||
# Refresh tokens for this user
|
||||
refresh_tokens = relationship(
|
||||
"RefreshToken",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
class RefreshToken(Base, TimestampMixin):
|
||||
"""
|
||||
Stores refresh tokens for user session management.
|
||||
Each row represents one device/session.
|
||||
"""
|
||||
|
||||
__tablename__ = "refresh_tokens"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
user = relationship("User", back_populates="refresh_tokens")
|
||||
token_hash = Column(String(256), unique=True, nullable=False, index=True)
|
||||
expires_at = Column(TIMESTAMP(timezone=True), nullable=False, index=True)
|
||||
is_revoked = Column(Boolean, default=False, nullable=False)
|
||||
family_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
return datetime.now(UTC) >= self.expires_at
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
return not self.is_expired and not self.is_revoked
|
||||
|
||||
|
||||
engine = create_async_engine(DATABASE_URL)
|
||||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
|
|
|||
|
|
@ -104,3 +104,33 @@ SUMMARY_PROMPT = (
|
|||
SUMMARY_PROMPT_TEMPLATE = PromptTemplate(
|
||||
input_variables=["document"], template=SUMMARY_PROMPT
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Chat Title Generation Prompt
|
||||
# =============================================================================
|
||||
|
||||
TITLE_GENERATION_PROMPT = """Generate a concise, descriptive title for the following conversation.
|
||||
|
||||
<rules>
|
||||
- The title MUST be between 1 and 6 words
|
||||
- The title MUST be on a single line
|
||||
- Capture the main topic or intent of the conversation
|
||||
- Do NOT use quotes, punctuation, or formatting
|
||||
- Do NOT include words like "Chat about" or "Discussion of"
|
||||
- Return ONLY the title, nothing else
|
||||
</rules>
|
||||
|
||||
<user_query>
|
||||
{user_query}
|
||||
</user_query>
|
||||
|
||||
<assistant_response>
|
||||
{assistant_response}
|
||||
</assistant_response>
|
||||
|
||||
Title:"""
|
||||
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE = PromptTemplate(
|
||||
input_variables=["user_query", "assistant_response"],
|
||||
template=TITLE_GENERATION_PROMPT,
|
||||
)
|
||||
|
|
|
|||
93
surfsense_backend/app/routes/auth_routes.py
Normal file
93
surfsense_backend/app/routes/auth_routes.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
"""Authentication routes for refresh token management."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import User, async_session_maker
|
||||
from app.schemas.auth import (
|
||||
LogoutAllResponse,
|
||||
LogoutRequest,
|
||||
LogoutResponse,
|
||||
RefreshTokenRequest,
|
||||
RefreshTokenResponse,
|
||||
)
|
||||
from app.users import current_active_user, get_jwt_strategy
|
||||
from app.utils.refresh_tokens import (
|
||||
revoke_all_user_tokens,
|
||||
revoke_refresh_token,
|
||||
rotate_refresh_token,
|
||||
validate_refresh_token,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/auth/jwt", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=RefreshTokenResponse)
|
||||
async def refresh_access_token(request: RefreshTokenRequest):
|
||||
"""
|
||||
Exchange a valid refresh token for a new access token and refresh token.
|
||||
Implements token rotation for security.
|
||||
"""
|
||||
token_record = await validate_refresh_token(request.refresh_token)
|
||||
|
||||
if not token_record:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired refresh token",
|
||||
)
|
||||
|
||||
# Get user from token record
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == token_record.user_id)
|
||||
)
|
||||
user = result.scalars().first()
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
# Generate new access token
|
||||
strategy = get_jwt_strategy()
|
||||
access_token = await strategy.write_token(user)
|
||||
|
||||
# Rotate refresh token
|
||||
new_refresh_token = await rotate_refresh_token(token_record)
|
||||
|
||||
logger.info(f"Refreshed token for user {user.id}")
|
||||
|
||||
return RefreshTokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/revoke", response_model=LogoutResponse)
|
||||
async def revoke_token(request: LogoutRequest):
|
||||
"""
|
||||
Logout current device by revoking the provided refresh token.
|
||||
Does not require authentication - just the refresh token.
|
||||
"""
|
||||
revoked = await revoke_refresh_token(request.refresh_token)
|
||||
if revoked:
|
||||
logger.info("User logged out from current device - token revoked")
|
||||
else:
|
||||
logger.warning("Logout called but no matching token found to revoke")
|
||||
return LogoutResponse()
|
||||
|
||||
|
||||
@router.post("/logout-all", response_model=LogoutAllResponse)
|
||||
async def logout_all_devices(user: User = Depends(current_active_user)):
|
||||
"""
|
||||
Logout from all devices by revoking all refresh tokens for the user.
|
||||
Requires valid access token.
|
||||
"""
|
||||
await revoke_all_user_tokens(user.id)
|
||||
logger.info(f"User {user.id} logged out from all devices")
|
||||
return LogoutAllResponse()
|
||||
|
|
@ -886,30 +886,8 @@ async def append_message(
|
|||
# Update thread's updated_at timestamp
|
||||
thread.updated_at = datetime.now(UTC)
|
||||
|
||||
# Auto-generate title from first user message if title is still default
|
||||
if thread.title == "New Chat" and role_str == "user":
|
||||
# Extract text content for title
|
||||
content = message.content
|
||||
if isinstance(content, str):
|
||||
title_text = content
|
||||
elif isinstance(content, list):
|
||||
# Find first text content
|
||||
title_text = ""
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
title_text = part.get("text", "")
|
||||
break
|
||||
elif isinstance(part, str):
|
||||
title_text = part
|
||||
break
|
||||
else:
|
||||
title_text = str(content)
|
||||
|
||||
# Truncate title
|
||||
if title_text:
|
||||
thread.title = title_text[:100] + (
|
||||
"..." if len(title_text) > 100 else ""
|
||||
)
|
||||
# Note: Title generation now happens in stream_new_chat.py after the first response
|
||||
# using LLM to generate a descriptive title (with truncation as fallback)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(db_message)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,10 @@
|
|||
from .auth import (
|
||||
LogoutAllResponse,
|
||||
LogoutRequest,
|
||||
LogoutResponse,
|
||||
RefreshTokenRequest,
|
||||
RefreshTokenResponse,
|
||||
)
|
||||
from .base import IDModel, TimestampModel
|
||||
from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate
|
||||
from .documents import (
|
||||
|
|
@ -119,6 +126,10 @@ __all__ = [
|
|||
"LogFilter",
|
||||
"LogRead",
|
||||
"LogUpdate",
|
||||
# Auth schemas
|
||||
"LogoutAllResponse",
|
||||
"LogoutRequest",
|
||||
"LogoutResponse",
|
||||
# Search source connector schemas
|
||||
"MCPConnectorCreate",
|
||||
"MCPConnectorRead",
|
||||
|
|
@ -148,6 +159,8 @@ __all__ = [
|
|||
"PodcastCreate",
|
||||
"PodcastRead",
|
||||
"PodcastUpdate",
|
||||
"RefreshTokenRequest",
|
||||
"RefreshTokenResponse",
|
||||
"RoleCreate",
|
||||
"RoleRead",
|
||||
"RoleUpdate",
|
||||
|
|
|
|||
35
surfsense_backend/app/schemas/auth.py
Normal file
35
surfsense_backend/app/schemas/auth.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
"""Authentication schemas for refresh token endpoints."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
"""Request body for token refresh endpoint."""
|
||||
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class RefreshTokenResponse(BaseModel):
|
||||
"""Response from token refresh endpoint."""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class LogoutRequest(BaseModel):
|
||||
"""Request body for logout endpoint (current device)."""
|
||||
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class LogoutResponse(BaseModel):
|
||||
"""Response from logout endpoint (current device)."""
|
||||
|
||||
detail: str = "Successfully logged out"
|
||||
|
||||
|
||||
class LogoutAllResponse(BaseModel):
|
||||
"""Response from logout all devices endpoint."""
|
||||
|
||||
detail: str = "Successfully logged out from all devices"
|
||||
|
|
@ -5,7 +5,7 @@ Service layer for chat comments and mentions.
|
|||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy import delete, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
|
|
@ -103,6 +103,37 @@ async def process_mentions(
|
|||
return mentions_map
|
||||
|
||||
|
||||
async def get_comment_thread_participants(
|
||||
session: AsyncSession,
|
||||
parent_comment_id: int,
|
||||
exclude_user_ids: set[UUID],
|
||||
) -> list[UUID]:
|
||||
"""
|
||||
Get all unique authors in a comment thread (parent + replies), excluding specified users.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
parent_comment_id: ID of the parent comment
|
||||
exclude_user_ids: Set of user IDs to exclude (e.g., replier, mentioned users)
|
||||
|
||||
Returns:
|
||||
List of user UUIDs who have participated in the thread
|
||||
"""
|
||||
query = select(ChatComment.author_id).where(
|
||||
or_(
|
||||
ChatComment.id == parent_comment_id,
|
||||
ChatComment.parent_id == parent_comment_id,
|
||||
),
|
||||
ChatComment.author_id.isnot(None),
|
||||
)
|
||||
|
||||
if exclude_user_ids:
|
||||
query = query.where(ChatComment.author_id.notin_(list(exclude_user_ids)))
|
||||
|
||||
result = await session.execute(query.distinct())
|
||||
return [row[0] for row in result.fetchall()]
|
||||
|
||||
|
||||
async def get_comments_for_message(
|
||||
session: AsyncSession,
|
||||
message_id: int,
|
||||
|
|
@ -436,6 +467,31 @@ async def create_reply(
|
|||
search_space_id=search_space_id,
|
||||
)
|
||||
|
||||
# Notify thread participants (excluding replier and mentioned users)
|
||||
mentioned_user_ids = set(mentions_map.keys())
|
||||
exclude_ids = {user.id} | mentioned_user_ids
|
||||
participants = await get_comment_thread_participants(
|
||||
session, comment_id, exclude_ids
|
||||
)
|
||||
for participant_id in participants:
|
||||
if participant_id in mentioned_user_ids:
|
||||
continue
|
||||
await NotificationService.comment_reply.notify_comment_reply(
|
||||
session=session,
|
||||
user_id=participant_id,
|
||||
reply_id=reply.id,
|
||||
parent_comment_id=comment_id,
|
||||
message_id=parent_comment.message_id,
|
||||
thread_id=thread.id,
|
||||
thread_title=thread.title or "Untitled thread",
|
||||
author_id=str(user.id),
|
||||
author_name=author_name,
|
||||
author_avatar_url=user.avatar_url,
|
||||
author_email=user.email,
|
||||
content_preview=content_preview[:200],
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
|
||||
author = AuthorResponse(
|
||||
id=user.id,
|
||||
display_name=user.display_name,
|
||||
|
|
|
|||
|
|
@ -479,6 +479,31 @@ class VercelStreamingService:
|
|||
},
|
||||
)
|
||||
|
||||
def format_thread_title_update(self, thread_id: int, title: str) -> str:
|
||||
"""
|
||||
Format a thread title update notification (SurfSense specific).
|
||||
|
||||
This is sent after the first response in a thread to update the
|
||||
auto-generated title based on the conversation content.
|
||||
|
||||
Args:
|
||||
thread_id: The ID of the thread being updated
|
||||
title: The new title for the thread
|
||||
|
||||
Returns:
|
||||
str: SSE formatted thread title update data part
|
||||
|
||||
Example output:
|
||||
data: {"type":"data-thread-title-update","data":{"threadId":123,"title":"New Title"}}
|
||||
"""
|
||||
return self.format_data(
|
||||
"thread-title-update",
|
||||
{
|
||||
"threadId": thread_id,
|
||||
"title": title,
|
||||
},
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Error Part
|
||||
# =========================================================================
|
||||
|
|
|
|||
|
|
@ -861,6 +861,98 @@ class MentionNotificationHandler(BaseNotificationHandler):
|
|||
raise
|
||||
|
||||
|
||||
class CommentReplyNotificationHandler(BaseNotificationHandler):
|
||||
"""Handler for comment reply notifications."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("comment_reply")
|
||||
|
||||
async def find_notification_by_reply(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
reply_id: int,
|
||||
user_id: UUID,
|
||||
) -> Notification | None:
|
||||
query = select(Notification).where(
|
||||
Notification.type == self.notification_type,
|
||||
Notification.user_id == user_id,
|
||||
Notification.notification_metadata["reply_id"].astext == str(reply_id),
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def notify_comment_reply(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
user_id: UUID,
|
||||
reply_id: int,
|
||||
parent_comment_id: int,
|
||||
message_id: int,
|
||||
thread_id: int,
|
||||
thread_title: str,
|
||||
author_id: str,
|
||||
author_name: str,
|
||||
author_avatar_url: str | None,
|
||||
author_email: str,
|
||||
content_preview: str,
|
||||
search_space_id: int,
|
||||
) -> Notification:
|
||||
existing = await self.find_notification_by_reply(session, reply_id, user_id)
|
||||
if existing:
|
||||
logger.info(
|
||||
f"Notification already exists for reply {reply_id} to user {user_id}"
|
||||
)
|
||||
return existing
|
||||
|
||||
title = f"{author_name} replied in a thread"
|
||||
message = content_preview[:100] + ("..." if len(content_preview) > 100 else "")
|
||||
|
||||
metadata = {
|
||||
"reply_id": reply_id,
|
||||
"parent_comment_id": parent_comment_id,
|
||||
"message_id": message_id,
|
||||
"thread_id": thread_id,
|
||||
"thread_title": thread_title,
|
||||
"author_id": author_id,
|
||||
"author_name": author_name,
|
||||
"author_avatar_url": author_avatar_url,
|
||||
"author_email": author_email,
|
||||
"content_preview": content_preview[:200],
|
||||
}
|
||||
|
||||
try:
|
||||
notification = Notification(
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
type=self.notification_type,
|
||||
title=title,
|
||||
message=message,
|
||||
notification_metadata=metadata,
|
||||
)
|
||||
session.add(notification)
|
||||
await session.commit()
|
||||
await session.refresh(notification)
|
||||
logger.info(
|
||||
f"Created comment_reply notification {notification.id} for user {user_id}"
|
||||
)
|
||||
return notification
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
if (
|
||||
"duplicate key" in str(e).lower()
|
||||
or "unique constraint" in str(e).lower()
|
||||
):
|
||||
logger.warning(
|
||||
f"Duplicate notification for reply {reply_id} to user {user_id}"
|
||||
)
|
||||
existing = await self.find_notification_by_reply(
|
||||
session, reply_id, user_id
|
||||
)
|
||||
if existing:
|
||||
return existing
|
||||
raise
|
||||
|
||||
|
||||
class PageLimitNotificationHandler(BaseNotificationHandler):
|
||||
"""Handler for page limit exceeded notifications."""
|
||||
|
||||
|
|
@ -959,6 +1051,7 @@ class NotificationService:
|
|||
connector_indexing = ConnectorIndexingNotificationHandler()
|
||||
document_processing = DocumentProcessingNotificationHandler()
|
||||
mention = MentionNotificationHandler()
|
||||
comment_reply = CommentReplyNotificationHandler()
|
||||
page_limit = PageLimitNotificationHandler()
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -366,11 +366,14 @@ async def list_snapshots_for_thread(
|
|||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
if thread.created_by_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only the creator can view snapshots",
|
||||
)
|
||||
# Check permission to view public share links
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
thread.search_space_id,
|
||||
Permission.PUBLIC_SHARING_VIEW.value,
|
||||
"You don't have permission to view public share links",
|
||||
)
|
||||
|
||||
result = await session.execute(
|
||||
select(PublicChatSnapshot)
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ from app.services.chat_session_state_service import (
|
|||
clear_ai_responding,
|
||||
set_ai_responding,
|
||||
)
|
||||
from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.services.new_streaming_service import VercelStreamingService
|
||||
from app.utils.content_utils import bootstrap_history_from_db
|
||||
|
|
@ -1208,6 +1209,59 @@ async def stream_new_chat(
|
|||
if completion_event:
|
||||
yield completion_event
|
||||
|
||||
# Generate LLM title for new chats after first response
|
||||
# Check if this is the first assistant response by counting existing assistant messages
|
||||
from app.db import NewChatMessage, NewChatThread
|
||||
from sqlalchemy import func
|
||||
|
||||
assistant_count_result = await session.execute(
|
||||
select(func.count(NewChatMessage.id)).filter(
|
||||
NewChatMessage.thread_id == chat_id,
|
||||
NewChatMessage.role == "assistant",
|
||||
)
|
||||
)
|
||||
assistant_message_count = assistant_count_result.scalar() or 0
|
||||
|
||||
# Only generate title on the first response (no prior assistant messages)
|
||||
if assistant_message_count == 0:
|
||||
generated_title = None
|
||||
try:
|
||||
# Generate title using the same LLM
|
||||
title_chain = TITLE_GENERATION_PROMPT_TEMPLATE | llm
|
||||
# Truncate inputs to avoid context length issues
|
||||
truncated_query = user_query[:500]
|
||||
truncated_response = accumulated_text[:1000]
|
||||
title_result = await title_chain.ainvoke({
|
||||
"user_query": truncated_query,
|
||||
"assistant_response": truncated_response,
|
||||
})
|
||||
|
||||
# Extract and clean the title
|
||||
if title_result and hasattr(title_result, "content"):
|
||||
raw_title = title_result.content.strip()
|
||||
# Validate the title (reasonable length)
|
||||
if raw_title and len(raw_title) <= 100:
|
||||
# Remove any quotes or extra formatting
|
||||
generated_title = raw_title.strip('"\'')
|
||||
except Exception:
|
||||
generated_title = None
|
||||
|
||||
# Only update if LLM succeeded (keep truncated prompt title as fallback)
|
||||
if generated_title:
|
||||
# Fetch thread and update title
|
||||
thread_result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
||||
)
|
||||
thread = thread_result.scalars().first()
|
||||
if thread:
|
||||
thread.title = generated_title
|
||||
await session.commit()
|
||||
|
||||
# Notify frontend of the title update
|
||||
yield streaming_service.format_thread_title_update(
|
||||
chat_id, generated_title
|
||||
)
|
||||
|
||||
# Finish the step and message
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
|
|
|
|||
|
|
@ -23,17 +23,20 @@ from app.db import (
|
|||
get_default_roles_config,
|
||||
get_user_db,
|
||||
)
|
||||
from app.utils.refresh_tokens import create_refresh_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BearerResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
SECRET = config.SECRET_KEY
|
||||
|
||||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
|
|
@ -183,7 +186,10 @@ async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db
|
|||
|
||||
|
||||
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
||||
return JWTStrategy(secret=SECRET, lifetime_seconds=3600 * 24)
|
||||
return JWTStrategy(
|
||||
secret=SECRET,
|
||||
lifetime_seconds=config.ACCESS_TOKEN_LIFETIME_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
# # COOKIE AUTH | Uncomment if you want to use cookie auth.
|
||||
|
|
@ -209,9 +215,30 @@ def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
|||
# BEARER AUTH CODE.
|
||||
class CustomBearerTransport(BearerTransport):
|
||||
async def get_login_response(self, token: str) -> Response:
|
||||
bearer_response = BearerResponse(access_token=token, token_type="bearer")
|
||||
redirect_url = f"{config.NEXT_FRONTEND_URL}/auth/callback?token={bearer_response.access_token}"
|
||||
import jwt
|
||||
|
||||
# Decode JWT to get user_id for refresh token creation
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET, algorithms=["HS256"], options={"verify_aud": False})
|
||||
user_id = uuid.UUID(payload.get("sub"))
|
||||
refresh_token = await create_refresh_token(user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create refresh token: {e}")
|
||||
# Fall back to response without refresh token
|
||||
refresh_token = ""
|
||||
|
||||
bearer_response = BearerResponse(
|
||||
access_token=token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="bearer",
|
||||
)
|
||||
|
||||
if config.AUTH_TYPE == "GOOGLE":
|
||||
redirect_url = (
|
||||
f"{config.NEXT_FRONTEND_URL}/auth/callback"
|
||||
f"?token={bearer_response.access_token}"
|
||||
f"&refresh_token={bearer_response.refresh_token}"
|
||||
)
|
||||
return RedirectResponse(redirect_url, status_code=302)
|
||||
else:
|
||||
return JSONResponse(bearer_response.model_dump())
|
||||
|
|
|
|||
153
surfsense_backend/app/utils/refresh_tokens.py
Normal file
153
surfsense_backend/app/utils/refresh_tokens.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
"""Utilities for managing refresh tokens."""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from app.config import config
|
||||
from app.db import RefreshToken, async_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_refresh_token() -> str:
|
||||
"""Generate a cryptographically secure refresh token."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
"""Hash a token for secure storage."""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
async def create_refresh_token(
|
||||
user_id: uuid.UUID,
|
||||
family_id: uuid.UUID | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create and store a new refresh token for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
family_id: Optional family ID for token rotation
|
||||
|
||||
Returns:
|
||||
The plaintext refresh token
|
||||
"""
|
||||
token = generate_refresh_token()
|
||||
token_hash = hash_token(token)
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=config.REFRESH_TOKEN_LIFETIME_SECONDS
|
||||
)
|
||||
|
||||
if family_id is None:
|
||||
family_id = uuid.uuid4()
|
||||
|
||||
async with async_session_maker() as session:
|
||||
refresh_token = RefreshToken(
|
||||
user_id=user_id,
|
||||
token_hash=token_hash,
|
||||
expires_at=expires_at,
|
||||
family_id=family_id,
|
||||
)
|
||||
session.add(refresh_token)
|
||||
await session.commit()
|
||||
|
||||
return token
|
||||
|
||||
|
||||
async def validate_refresh_token(token: str) -> RefreshToken | None:
|
||||
"""
|
||||
Validate a refresh token. Handles reuse detection.
|
||||
|
||||
Args:
|
||||
token: The plaintext refresh token
|
||||
|
||||
Returns:
|
||||
RefreshToken if valid, None otherwise
|
||||
"""
|
||||
token_hash = hash_token(token)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||
)
|
||||
refresh_token = result.scalars().first()
|
||||
|
||||
if not refresh_token:
|
||||
return None
|
||||
|
||||
# Reuse detection: revoked token used while family has active tokens
|
||||
if refresh_token.is_revoked:
|
||||
active = await session.execute(
|
||||
select(RefreshToken).where(
|
||||
RefreshToken.family_id == refresh_token.family_id,
|
||||
RefreshToken.is_revoked == False, # noqa: E712
|
||||
RefreshToken.expires_at > datetime.now(UTC),
|
||||
)
|
||||
)
|
||||
if active.scalars().first():
|
||||
# Revoke entire family
|
||||
await session.execute(
|
||||
update(RefreshToken)
|
||||
.where(RefreshToken.family_id == refresh_token.family_id)
|
||||
.values(is_revoked=True)
|
||||
)
|
||||
await session.commit()
|
||||
logger.warning(f"Token reuse detected for user {refresh_token.user_id}")
|
||||
return None
|
||||
|
||||
if refresh_token.is_expired:
|
||||
return None
|
||||
|
||||
return refresh_token
|
||||
|
||||
|
||||
async def rotate_refresh_token(old_token: RefreshToken) -> str:
|
||||
"""Revoke old token and create new one in same family."""
|
||||
async with async_session_maker() as session:
|
||||
await session.execute(
|
||||
update(RefreshToken)
|
||||
.where(RefreshToken.id == old_token.id)
|
||||
.values(is_revoked=True)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
return await create_refresh_token(old_token.user_id, old_token.family_id)
|
||||
|
||||
|
||||
async def revoke_refresh_token(token: str) -> bool:
|
||||
"""
|
||||
Revoke a single refresh token by its plaintext value.
|
||||
|
||||
Args:
|
||||
token: The plaintext refresh token
|
||||
|
||||
Returns:
|
||||
True if token was found and revoked, False otherwise
|
||||
"""
|
||||
token_hash = hash_token(token)
|
||||
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
update(RefreshToken)
|
||||
.where(RefreshToken.token_hash == token_hash)
|
||||
.values(is_revoked=True)
|
||||
)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
|
||||
async def revoke_all_user_tokens(user_id: uuid.UUID) -> None:
|
||||
"""Revoke all refresh tokens for a user (logout all devices)."""
|
||||
async with async_session_maker() as session:
|
||||
await session.execute(
|
||||
update(RefreshToken)
|
||||
.where(RefreshToken.user_id == user_id)
|
||||
.values(is_revoked=True)
|
||||
)
|
||||
await session.commit()
|
||||
Loading…
Add table
Add a link
Reference in a new issue