mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-02 04:12:47 +02:00
Merge remote-tracking branch 'upstream/dev'
This commit is contained in:
commit
4e7e8ccd7e
141 changed files with 5771 additions and 5223 deletions
|
|
@ -0,0 +1,39 @@
|
|||
"""103_add_last_login_to_user
|
||||
|
||||
Revision ID: 103
|
||||
Revises: 102
|
||||
Create Date: 2026-03-08
|
||||
|
||||
Adds last_login timestamp column to the user table so we can track
|
||||
when each user last authenticated. The column is nullable — existing
|
||||
rows will have NULL until the user's next login.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "103"
|
||||
down_revision: str | None = "102"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
existing_columns = [col["name"] for col in sa.inspect(conn).get_columns("user")]
|
||||
|
||||
if "last_login" not in existing_columns:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("last_login", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "last_login")
|
||||
|
|
@ -1720,6 +1720,8 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
display_name = Column(String, nullable=True)
|
||||
avatar_url = Column(String, nullable=True)
|
||||
|
||||
last_login = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
|
||||
# Refresh tokens for this user
|
||||
refresh_tokens = relationship(
|
||||
"RefreshToken",
|
||||
|
|
@ -1820,6 +1822,8 @@ else:
|
|||
display_name = Column(String, nullable=True)
|
||||
avatar_url = Column(String, nullable=True)
|
||||
|
||||
last_login = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
|
||||
# Refresh tokens for this user
|
||||
refresh_tokens = relationship(
|
||||
"RefreshToken",
|
||||
|
|
|
|||
|
|
@ -109,12 +109,12 @@ SUMMARY_PROMPT_TEMPLATE = PromptTemplate(
|
|||
# Chat Title Generation Prompt
|
||||
# =============================================================================
|
||||
|
||||
TITLE_GENERATION_PROMPT = """Generate a concise, descriptive title for the following conversation.
|
||||
TITLE_GENERATION_PROMPT = """Generate a concise, descriptive title for the following user query.
|
||||
|
||||
<rules>
|
||||
- The title MUST be between 1 and 6 words
|
||||
- The title MUST be on a single line
|
||||
- Capture the main topic or intent of the conversation
|
||||
- Capture the main topic or intent of the query
|
||||
- Do NOT use quotes, punctuation, or formatting
|
||||
- Do NOT include words like "Chat about" or "Discussion of"
|
||||
- Return ONLY the title, nothing else
|
||||
|
|
@ -124,13 +124,9 @@ TITLE_GENERATION_PROMPT = """Generate a concise, descriptive title for the follo
|
|||
{user_query}
|
||||
</user_query>
|
||||
|
||||
<assistant_response>
|
||||
{assistant_response}
|
||||
</assistant_response>
|
||||
|
||||
Title:"""
|
||||
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE = PromptTemplate(
|
||||
input_variables=["user_query", "assistant_response"],
|
||||
input_variables=["user_query"],
|
||||
template=TITLE_GENERATION_PROMPT,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -320,6 +320,8 @@ async def read_documents(
|
|||
page_size: int = 50,
|
||||
search_space_id: int | None = None,
|
||||
document_types: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
|
|
@ -392,6 +394,19 @@ async def read_documents(
|
|||
total_result = await session.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# Apply sorting
|
||||
from sqlalchemy import asc as sa_asc, desc as sa_desc
|
||||
|
||||
sort_column_map = {
|
||||
"created_at": Document.created_at,
|
||||
"title": Document.title,
|
||||
"document_type": Document.document_type,
|
||||
}
|
||||
sort_col = sort_column_map.get(sort_by, Document.created_at)
|
||||
query = query.order_by(
|
||||
sa_desc(sort_col) if sort_order == "desc" else sa_asc(sort_col)
|
||||
)
|
||||
|
||||
# Calculate offset
|
||||
offset = 0
|
||||
if skip is not None:
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from typing import Literal
|
|||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import desc, func, select, update
|
||||
from sqlalchemy import desc, func, literal, literal_column, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Notification, User, get_async_session
|
||||
|
|
@ -23,9 +23,26 @@ SYNC_WINDOW_DAYS = 14
|
|||
|
||||
# Valid notification types - must match frontend InboxItemTypeEnum
|
||||
NotificationType = Literal[
|
||||
"connector_indexing", "document_processing", "new_mention", "page_limit_exceeded"
|
||||
"connector_indexing",
|
||||
"connector_deletion",
|
||||
"document_processing",
|
||||
"new_mention",
|
||||
"comment_reply",
|
||||
"page_limit_exceeded",
|
||||
]
|
||||
|
||||
# Category-to-types mapping for filtering by tab
|
||||
NotificationCategory = Literal["comments", "status"]
|
||||
CATEGORY_TYPES: dict[str, tuple[str, ...]] = {
|
||||
"comments": ("new_mention", "comment_reply"),
|
||||
"status": (
|
||||
"connector_indexing",
|
||||
"connector_deletion",
|
||||
"document_processing",
|
||||
"page_limit_exceeded",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class NotificationResponse(BaseModel):
|
||||
"""Response model for a single notification."""
|
||||
|
|
@ -69,6 +86,21 @@ class MarkAllReadResponse(BaseModel):
|
|||
updated_count: int
|
||||
|
||||
|
||||
class SourceTypeItem(BaseModel):
|
||||
"""A single source type with its category and count."""
|
||||
|
||||
key: str
|
||||
type: str
|
||||
category: str # "connector" or "document"
|
||||
count: int
|
||||
|
||||
|
||||
class SourceTypesResponse(BaseModel):
|
||||
"""Response for notification source types used in status tab filter."""
|
||||
|
||||
sources: list[SourceTypeItem]
|
||||
|
||||
|
||||
class UnreadCountResponse(BaseModel):
|
||||
"""Response for unread count with split between recent and older items."""
|
||||
|
||||
|
|
@ -76,12 +108,86 @@ class UnreadCountResponse(BaseModel):
|
|||
recent_unread: int # Within SYNC_WINDOW_DAYS
|
||||
|
||||
|
||||
@router.get("/source-types", response_model=SourceTypesResponse)
|
||||
async def get_notification_source_types(
|
||||
search_space_id: int | None = Query(None, description="Filter by search space ID"),
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> SourceTypesResponse:
|
||||
"""
|
||||
Get all distinct connector types and document types from the user's
|
||||
status notifications. Used to populate the filter dropdown in the
|
||||
inbox Status tab so that all types are shown regardless of pagination.
|
||||
"""
|
||||
base_filter = [Notification.user_id == user.id]
|
||||
|
||||
if search_space_id is not None:
|
||||
base_filter.append(
|
||||
(Notification.search_space_id == search_space_id)
|
||||
| (Notification.search_space_id.is_(None))
|
||||
)
|
||||
|
||||
connector_type_expr = Notification.notification_metadata["connector_type"].astext
|
||||
connector_query = (
|
||||
select(
|
||||
connector_type_expr.label("source_type"),
|
||||
literal("connector").label("category"),
|
||||
func.count(Notification.id).label("cnt"),
|
||||
)
|
||||
.where(
|
||||
*base_filter,
|
||||
Notification.type.in_(("connector_indexing", "connector_deletion")),
|
||||
connector_type_expr.isnot(None),
|
||||
)
|
||||
.group_by(literal_column("source_type"))
|
||||
)
|
||||
|
||||
document_type_expr = Notification.notification_metadata["document_type"].astext
|
||||
document_query = (
|
||||
select(
|
||||
document_type_expr.label("source_type"),
|
||||
literal("document").label("category"),
|
||||
func.count(Notification.id).label("cnt"),
|
||||
)
|
||||
.where(
|
||||
*base_filter,
|
||||
Notification.type.in_(("document_processing",)),
|
||||
document_type_expr.isnot(None),
|
||||
)
|
||||
.group_by(literal_column("source_type"))
|
||||
)
|
||||
|
||||
connector_result = await session.execute(connector_query)
|
||||
document_result = await session.execute(document_query)
|
||||
|
||||
sources = []
|
||||
for source_type, category, count in [
|
||||
*connector_result.all(),
|
||||
*document_result.all(),
|
||||
]:
|
||||
if not source_type:
|
||||
continue
|
||||
sources.append(
|
||||
SourceTypeItem(
|
||||
key=f"{category}:{source_type}",
|
||||
type=source_type,
|
||||
category=category,
|
||||
count=count,
|
||||
)
|
||||
)
|
||||
|
||||
return SourceTypesResponse(sources=sources)
|
||||
|
||||
|
||||
@router.get("/unread-count", response_model=UnreadCountResponse)
|
||||
async def get_unread_count(
|
||||
search_space_id: int | None = Query(None, description="Filter by search space ID"),
|
||||
type_filter: NotificationType | None = Query(
|
||||
None, alias="type", description="Filter by notification type"
|
||||
),
|
||||
category: NotificationCategory | None = Query(
|
||||
None, description="Filter by category: 'comments' or 'status'"
|
||||
),
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> UnreadCountResponse:
|
||||
|
|
@ -116,6 +222,10 @@ async def get_unread_count(
|
|||
if type_filter:
|
||||
base_filter.append(Notification.type == type_filter)
|
||||
|
||||
# Filter by category (maps to multiple types)
|
||||
if category:
|
||||
base_filter.append(Notification.type.in_(CATEGORY_TYPES[category]))
|
||||
|
||||
# Total unread count (all time)
|
||||
total_query = select(func.count(Notification.id)).where(*base_filter)
|
||||
total_result = await session.execute(total_query)
|
||||
|
|
@ -141,6 +251,17 @@ async def list_notifications(
|
|||
type_filter: NotificationType | None = Query(
|
||||
None, alias="type", description="Filter by notification type"
|
||||
),
|
||||
category: NotificationCategory | None = Query(
|
||||
None, description="Filter by category: 'comments' or 'status'"
|
||||
),
|
||||
source_type: str | None = Query(
|
||||
None,
|
||||
description="Filter by source type, e.g. 'connector:GITHUB_CONNECTOR' or 'doctype:FILE'",
|
||||
),
|
||||
filter: str | None = Query(
|
||||
None,
|
||||
description="Filter preset: 'unread' for unread only, 'errors' for failed/error items only",
|
||||
),
|
||||
before_date: str | None = Query(
|
||||
None, description="Get notifications before this ISO date (for pagination)"
|
||||
),
|
||||
|
|
@ -182,6 +303,45 @@ async def list_notifications(
|
|||
query = query.where(Notification.type == type_filter)
|
||||
count_query = count_query.where(Notification.type == type_filter)
|
||||
|
||||
# Filter by category (maps to multiple types)
|
||||
if category:
|
||||
cat_types = CATEGORY_TYPES[category]
|
||||
query = query.where(Notification.type.in_(cat_types))
|
||||
count_query = count_query.where(Notification.type.in_(cat_types))
|
||||
|
||||
# Filter by source type (connector or document type from JSONB metadata)
|
||||
if source_type:
|
||||
if source_type.startswith("connector:"):
|
||||
connector_val = source_type[len("connector:") :]
|
||||
source_filter = Notification.type.in_(
|
||||
("connector_indexing", "connector_deletion")
|
||||
) & (
|
||||
Notification.notification_metadata["connector_type"].astext
|
||||
== connector_val
|
||||
)
|
||||
query = query.where(source_filter)
|
||||
count_query = count_query.where(source_filter)
|
||||
elif source_type.startswith("doctype:"):
|
||||
doctype_val = source_type[len("doctype:") :]
|
||||
source_filter = Notification.type.in_(("document_processing",)) & (
|
||||
Notification.notification_metadata["document_type"].astext
|
||||
== doctype_val
|
||||
)
|
||||
query = query.where(source_filter)
|
||||
count_query = count_query.where(source_filter)
|
||||
|
||||
# Filter by preset: 'unread' or 'errors'
|
||||
if filter == "unread":
|
||||
unread_filter = Notification.read == False # noqa: E712
|
||||
query = query.where(unread_filter)
|
||||
count_query = count_query.where(unread_filter)
|
||||
elif filter == "errors":
|
||||
error_filter = (Notification.type == "page_limit_exceeded") | (
|
||||
Notification.notification_metadata["status"].astext == "failed"
|
||||
)
|
||||
query = query.where(error_filter)
|
||||
count_query = count_query.where(error_filter)
|
||||
|
||||
# Filter by date (for efficient pagination of older items)
|
||||
if before_date:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -510,6 +510,7 @@ async def list_members(
|
|||
"user_email": member_user.email if member_user else None,
|
||||
"user_display_name": member_user.display_name if member_user else None,
|
||||
"user_avatar_url": member_user.avatar_url if member_user else None,
|
||||
"user_last_login": member_user.last_login if member_user else None,
|
||||
}
|
||||
response.append(membership_dict)
|
||||
|
||||
|
|
@ -602,6 +603,7 @@ async def update_member_role(
|
|||
"created_at": db_membership.created_at,
|
||||
"role": db_membership.role,
|
||||
"user_email": member_user.email if member_user else None,
|
||||
"user_last_login": member_user.last_login if member_user else None,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ class MembershipRead(BaseModel):
|
|||
user_email: str | None = None
|
||||
user_display_name: str | None = None
|
||||
user_avatar_url: str | None = None
|
||||
user_last_login: datetime | None = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
|
|
|||
|
|
@ -1366,6 +1366,38 @@ async def stream_new_chat(
|
|||
del mentioned_documents, mentioned_surfsense_docs, recent_reports
|
||||
del langchain_messages, final_query
|
||||
|
||||
# Check if this is the first assistant response so we can generate
|
||||
# a title in parallel with the agent stream (better UX than waiting
|
||||
# until after the full response).
|
||||
assistant_count_result = await session.execute(
|
||||
select(func.count(NewChatMessage.id)).filter(
|
||||
NewChatMessage.thread_id == chat_id,
|
||||
NewChatMessage.role == "assistant",
|
||||
)
|
||||
)
|
||||
is_first_response = (assistant_count_result.scalar() or 0) == 0
|
||||
|
||||
title_task: asyncio.Task[str | None] | None = None
|
||||
if is_first_response:
|
||||
|
||||
async def _generate_title() -> str | None:
|
||||
try:
|
||||
title_chain = TITLE_GENERATION_PROMPT_TEMPLATE | llm
|
||||
title_result = await title_chain.ainvoke(
|
||||
{"user_query": user_query[:500]}
|
||||
)
|
||||
if title_result and hasattr(title_result, "content"):
|
||||
raw_title = title_result.content.strip()
|
||||
if raw_title and len(raw_title) <= 100:
|
||||
return raw_title.strip("\"'")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
title_task = asyncio.create_task(_generate_title())
|
||||
|
||||
title_emitted = False
|
||||
|
||||
_t_stream_start = time.perf_counter()
|
||||
_first_event_logged = False
|
||||
async for sse in _stream_agent_events(
|
||||
|
|
@ -1390,6 +1422,23 @@ async def stream_new_chat(
|
|||
_first_event_logged = True
|
||||
yield sse
|
||||
|
||||
# Inject title update mid-stream as soon as the background task finishes
|
||||
if title_task is not None and title_task.done() and not title_emitted:
|
||||
generated_title = title_task.result()
|
||||
if generated_title:
|
||||
async with shielded_async_session() as title_session:
|
||||
title_thread_result = await title_session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
||||
)
|
||||
title_thread = title_thread_result.scalars().first()
|
||||
if title_thread:
|
||||
title_thread.title = generated_title
|
||||
await title_session.commit()
|
||||
yield streaming_service.format_thread_title_update(
|
||||
chat_id, generated_title
|
||||
)
|
||||
title_emitted = True
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Agent stream completed in %.3fs (chat_id=%s)",
|
||||
time.perf_counter() - _t_stream_start,
|
||||
|
|
@ -1398,62 +1447,28 @@ async def stream_new_chat(
|
|||
log_system_snapshot("stream_new_chat_END")
|
||||
|
||||
if stream_result.is_interrupted:
|
||||
if title_task is not None and not title_task.done():
|
||||
title_task.cancel()
|
||||
yield streaming_service.format_finish_step()
|
||||
yield streaming_service.format_finish()
|
||||
yield streaming_service.format_done()
|
||||
return
|
||||
|
||||
accumulated_text = stream_result.accumulated_text
|
||||
|
||||
assistant_count_result = await session.execute(
|
||||
select(func.count(NewChatMessage.id)).filter(
|
||||
NewChatMessage.thread_id == chat_id,
|
||||
NewChatMessage.role == "assistant",
|
||||
)
|
||||
)
|
||||
assistant_message_count = assistant_count_result.scalar() or 0
|
||||
|
||||
# Only generate title on the first response (no prior assistant messages)
|
||||
if assistant_message_count == 0:
|
||||
generated_title = None
|
||||
try:
|
||||
# Generate title using the same LLM
|
||||
title_chain = TITLE_GENERATION_PROMPT_TEMPLATE | llm
|
||||
# Truncate inputs to avoid context length issues
|
||||
truncated_query = user_query[:500]
|
||||
truncated_response = accumulated_text[:1000]
|
||||
title_result = await title_chain.ainvoke(
|
||||
{
|
||||
"user_query": truncated_query,
|
||||
"assistant_response": truncated_response,
|
||||
}
|
||||
)
|
||||
|
||||
# Extract and clean the title
|
||||
if title_result and hasattr(title_result, "content"):
|
||||
raw_title = title_result.content.strip()
|
||||
# Validate the title (reasonable length)
|
||||
if raw_title and len(raw_title) <= 100:
|
||||
# Remove any quotes or extra formatting
|
||||
generated_title = raw_title.strip("\"'")
|
||||
except Exception:
|
||||
generated_title = None
|
||||
|
||||
# Only update if LLM succeeded (keep truncated prompt title as fallback)
|
||||
# If the title task didn't finish during streaming, await it now
|
||||
if title_task is not None and not title_emitted:
|
||||
generated_title = await title_task
|
||||
if generated_title:
|
||||
# Fetch thread and update title
|
||||
thread_result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
||||
)
|
||||
thread = thread_result.scalars().first()
|
||||
if thread:
|
||||
thread.title = generated_title
|
||||
await session.commit()
|
||||
|
||||
# Notify frontend of the title update
|
||||
yield streaming_service.format_thread_title_update(
|
||||
chat_id, generated_title
|
||||
async with shielded_async_session() as title_session:
|
||||
title_thread_result = await title_session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == chat_id)
|
||||
)
|
||||
title_thread = title_thread_result.scalars().first()
|
||||
if title_thread:
|
||||
title_thread.title = generated_title
|
||||
await title_session.commit()
|
||||
yield streaming_service.format_thread_title_update(
|
||||
chat_id, generated_title
|
||||
)
|
||||
|
||||
# Finish the step and message
|
||||
yield streaming_service.format_finish_step()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
from fastapi import Depends, Request, Response
|
||||
|
|
@ -12,6 +13,7 @@ from fastapi_users.authentication import (
|
|||
)
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import update
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
|
|
@ -123,6 +125,23 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
|||
|
||||
return user
|
||||
|
||||
async def on_after_login(
|
||||
self,
|
||||
user: User,
|
||||
request: Request | None = None,
|
||||
response: Response | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
async with async_session_maker() as session:
|
||||
await session.execute(
|
||||
update(User)
|
||||
.where(User.id == user.id)
|
||||
.values(last_login=datetime.now(UTC))
|
||||
)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update last_login for user {user.id}: {e}")
|
||||
|
||||
async def on_after_register(self, user: User, request: Request | None = None):
|
||||
"""
|
||||
Called after a user registers. Creates a default search space for the user
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue