This commit is contained in:
Manoj Aggarwal 2026-01-22 20:57:48 -08:00
commit 49d51ba569
70 changed files with 4266 additions and 1362 deletions

View file

@ -1,81 +0,0 @@
"""Add COMPOSIO_CONNECTOR to SearchSourceConnectorType and DocumentType enums
Revision ID: 74
Revises: 73
Create Date: 2026-01-21
This migration adds the COMPOSIO_CONNECTOR enum value to both:
- searchsourceconnectortype (for connector type tracking)
- documenttype (for document type tracking)
Composio is a managed OAuth integration service that allows connecting
to various third-party services (Google Drive, Gmail, Calendar, etc.)
without requiring separate OAuth app verification.
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "74"
down_revision: str | None = "73"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
# Define the ENUM type names and the new value
CONNECTOR_ENUM = "searchsourceconnectortype"
CONNECTOR_NEW_VALUE = "COMPOSIO_CONNECTOR"
DOCUMENT_ENUM = "documenttype"
DOCUMENT_NEW_VALUE = "COMPOSIO_CONNECTOR"
def upgrade() -> None:
"""Upgrade schema - add COMPOSIO_CONNECTOR to connector and document enums safely."""
# Add COMPOSIO_CONNECTOR to searchsourceconnectortype only if not exists
op.execute(
f"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_enum
WHERE enumlabel = '{CONNECTOR_NEW_VALUE}'
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = '{CONNECTOR_ENUM}')
) THEN
ALTER TYPE {CONNECTOR_ENUM} ADD VALUE '{CONNECTOR_NEW_VALUE}';
END IF;
END$$;
"""
)
# Add COMPOSIO_CONNECTOR to documenttype only if not exists
op.execute(
f"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_enum
WHERE enumlabel = '{DOCUMENT_NEW_VALUE}'
AND enumtypid = (SELECT oid FROM pg_type WHERE typname = '{DOCUMENT_ENUM}')
) THEN
ALTER TYPE {DOCUMENT_ENUM} ADD VALUE '{DOCUMENT_NEW_VALUE}';
END IF;
END$$;
"""
)
def downgrade() -> None:
"""Downgrade schema - remove COMPOSIO_CONNECTOR from connector and document enums.
Note: PostgreSQL does not support removing enum values directly.
To properly downgrade, you would need to:
1. Delete any rows using the COMPOSIO_CONNECTOR value
2. Create new enums without COMPOSIO_CONNECTOR
3. Alter the columns to use the new enums
4. Drop the old enums
This is left as a no-op since removing enum values is complex
and typically not needed in practice.
"""
pass

View file

@ -0,0 +1,29 @@
"""No-op migration for Composio support
Revision ID: 74
Revises: 73
Create Date: 2026-01-21
NOTE: This migration is a no-op since Composio is not supported yet.
"""
from collections.abc import Sequence
# revision identifiers, used by Alembic.
revision: str = "74"
down_revision: str | None = "73"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""No-op upgrade for Composio support."""
pass
def downgrade() -> None:
"""No-op downgrade for Composio support.
Note: PostgreSQL does not support removing enum values directly.
"""
pass

View file

@ -0,0 +1,75 @@
"""Add chat_session_state table for live collaboration
Revision ID: 75
Revises: 74
Creates chat_session_state table to track AI responding state per thread.
Enables real-time sync via Electric SQL for shared chat collaboration.
"""
from collections.abc import Sequence
from alembic import op
revision: str = "75"
down_revision: str | None = "74"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Create chat_session_state table with Electric SQL replication."""
op.execute(
"""
CREATE TABLE IF NOT EXISTS chat_session_state (
id SERIAL PRIMARY KEY,
thread_id INTEGER NOT NULL REFERENCES new_chat_threads(id) ON DELETE CASCADE,
ai_responding_to_user_id UUID REFERENCES "user"(id) ON DELETE SET NULL,
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE (thread_id)
)
"""
)
op.execute(
"CREATE INDEX IF NOT EXISTS idx_chat_session_state_thread_id ON chat_session_state(thread_id)"
)
op.execute("ALTER TABLE chat_session_state REPLICA IDENTITY FULL;")
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_publication_tables
WHERE pubname = 'electric_publication_default'
AND tablename = 'chat_session_state'
) THEN
ALTER PUBLICATION electric_publication_default ADD TABLE chat_session_state;
END IF;
END
$$;
"""
)
def downgrade() -> None:
"""Drop chat_session_state table and remove from Electric SQL replication."""
op.execute(
"""
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM pg_publication_tables
WHERE pubname = 'electric_publication_default'
AND tablename = 'chat_session_state'
) THEN
ALTER PUBLICATION electric_publication_default DROP TABLE chat_session_state;
END IF;
END
$$;
"""
)
op.execute("DROP TABLE IF EXISTS chat_session_state;")

View file

@ -1,33 +0,0 @@
"""Add Obsidian connector enums
Revision ID: 75
Revises: 74
Create Date: 2026-01-21
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "75"
down_revision: str | None = "74"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# Add OBSIDIAN_CONNECTOR to documenttype enum
op.execute("ALTER TYPE documenttype ADD VALUE IF NOT EXISTS 'OBSIDIAN_CONNECTOR'")
# Add OBSIDIAN_CONNECTOR to searchsourceconnectortype enum
op.execute(
"ALTER TYPE searchsourceconnectortype ADD VALUE IF NOT EXISTS 'OBSIDIAN_CONNECTOR'"
)
def downgrade() -> None:
# Note: PostgreSQL doesn't support removing enum values directly.
# The values will remain in the enum type but won't be used.
pass

View file

@ -0,0 +1,99 @@
"""Add live collaboration tables to Electric SQL publication
Revision ID: 76
Revises: 75
Enables real-time sync for live collaboration features:
- new_chat_messages: Live message sync between users
- chat_comments: Live comment updates
Note: User/member info is fetched via API (membersAtom) for client-side joins,
not via Electric SQL, to keep where clauses optimized and reduce complexity.
"""
from collections.abc import Sequence
from alembic import op
revision: str = "76"
down_revision: str | None = "75"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add live collaboration tables to Electric SQL replication."""
# Set REPLICA IDENTITY FULL for Electric SQL sync
op.execute("ALTER TABLE new_chat_messages REPLICA IDENTITY FULL;")
op.execute("ALTER TABLE chat_comments REPLICA IDENTITY FULL;")
# Add new_chat_messages to Electric publication
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_publication_tables
WHERE pubname = 'electric_publication_default'
AND tablename = 'new_chat_messages'
) THEN
ALTER PUBLICATION electric_publication_default ADD TABLE new_chat_messages;
END IF;
END
$$;
"""
)
# Add chat_comments to Electric publication
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM pg_publication_tables
WHERE pubname = 'electric_publication_default'
AND tablename = 'chat_comments'
) THEN
ALTER PUBLICATION electric_publication_default ADD TABLE chat_comments;
END IF;
END
$$;
"""
)
def downgrade() -> None:
"""Remove live collaboration tables from Electric SQL replication."""
op.execute(
"""
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM pg_publication_tables
WHERE pubname = 'electric_publication_default'
AND tablename = 'new_chat_messages'
) THEN
ALTER PUBLICATION electric_publication_default DROP TABLE new_chat_messages;
END IF;
END
$$;
"""
)
op.execute(
"""
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM pg_publication_tables
WHERE pubname = 'electric_publication_default'
AND tablename = 'chat_comments'
) THEN
ALTER PUBLICATION electric_publication_default DROP TABLE chat_comments;
END IF;
END
$$;
"""
)
# Note: Not reverting REPLICA IDENTITY as it doesn't harm normal operations

View file

@ -0,0 +1,70 @@
"""Add thread_id to chat_comments for denormalized Electric subscriptions
This denormalization allows a single Electric SQL subscription per thread
instead of one per message, significantly reducing connection overhead.
Revision ID: 77
Revises: 76
"""
from collections.abc import Sequence
from alembic import op
revision: str = "77"
down_revision: str | None = "76"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add thread_id column to chat_comments and backfill from messages."""
# Add the column (nullable initially for backfill)
op.execute(
"""
ALTER TABLE chat_comments
ADD COLUMN IF NOT EXISTS thread_id INTEGER;
"""
)
# Backfill thread_id from the related message
op.execute(
"""
UPDATE chat_comments c
SET thread_id = m.thread_id
FROM new_chat_messages m
WHERE c.message_id = m.id
AND c.thread_id IS NULL;
"""
)
# Make it NOT NULL after backfill
op.execute(
"""
ALTER TABLE chat_comments
ALTER COLUMN thread_id SET NOT NULL;
"""
)
# Add FK constraint
op.execute(
"""
ALTER TABLE chat_comments
ADD CONSTRAINT fk_chat_comments_thread_id
FOREIGN KEY (thread_id) REFERENCES new_chat_threads(id) ON DELETE CASCADE;
"""
)
# Add index for efficient Electric subscriptions by thread
op.execute(
"CREATE INDEX IF NOT EXISTS idx_chat_comments_thread_id ON chat_comments(thread_id)"
)
def downgrade() -> None:
"""Remove thread_id column from chat_comments."""
op.execute("DROP INDEX IF EXISTS idx_chat_comments_thread_id")
op.execute(
"ALTER TABLE chat_comments DROP CONSTRAINT IF EXISTS fk_chat_comments_thread_id"
)
op.execute("ALTER TABLE chat_comments DROP COLUMN IF EXISTS thread_id")

View file

@ -268,7 +268,9 @@ class ComposioConnector:
from_email = header_dict.get("from", "Unknown Sender")
to_email = header_dict.get("to", "Unknown Recipient")
# Composio provides messageTimestamp directly
date_str = message.get("messageTimestamp", "") or header_dict.get("date", "Unknown Date")
date_str = message.get("messageTimestamp", "") or header_dict.get(
"date", "Unknown Date"
)
# Build markdown content
markdown_content = f"# {subject}\n\n"

View file

@ -58,7 +58,9 @@ class GitHubConnector:
if self.token:
logger.info("GitHub connector initialized with authentication token.")
else:
logger.info("GitHub connector initialized without token (public repos only).")
logger.info(
"GitHub connector initialized without token (public repos only)."
)
def ingest_repository(
self,
@ -95,17 +97,27 @@ class GitHubConnector:
cmd = [
"gitingest",
repo_url,
"--output", output_path,
"--max-size", str(max_file_size),
"--output",
output_path,
"--max-size",
str(max_file_size),
# Common exclude patterns
"-e", "node_modules/*",
"-e", "vendor/*",
"-e", ".git/*",
"-e", "__pycache__/*",
"-e", "dist/*",
"-e", "build/*",
"-e", "*.lock",
"-e", "package-lock.json",
"-e",
"node_modules/*",
"-e",
"vendor/*",
"-e",
".git/*",
"-e",
"__pycache__/*",
"-e",
"dist/*",
"-e",
"build/*",
"-e",
"*.lock",
"-e",
"package-lock.json",
]
# Add branch if specified
@ -147,7 +159,9 @@ class GitHubConnector:
os.unlink(output_path)
if not full_content or not full_content.strip():
logger.warning(f"No content retrieved from repository: {repo_full_name}")
logger.warning(
f"No content retrieved from repository: {repo_full_name}"
)
return None
# Parse the gitingest output
@ -171,11 +185,11 @@ class GitHubConnector:
logger.error(f"gitingest timed out for repository: {repo_full_name}")
return None
except FileNotFoundError:
logger.error(
"gitingest CLI not found. Falling back to Python library."
)
logger.error("gitingest CLI not found. Falling back to Python library.")
# Fall back to Python library
return self._ingest_with_python_library(repo_full_name, branch, max_file_size)
return self._ingest_with_python_library(
repo_full_name, branch, max_file_size
)
except Exception as e:
logger.error(f"Failed to ingest repository {repo_full_name}: {e}")
return None

View file

@ -84,7 +84,9 @@ class SearchSourceConnectorType(str, Enum):
CIRCLEBACK_CONNECTOR = "CIRCLEBACK_CONNECTOR"
OBSIDIAN_CONNECTOR = "OBSIDIAN_CONNECTOR" # Self-hosted only - Local Obsidian vault indexing
MCP_CONNECTOR = "MCP_CONNECTOR" # Model Context Protocol - User-defined API tools
COMPOSIO_CONNECTOR = "COMPOSIO_CONNECTOR" # Generic Composio integration (Google, Slack, etc.)
COMPOSIO_CONNECTOR = (
"COMPOSIO_CONNECTOR" # Generic Composio integration (Google, Slack, etc.)
)
class LiteLLMProvider(str, Enum):
@ -417,6 +419,13 @@ class ChatComment(BaseModel, TimestampMixin):
nullable=False,
index=True,
)
# Denormalized thread_id for efficient Electric SQL subscriptions (one per thread)
thread_id = Column(
Integer,
ForeignKey("new_chat_threads.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
parent_id = Column(
Integer,
ForeignKey("chat_comments.id", ondelete="CASCADE"),
@ -440,6 +449,7 @@ class ChatComment(BaseModel, TimestampMixin):
# Relationships
message = relationship("NewChatMessage", back_populates="comments")
thread = relationship("NewChatThread")
author = relationship("User")
parent = relationship(
"ChatComment", remote_side="ChatComment.id", backref="replies"
@ -476,6 +486,38 @@ class ChatCommentMention(BaseModel, TimestampMixin):
mentioned_user = relationship("User")
class ChatSessionState(BaseModel):
"""
Tracks real-time session state for shared chat collaboration.
One record per thread, synced via Electric SQL.
"""
__tablename__ = "chat_session_state"
thread_id = Column(
Integer,
ForeignKey("new_chat_threads.id", ondelete="CASCADE"),
nullable=False,
unique=True,
index=True,
)
ai_responding_to_user_id = Column(
UUID(as_uuid=True),
ForeignKey("user.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
updated_at = Column(
TIMESTAMP(timezone=True),
nullable=False,
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
)
thread = relationship("NewChatThread")
ai_responding_to_user = relationship("User")
class MemoryCategory(str, Enum):
"""Categories for user memories."""

View file

@ -84,7 +84,9 @@ async def list_composio_toolkits(user: User = Depends(current_active_user)):
@router.get("/auth/composio/connector/add")
async def initiate_composio_auth(
space_id: int,
toolkit_id: str = Query(..., description="Composio toolkit ID (e.g., 'googledrive', 'gmail')"),
toolkit_id: str = Query(
..., description="Composio toolkit ID (e.g., 'googledrive', 'gmail')"
),
user: User = Depends(current_active_user),
):
"""
@ -165,7 +167,9 @@ async def initiate_composio_auth(
@router.get("/auth/composio/connector/callback")
async def composio_callback(
state: str | None = None,
connectedAccountId: str | None = None, # Composio sends camelCase
composio_connected_account_id: str | None = Query(
None, alias="connectedAccountId"
), # Composio sends camelCase
connected_account_id: str | None = None, # Fallback snake_case
error: str | None = None,
session: AsyncSession = Depends(get_async_session),
@ -232,15 +236,18 @@ async def composio_callback(
)
# Initialize Composio service
service = ComposioService()
entity_id = f"surfsense_{user_id}"
ComposioService()
# Use camelCase param if provided (Composio's format), fallback to snake_case
final_connected_account_id = connectedAccountId or connected_account_id
final_connected_account_id = (
composio_connected_account_id or connected_account_id
)
# DEBUG: Log all query parameters received
logger.info(f"DEBUG: Callback received - connectedAccountId: {connectedAccountId}, connected_account_id: {connected_account_id}, using: {final_connected_account_id}")
logger.info(
f"DEBUG: Callback received - connectedAccountId: {composio_connected_account_id}, connected_account_id: {connected_account_id}, using: {final_connected_account_id}"
)
# If we still don't have a connected_account_id, warn but continue
# (the connector will be created but indexing won't work until updated)
if not final_connected_account_id:
@ -249,7 +256,9 @@ async def composio_callback(
"The connector will be created but indexing may not work."
)
else:
logger.info(f"Successfully got connected_account_id: {final_connected_account_id}")
logger.info(
f"Successfully got connected_account_id: {final_connected_account_id}"
)
# Build connector config
connector_config = {

View file

@ -990,7 +990,7 @@ async def handle_new_chat(
search_space_id=request.search_space_id,
chat_id=request.chat_id,
session=session,
user_id=str(user.id), # Pass user ID for memory tools
user_id=str(user.id), # Pass user ID for memory tools and session state
llm_config_id=llm_config_id,
attachments=request.attachments,
mentioned_document_ids=request.mentioned_document_ids,

View file

@ -1,12 +1,15 @@
"""
Notifications API routes.
These endpoints allow marking notifications as read.
Electric SQL automatically syncs the changes to all connected clients.
These endpoints allow marking notifications as read and fetching older notifications.
Electric SQL automatically syncs the changes to all connected clients for recent items.
For older items (beyond the sync window), use the list endpoint.
"""
from fastapi import APIRouter, Depends, HTTPException, status
from datetime import UTC, datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy import desc, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Notification, User, get_async_session
@ -14,6 +17,36 @@ from app.users import current_active_user
router = APIRouter(prefix="/notifications", tags=["notifications"])
# Must match frontend SYNC_WINDOW_DAYS in use-inbox.ts
SYNC_WINDOW_DAYS = 14
class NotificationResponse(BaseModel):
"""Response model for a single notification."""
id: int
user_id: str
search_space_id: int | None
type: str
title: str
message: str
read: bool
metadata: dict
created_at: str
updated_at: str | None
class Config:
from_attributes = True
class NotificationListResponse(BaseModel):
"""Response for listing notifications with pagination."""
items: list[NotificationResponse]
total: int
has_more: bool
next_offset: int | None
class MarkReadResponse(BaseModel):
"""Response for mark as read operations."""
@ -30,6 +63,169 @@ class MarkAllReadResponse(BaseModel):
updated_count: int
class UnreadCountResponse(BaseModel):
"""Response for unread count with split between recent and older items."""
total_unread: int
recent_unread: int # Within SYNC_WINDOW_DAYS
@router.get("/unread-count", response_model=UnreadCountResponse)
async def get_unread_count(
search_space_id: int | None = Query(None, description="Filter by search space ID"),
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
) -> UnreadCountResponse:
"""
Get the total unread notification count for the current user.
Returns both:
- total_unread: All unread notifications (for accurate badge count)
- recent_unread: Unread notifications within the sync window (last 14 days)
This allows the frontend to calculate:
- older_unread = total_unread - recent_unread (static until reconciliation)
- Display count = older_unread + live_recent_count (from Electric SQL)
"""
# Calculate cutoff date for sync window
cutoff_date = datetime.now(UTC) - timedelta(days=SYNC_WINDOW_DAYS)
# Base filter for user's unread notifications
base_filter = [
Notification.user_id == user.id,
Notification.read == False, # noqa: E712
]
# Add search space filter if provided (include null for global notifications)
if search_space_id is not None:
base_filter.append(
(Notification.search_space_id == search_space_id)
| (Notification.search_space_id.is_(None))
)
# Total unread count (all time)
total_query = select(func.count(Notification.id)).where(*base_filter)
total_result = await session.execute(total_query)
total_unread = total_result.scalar() or 0
# Recent unread count (within sync window)
recent_query = select(func.count(Notification.id)).where(
*base_filter,
Notification.created_at > cutoff_date,
)
recent_result = await session.execute(recent_query)
recent_unread = recent_result.scalar() or 0
return UnreadCountResponse(
total_unread=total_unread,
recent_unread=recent_unread,
)
@router.get("", response_model=NotificationListResponse)
async def list_notifications(
search_space_id: int | None = Query(None, description="Filter by search space ID"),
type_filter: str | None = Query(
None, alias="type", description="Filter by notification type"
),
before_date: str | None = Query(
None, description="Get notifications before this ISO date (for pagination)"
),
limit: int = Query(50, ge=1, le=100, description="Number of items to return"),
offset: int = Query(0, ge=0, description="Number of items to skip"),
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
) -> NotificationListResponse:
"""
List notifications for the current user with pagination.
This endpoint is used as a fallback for older notifications that are
outside the Electric SQL sync window (2 weeks).
Use `before_date` to paginate through older notifications efficiently.
"""
# Build base query
query = select(Notification).where(Notification.user_id == user.id)
count_query = select(func.count(Notification.id)).where(
Notification.user_id == user.id
)
# Filter by search space (include null search_space_id for global notifications)
if search_space_id is not None:
query = query.where(
(Notification.search_space_id == search_space_id)
| (Notification.search_space_id.is_(None))
)
count_query = count_query.where(
(Notification.search_space_id == search_space_id)
| (Notification.search_space_id.is_(None))
)
# Filter by type
if type_filter:
query = query.where(Notification.type == type_filter)
count_query = count_query.where(Notification.type == type_filter)
# Filter by date (for efficient pagination of older items)
if before_date:
try:
before_datetime = datetime.fromisoformat(before_date.replace("Z", "+00:00"))
query = query.where(Notification.created_at < before_datetime)
count_query = count_query.where(Notification.created_at < before_datetime)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid date format. Use ISO format (e.g., 2024-01-15T00:00:00Z)",
) from None
# Get total count
total_result = await session.execute(count_query)
total = total_result.scalar() or 0
# Apply ordering and pagination
query = (
query.order_by(desc(Notification.created_at)).offset(offset).limit(limit + 1)
)
# Execute query
result = await session.execute(query)
notifications = result.scalars().all()
# Check if there are more items
has_more = len(notifications) > limit
if has_more:
notifications = notifications[:limit]
# Convert to response format
items = []
for notification in notifications:
items.append(
NotificationResponse(
id=notification.id,
user_id=str(notification.user_id),
search_space_id=notification.search_space_id,
type=notification.type,
title=notification.title,
message=notification.message,
read=notification.read,
metadata=notification.notification_metadata or {},
created_at=notification.created_at.isoformat()
if notification.created_at
else "",
updated_at=notification.updated_at.isoformat()
if notification.updated_at
else None,
)
)
return NotificationListResponse(
items=items,
total=total,
has_more=has_more,
next_offset=offset + limit if has_more else None,
)
@router.patch("/{notification_id}/read", response_model=MarkReadResponse)
async def mark_notification_as_read(
notification_id: int,

View file

@ -0,0 +1,29 @@
"""
Pydantic schemas for chat session state (live collaboration).
"""
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, ConfigDict
class RespondingUser(BaseModel):
"""The user that the AI is currently responding to."""
id: UUID
display_name: str | None = None
email: str
model_config = ConfigDict(from_attributes=True)
class ChatSessionStateResponse(BaseModel):
"""Current session state for a chat thread."""
id: int
thread_id: int
responding_to: RespondingUser | None = None
updated_at: datetime
model_config = ConfigDict(from_attributes=True)

View file

@ -281,8 +281,10 @@ async def create_comment(
detail="You don't have permission to create comments in this search space",
)
thread = message.thread
comment = ChatComment(
message_id=message_id,
thread_id=thread.id, # Denormalized for efficient Electric subscriptions
author_id=user.id,
content=content,
)
@ -299,7 +301,6 @@ async def create_comment(
user_names = await get_user_names_for_mentions(session, set(mentions_map.keys()))
# Create notifications for mentioned users (excluding author)
thread = message.thread
author_name = user.display_name or user.email
content_preview = render_mentions(content, user_names)
for mentioned_user_id, mention_id in mentions_map.items():
@ -393,8 +394,10 @@ async def create_reply(
detail="You don't have permission to create comments in this search space",
)
thread = parent_comment.message.thread
reply = ChatComment(
message_id=parent_comment.message_id,
thread_id=thread.id, # Denormalized for efficient Electric subscriptions
parent_id=comment_id,
author_id=user.id,
content=content,
@ -412,7 +415,6 @@ async def create_reply(
user_names = await get_user_names_for_mentions(session, set(mentions_map.keys()))
# Create notifications for mentioned users (excluding author)
thread = parent_comment.message.thread
author_name = user.display_name or user.email
content_preview = render_mentions(content, user_names)
for mentioned_user_id, mention_id in mentions_map.items():

View file

@ -0,0 +1,65 @@
"""
Service layer for chat session state (live collaboration).
"""
from datetime import UTC, datetime
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.db import ChatSessionState
async def get_session_state(
session: AsyncSession,
thread_id: int,
) -> ChatSessionState | None:
"""Get the current session state for a thread."""
result = await session.execute(
select(ChatSessionState)
.options(selectinload(ChatSessionState.ai_responding_to_user))
.filter(ChatSessionState.thread_id == thread_id)
)
return result.scalar_one_or_none()
async def set_ai_responding(
session: AsyncSession,
thread_id: int,
user_id: UUID,
) -> ChatSessionState:
"""Mark AI as responding to a specific user. Uses upsert for atomicity."""
now = datetime.now(UTC)
upsert_query = insert(ChatSessionState).values(
thread_id=thread_id,
ai_responding_to_user_id=user_id,
updated_at=now,
)
upsert_query = upsert_query.on_conflict_do_update(
index_elements=["thread_id"],
set_={
"ai_responding_to_user_id": user_id,
"updated_at": now,
},
)
await session.execute(upsert_query)
await session.commit()
return await get_session_state(session, thread_id)
async def clear_ai_responding(
session: AsyncSession,
thread_id: int,
) -> ChatSessionState | None:
"""Clear AI responding state when response is complete."""
state = await get_session_state(session, thread_id)
if state:
state.ai_responding_to_user_id = None
state.updated_at = datetime.now(UTC)
await session.commit()
await session.refresh(state)
return state

View file

@ -97,7 +97,7 @@ class ComposioService:
config_toolkit = getattr(auth_config, "toolkit", None)
if config_toolkit is None:
continue
# Extract toolkit name/slug from the object
toolkit_name = None
if isinstance(config_toolkit, str):
@ -108,18 +108,22 @@ class ComposioService:
toolkit_name = config_toolkit.name
elif hasattr(config_toolkit, "id"):
toolkit_name = config_toolkit.id
# Compare case-insensitively
if toolkit_name and toolkit_name.lower() == toolkit_id.lower():
logger.info(f"Found auth config {auth_config.id} for toolkit {toolkit_id}")
logger.info(
f"Found auth config {auth_config.id} for toolkit {toolkit_id}"
)
return auth_config.id
# Log available auth configs for debugging
logger.warning(f"No auth config found for toolkit '{toolkit_id}'. Available auth configs:")
logger.warning(
f"No auth config found for toolkit '{toolkit_id}'. Available auth configs:"
)
for auth_config in auth_configs.items:
config_toolkit = getattr(auth_config, "toolkit", None)
logger.warning(f" - {auth_config.id}: toolkit={config_toolkit}")
return None
except Exception as e:
logger.error(f"Failed to list auth configs: {e!s}")
@ -148,7 +152,7 @@ class ComposioService:
try:
# First, get the auth_config_id for this toolkit
auth_config_id = self._get_auth_config_for_toolkit(toolkit_id)
if not auth_config_id:
raise ValueError(
f"No auth config found for toolkit '{toolkit_id}'. "
@ -200,7 +204,9 @@ class ComposioService:
"user_id": getattr(account, "user_id", None),
}
except Exception as e:
logger.error(f"Failed to get connected account {connected_account_id}: {e!s}")
logger.error(
f"Failed to get connected account {connected_account_id}: {e!s}"
)
return None
async def list_all_connections(self) -> list[dict[str, Any]]:
@ -212,15 +218,17 @@ class ComposioService:
"""
try:
accounts_response = self.client.connected_accounts.list()
if hasattr(accounts_response, "items"):
accounts = accounts_response.items
elif hasattr(accounts_response, "__iter__"):
accounts = accounts_response
else:
logger.warning(f"Unexpected accounts response type: {type(accounts_response)}")
logger.warning(
f"Unexpected accounts response type: {type(accounts_response)}"
)
return []
result = []
for acc in accounts:
toolkit_raw = getattr(acc, "toolkit", None)
@ -234,14 +242,16 @@ class ComposioService:
toolkit_info = toolkit_raw.name
else:
toolkit_info = str(toolkit_raw)
result.append({
"id": acc.id,
"status": getattr(acc, "status", None),
"toolkit": toolkit_info,
"user_id": getattr(acc, "user_id", None),
})
result.append(
{
"id": acc.id,
"status": getattr(acc, "status", None),
"toolkit": toolkit_info,
"user_id": getattr(acc, "user_id", None),
}
)
logger.info(f"DEBUG: Found {len(result)} TOTAL connections in Composio")
return result
except Exception as e:
@ -261,16 +271,18 @@ class ComposioService:
try:
logger.info(f"DEBUG: Calling connected_accounts.list(user_id='{user_id}')")
accounts_response = self.client.connected_accounts.list(user_id=user_id)
# Handle paginated response (may have .items attribute) or direct list
if hasattr(accounts_response, "items"):
accounts = accounts_response.items
elif hasattr(accounts_response, "__iter__"):
accounts = accounts_response
else:
logger.warning(f"Unexpected accounts response type: {type(accounts_response)}")
logger.warning(
f"Unexpected accounts response type: {type(accounts_response)}"
)
return []
result = []
for acc in accounts:
# Extract toolkit info - might be string or object
@ -285,13 +297,15 @@ class ComposioService:
toolkit_info = toolkit_raw.name
else:
toolkit_info = toolkit_raw
result.append({
"id": acc.id,
"status": getattr(acc, "status", None),
"toolkit": toolkit_info,
})
result.append(
{
"id": acc.id,
"status": getattr(acc, "status", None),
"toolkit": toolkit_info,
}
)
logger.info(f"Found {len(result)} connections for user {user_id}: {result}")
return result
except Exception as e:
@ -383,18 +397,24 @@ class ComposioService:
return [], None, result.get("error", "Unknown error")
data = result.get("data", {})
logger.info(f"DEBUG: Drive data type: {type(data)}, keys: {data.keys() if isinstance(data, dict) else 'N/A'}")
logger.info(
f"DEBUG: Drive data type: {type(data)}, keys: {data.keys() if isinstance(data, dict) else 'N/A'}"
)
# Handle nested response structure from Composio
files = []
next_token = None
if isinstance(data, dict):
# Try direct access first, then nested
files = data.get("files", []) or data.get("data", {}).get("files", [])
next_token = data.get("nextPageToken") or data.get("next_page_token") or data.get("data", {}).get("nextPageToken")
next_token = (
data.get("nextPageToken")
or data.get("next_page_token")
or data.get("data", {}).get("nextPageToken")
)
elif isinstance(data, list):
files = data
logger.info(f"DEBUG: Extracted {len(files)} drive files")
return files, next_token, None
@ -475,16 +495,22 @@ class ComposioService:
return [], result.get("error", "Unknown error")
data = result.get("data", {})
logger.info(f"DEBUG: Gmail data type: {type(data)}, keys: {data.keys() if isinstance(data, dict) else 'N/A'}")
logger.info(
f"DEBUG: Gmail data type: {type(data)}, keys: {data.keys() if isinstance(data, dict) else 'N/A'}"
)
logger.info(f"DEBUG: Gmail full data: {data}")
# Try different possible response structures
messages = []
if isinstance(data, dict):
messages = data.get("messages", []) or data.get("data", {}).get("messages", []) or data.get("emails", [])
messages = (
data.get("messages", [])
or data.get("data", {}).get("messages", [])
or data.get("emails", [])
)
elif isinstance(data, list):
messages = data
logger.info(f"DEBUG: Extracted {len(messages)} messages")
return messages, None
@ -569,16 +595,22 @@ class ComposioService:
return [], result.get("error", "Unknown error")
data = result.get("data", {})
logger.info(f"DEBUG: Calendar data type: {type(data)}, keys: {data.keys() if isinstance(data, dict) else 'N/A'}")
logger.info(
f"DEBUG: Calendar data type: {type(data)}, keys: {data.keys() if isinstance(data, dict) else 'N/A'}"
)
logger.info(f"DEBUG: Calendar full data: {data}")
# Try different possible response structures
events = []
if isinstance(data, dict):
events = data.get("items", []) or data.get("data", {}).get("items", []) or data.get("events", [])
events = (
data.get("items", [])
or data.get("data", {}).get("items", [])
or data.get("events", [])
)
elif isinstance(data, list):
events = data
logger.info(f"DEBUG: Extracted {len(events)} calendar events")
return events, None

View file

@ -623,6 +623,28 @@ class MentionNotificationHandler(BaseNotificationHandler):
def __init__(self):
super().__init__("new_mention")
async def find_notification_by_mention(
self,
session: AsyncSession,
mention_id: int,
) -> Notification | None:
"""
Find an existing notification by mention ID.
Args:
session: Database session
mention_id: The mention ID to search for
Returns:
Notification if found, None otherwise
"""
query = select(Notification).where(
Notification.type == self.notification_type,
Notification.notification_metadata["mention_id"].astext == str(mention_id),
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def notify_new_mention(
self,
session: AsyncSession,
@ -641,11 +663,12 @@ class MentionNotificationHandler(BaseNotificationHandler):
) -> Notification:
"""
Create notification when a user is @mentioned in a comment.
Uses mention_id for idempotency to prevent duplicate notifications.
Args:
session: Database session
mentioned_user_id: User who was mentioned
mention_id: ID of the mention record
mention_id: ID of the mention record (used for idempotency)
comment_id: ID of the comment containing the mention
message_id: ID of the message being commented on
thread_id: ID of the chat thread
@ -658,8 +681,16 @@ class MentionNotificationHandler(BaseNotificationHandler):
search_space_id: Search space ID
Returns:
Notification: The created notification
Notification: The created or existing notification
"""
# Check if notification already exists for this mention (idempotency)
existing = await self.find_notification_by_mention(session, mention_id)
if existing:
logger.info(
f"Notification already exists for mention {mention_id}, returning existing"
)
return existing
title = f"{author_name} mentioned you"
message = content_preview[:100] + ("..." if len(content_preview) > 100 else "")
@ -676,21 +707,37 @@ class MentionNotificationHandler(BaseNotificationHandler):
"content_preview": content_preview[:200],
}
notification = Notification(
user_id=mentioned_user_id,
search_space_id=search_space_id,
type=self.notification_type,
title=title,
message=message,
notification_metadata=metadata,
)
session.add(notification)
await session.commit()
await session.refresh(notification)
logger.info(
f"Created new_mention notification {notification.id} for user {mentioned_user_id}"
)
return notification
try:
notification = Notification(
user_id=mentioned_user_id,
search_space_id=search_space_id,
type=self.notification_type,
title=title,
message=message,
notification_metadata=metadata,
)
session.add(notification)
await session.commit()
await session.refresh(notification)
logger.info(
f"Created new_mention notification {notification.id} for user {mentioned_user_id}"
)
return notification
except Exception as e:
# Handle race condition - if duplicate key error, try to fetch existing
await session.rollback()
if (
"duplicate key" in str(e).lower()
or "unique constraint" in str(e).lower()
):
logger.warning(
f"Duplicate notification detected for mention {mention_id}, fetching existing"
)
existing = await self.find_notification_by_mention(session, mention_id)
if existing:
return existing
# Re-raise if not a duplicate key error or couldn't find existing
raise
class NotificationService:

View file

@ -11,6 +11,7 @@ Supports loading LLM configurations from:
import json
from collections.abc import AsyncGenerator
from uuid import UUID
from langchain_core.messages import HumanMessage
from sqlalchemy.ext.asyncio import AsyncSession
@ -27,6 +28,10 @@ from app.agents.new_chat.llm_config import (
)
from app.db import Document, SurfsenseDocsDocument
from app.schemas.new_chat import ChatAttachment
from app.services.chat_session_state_service import (
clear_ai_responding,
set_ai_responding,
)
from app.services.connector_service import ConnectorService
from app.services.new_streaming_service import VercelStreamingService
@ -167,9 +172,8 @@ async def stream_new_chat(
search_space_id: The search space ID
chat_id: The chat ID (used as LangGraph thread_id for memory)
session: The database session
user_id: The current user's UUID string (for memory tools)
user_id: The current user's UUID string (for memory tools and session state)
llm_config_id: The LLM configuration ID (default: -1 for first global config)
messages: Optional chat history from frontend (list of ChatMessage)
attachments: Optional attachments with extracted content
mentioned_document_ids: Optional list of document IDs mentioned with @ in the chat
mentioned_surfsense_doc_ids: Optional list of SurfSense doc IDs mentioned with @ in the chat
@ -183,6 +187,9 @@ async def stream_new_chat(
current_text_id: str | None = None
try:
# Mark AI as responding to this user for live collaboration
if user_id:
await set_ai_responding(session, chat_id, UUID(user_id))
# Load LLM config - supports both YAML (negative IDs) and database (positive IDs)
agent_config: AgentConfig | None = None
@ -1147,3 +1154,7 @@ async def stream_new_chat(
yield streaming_service.format_finish_step()
yield streaming_service.format_finish()
yield streaming_service.format_done()
finally:
# Clear AI responding state for live collaboration
await clear_ai_responding(session, chat_id)

View file

@ -144,7 +144,9 @@ async def index_composio_connector(
# Get toolkit ID from config
toolkit_id = connector.config.get("toolkit_id")
if not toolkit_id:
error_msg = f"Composio connector {connector_id} has no toolkit_id configured"
error_msg = (
f"Composio connector {connector_id} has no toolkit_id configured"
)
await task_logger.log_task_failure(
log_entry, error_msg, {"error_type": "MissingToolkitId"}
)
@ -287,8 +289,14 @@ async def _index_composio_google_drive(
try:
# Handle both standard Google API and potential Composio variations
file_id = file_info.get("id", "") or file_info.get("fileId", "")
file_name = file_info.get("name", "") or file_info.get("fileName", "") or "Untitled"
mime_type = file_info.get("mimeType", "") or file_info.get("mime_type", "")
file_name = (
file_info.get("name", "")
or file_info.get("fileName", "")
or "Untitled"
)
mime_type = file_info.get("mimeType", "") or file_info.get(
"mime_type", ""
)
if not file_id:
documents_skipped += 1
@ -309,12 +317,15 @@ async def _index_composio_google_drive(
)
# Get file content
content, content_error = await composio_connector.get_drive_file_content(
file_id
)
(
content,
content_error,
) = await composio_connector.get_drive_file_content(file_id)
if content_error or not content:
logger.warning(f"Could not get content for file {file_name}: {content_error}")
logger.warning(
f"Could not get content for file {file_name}: {content_error}"
)
# Use metadata as content fallback
markdown_content = f"# {file_name}\n\n"
markdown_content += f"**File ID:** {file_id}\n"
@ -344,12 +355,19 @@ async def _index_composio_google_drive(
"mime_type": mime_type,
"document_type": "Google Drive File (Composio)",
}
summary_content, summary_embedding = await generate_document_summary(
(
summary_content,
summary_embedding,
) = await generate_document_summary(
markdown_content, user_llm, document_metadata
)
else:
summary_content = f"Google Drive File: {file_name}\n\nType: {mime_type}"
summary_embedding = config.embedding_model_instance.embed(summary_content)
summary_content = (
f"Google Drive File: {file_name}\n\nType: {mime_type}"
)
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
chunks = await create_document_chunks(markdown_content)
@ -382,12 +400,19 @@ async def _index_composio_google_drive(
"mime_type": mime_type,
"document_type": "Google Drive File (Composio)",
}
summary_content, summary_embedding = await generate_document_summary(
(
summary_content,
summary_embedding,
) = await generate_document_summary(
markdown_content, user_llm, document_metadata
)
else:
summary_content = f"Google Drive File: {file_name}\n\nType: {mime_type}"
summary_embedding = config.embedding_model_instance.embed(summary_content)
summary_content = (
f"Google Drive File: {file_name}\n\nType: {mime_type}"
)
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
chunks = await create_document_chunks(markdown_content)
@ -527,11 +552,15 @@ async def _index_composio_gmail(
date_str = value
# Format to markdown using the full message data
markdown_content = composio_connector.format_gmail_message_to_markdown(message)
markdown_content = composio_connector.format_gmail_message_to_markdown(
message
)
# Generate unique identifier
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.COMPOSIO_CONNECTOR, f"gmail_{message_id}", search_space_id
DocumentType.COMPOSIO_CONNECTOR,
f"gmail_{message_id}",
search_space_id,
)
content_hash = generate_content_hash(markdown_content, search_space_id)
@ -560,12 +589,19 @@ async def _index_composio_gmail(
"sender": sender,
"document_type": "Gmail Message (Composio)",
}
summary_content, summary_embedding = await generate_document_summary(
(
summary_content,
summary_embedding,
) = await generate_document_summary(
markdown_content, user_llm, document_metadata
)
else:
summary_content = f"Gmail: {subject}\n\nFrom: {sender}\nDate: {date_str}"
summary_embedding = config.embedding_model_instance.embed(summary_content)
summary_content = (
f"Gmail: {subject}\n\nFrom: {sender}\nDate: {date_str}"
)
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
chunks = await create_document_chunks(markdown_content)
@ -600,12 +636,19 @@ async def _index_composio_gmail(
"sender": sender,
"document_type": "Gmail Message (Composio)",
}
summary_content, summary_embedding = await generate_document_summary(
(
summary_content,
summary_embedding,
) = await generate_document_summary(
markdown_content, user_llm, document_metadata
)
else:
summary_content = f"Gmail: {subject}\n\nFrom: {sender}\nDate: {date_str}"
summary_embedding = config.embedding_model_instance.embed(summary_content)
summary_content = (
f"Gmail: {subject}\n\nFrom: {sender}\nDate: {date_str}"
)
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
chunks = await create_document_chunks(markdown_content)
@ -728,18 +771,24 @@ async def _index_composio_google_calendar(
try:
# Handle both standard Google API and potential Composio variations
event_id = event.get("id", "") or event.get("eventId", "")
summary = event.get("summary", "") or event.get("title", "") or "No Title"
summary = (
event.get("summary", "") or event.get("title", "") or "No Title"
)
if not event_id:
documents_skipped += 1
continue
# Format to markdown
markdown_content = composio_connector.format_calendar_event_to_markdown(event)
markdown_content = composio_connector.format_calendar_event_to_markdown(
event
)
# Generate unique identifier
unique_identifier_hash = generate_unique_identifier_hash(
DocumentType.COMPOSIO_CONNECTOR, f"calendar_{event_id}", search_space_id
DocumentType.COMPOSIO_CONNECTOR,
f"calendar_{event_id}",
search_space_id,
)
content_hash = generate_content_hash(markdown_content, search_space_id)
@ -772,14 +821,19 @@ async def _index_composio_google_calendar(
"start_time": start_time,
"document_type": "Google Calendar Event (Composio)",
}
summary_content, summary_embedding = await generate_document_summary(
(
summary_content,
summary_embedding,
) = await generate_document_summary(
markdown_content, user_llm, document_metadata
)
else:
summary_content = f"Calendar: {summary}\n\nStart: {start_time}\nEnd: {end_time}"
if location:
summary_content += f"\nLocation: {location}"
summary_embedding = config.embedding_model_instance.embed(summary_content)
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
chunks = await create_document_chunks(markdown_content)
@ -814,14 +868,21 @@ async def _index_composio_google_calendar(
"start_time": start_time,
"document_type": "Google Calendar Event (Composio)",
}
summary_content, summary_embedding = await generate_document_summary(
(
summary_content,
summary_embedding,
) = await generate_document_summary(
markdown_content, user_llm, document_metadata
)
else:
summary_content = f"Calendar: {summary}\n\nStart: {start_time}\nEnd: {end_time}"
summary_content = (
f"Calendar: {summary}\n\nStart: {start_time}\nEnd: {end_time}"
)
if location:
summary_content += f"\nLocation: {location}"
summary_embedding = config.embedding_model_instance.embed(summary_content)
summary_embedding = config.embedding_model_instance.embed(
summary_content
)
chunks = await create_document_chunks(markdown_content)
@ -874,5 +935,7 @@ async def _index_composio_google_calendar(
return documents_indexed, None
except Exception as e:
logger.error(f"Failed to index Google Calendar via Composio: {e!s}", exc_info=True)
logger.error(
f"Failed to index Google Calendar via Composio: {e!s}", exc_info=True
)
return 0, f"Failed to index Google Calendar via Composio: {e!s}"

View file

@ -128,7 +128,9 @@ async def index_github_repos(
if github_pat:
logger.info("Using GitHub PAT for authentication (private repos supported)")
else:
logger.info("No GitHub PAT provided - only public repositories can be indexed")
logger.info(
"No GitHub PAT provided - only public repositories can be indexed"
)
# 3. Initialize GitHub connector with gitingest backend
await task_logger.log_task_progress(
@ -308,9 +310,7 @@ async def _process_repository_digest(
if existing_document:
# Document exists - check if content has changed
if existing_document.content_hash == content_hash:
logger.info(
f"Repository {repo_full_name} unchanged. Skipping."
)
logger.info(f"Repository {repo_full_name} unchanged. Skipping.")
return 0
else:
logger.info(
@ -341,7 +341,7 @@ async def _process_repository_digest(
summary_content = (
f"# Repository: {repo_full_name}\n\n"
f"## File Structure\n\n{digest.tree}\n\n"
f"## File Contents (truncated)\n\n{digest.content[:MAX_DIGEST_CHARS - len(digest.tree) - 200]}..."
f"## File Contents (truncated)\n\n{digest.content[: MAX_DIGEST_CHARS - len(digest.tree) - 200]}..."
)
summary_text, summary_embedding = await generate_document_summary(
@ -362,9 +362,7 @@ async def _process_repository_digest(
# This preserves file-level granularity in search
chunks_data = await create_document_chunks(digest.content)
except Exception as chunk_err:
logger.error(
f"Failed to chunk repository {repo_full_name}: {chunk_err}"
)
logger.error(f"Failed to chunk repository {repo_full_name}: {chunk_err}")
# Fall back to a simpler chunking approach
chunks_data = await _simple_chunk_content(digest.content)