mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-27 17:56:25 +02:00
Merge branch 'dev' of https://github.com/MODSetter/SurfSense into dev
This commit is contained in:
commit
8301e0169c
71 changed files with 2889 additions and 732 deletions
|
|
@ -442,11 +442,24 @@ async def refresh_airtable_token(
|
|||
|
||||
if token_response.status_code != 200:
|
||||
error_detail = token_response.text
|
||||
error_code = ""
|
||||
try:
|
||||
error_json = token_response.json()
|
||||
error_detail = error_json.get("error_description", error_detail)
|
||||
error_code = error_json.get("error", "")
|
||||
except Exception:
|
||||
pass
|
||||
# Check if this is a token expiration/revocation error
|
||||
error_lower = (error_detail + error_code).lower()
|
||||
if (
|
||||
"invalid_grant" in error_lower
|
||||
or "expired" in error_lower
|
||||
or "revoked" in error_lower
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Airtable authentication failed. Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token refresh failed: {error_detail}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -417,6 +417,17 @@ async def refresh_clickup_token(
|
|||
error_detail = error_json.get("error", error_detail)
|
||||
except Exception:
|
||||
pass
|
||||
# Check if this is a token expiration/revocation error
|
||||
error_lower = error_detail.lower()
|
||||
if (
|
||||
"invalid_grant" in error_lower
|
||||
or "expired" in error_lower
|
||||
or "revoked" in error_lower
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="ClickUp authentication failed. Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token refresh failed: {error_detail}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -428,13 +428,26 @@ async def refresh_confluence_token(
|
|||
|
||||
if token_response.status_code != 200:
|
||||
error_detail = token_response.text
|
||||
error_code = ""
|
||||
try:
|
||||
error_json = token_response.json()
|
||||
error_detail = error_json.get(
|
||||
"error_description", error_json.get("error", error_detail)
|
||||
)
|
||||
error_code = error_json.get("error", "")
|
||||
except Exception:
|
||||
pass
|
||||
# Check if this is a token expiration/revocation error
|
||||
error_lower = (error_detail + error_code).lower()
|
||||
if (
|
||||
"invalid_grant" in error_lower
|
||||
or "expired" in error_lower
|
||||
or "revoked" in error_lower
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Confluence authentication failed. Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token refresh failed: {error_detail}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -46,6 +46,11 @@ SCOPES = [
|
|||
"guilds.members.read", # Read member information
|
||||
]
|
||||
|
||||
# Discord permission bits
|
||||
VIEW_CHANNEL = 1 << 10 # 1024
|
||||
READ_MESSAGE_HISTORY = 1 << 16 # 65536
|
||||
ADMINISTRATOR = 1 << 3 # 8
|
||||
|
||||
# Initialize security utilities
|
||||
_state_manager = None
|
||||
_token_encryption = None
|
||||
|
|
@ -531,3 +536,296 @@ async def refresh_discord_token(
|
|||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to refresh Discord tokens: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
def _compute_channel_permissions(
|
||||
base_permissions: int,
|
||||
bot_role_ids: set[str],
|
||||
bot_user_id: str | None,
|
||||
channel_overwrites: list[dict],
|
||||
guild_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
Compute effective permissions for a channel based on role permissions and overwrites.
|
||||
|
||||
Discord permission computation follows this order (per official docs):
|
||||
1. Start with base permissions from roles
|
||||
2. Apply @everyone role overwrites (deny, then allow)
|
||||
3. Apply role-specific overwrites (deny, then allow)
|
||||
4. Apply member-specific overwrites (deny, then allow)
|
||||
|
||||
Args:
|
||||
base_permissions: Combined permissions from all bot roles
|
||||
bot_role_ids: Set of role IDs the bot has
|
||||
bot_user_id: The bot's user ID for member-specific overwrites
|
||||
channel_overwrites: List of permission overwrites for the channel
|
||||
guild_id: Guild ID (same as @everyone role ID)
|
||||
|
||||
Returns:
|
||||
Computed permission integer
|
||||
"""
|
||||
permissions = base_permissions
|
||||
|
||||
# Permission overwrites are applied in order: @everyone, roles, member
|
||||
everyone_allow = 0
|
||||
everyone_deny = 0
|
||||
role_allow = 0
|
||||
role_deny = 0
|
||||
member_allow = 0
|
||||
member_deny = 0
|
||||
|
||||
for overwrite in channel_overwrites:
|
||||
overwrite_id = overwrite.get("id")
|
||||
overwrite_type = overwrite.get("type") # 0 = role, 1 = member
|
||||
allow = int(overwrite.get("allow", 0))
|
||||
deny = int(overwrite.get("deny", 0))
|
||||
|
||||
if overwrite_type == 0: # Role overwrite
|
||||
if overwrite_id == guild_id: # @everyone role
|
||||
everyone_allow = allow
|
||||
everyone_deny = deny
|
||||
elif overwrite_id in bot_role_ids:
|
||||
role_allow |= allow
|
||||
role_deny |= deny
|
||||
elif overwrite_type == 1 and bot_user_id and overwrite_id == bot_user_id:
|
||||
# Member-specific overwrite for the bot
|
||||
member_allow = allow
|
||||
member_deny = deny
|
||||
|
||||
# Apply in order per Discord docs:
|
||||
# 1. @everyone deny, then allow
|
||||
permissions &= ~everyone_deny
|
||||
permissions |= everyone_allow
|
||||
# 2. Role deny, then allow
|
||||
permissions &= ~role_deny
|
||||
permissions |= role_allow
|
||||
# 3. Member deny, then allow (applied LAST, highest priority)
|
||||
permissions &= ~member_deny
|
||||
permissions |= member_allow
|
||||
|
||||
return permissions
|
||||
|
||||
|
||||
@router.get("/discord/connector/{connector_id}/channels", response_model=None)
|
||||
async def get_discord_channels(
|
||||
connector_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get list of Discord text channels for a connector with permission info.
|
||||
|
||||
Uses Discord's HTTP REST API directly instead of WebSocket bot connection.
|
||||
Computes effective permissions to determine if bot can read message history.
|
||||
|
||||
Args:
|
||||
connector_id: The Discord connector ID
|
||||
session: Database session
|
||||
user: Current authenticated user
|
||||
|
||||
Returns:
|
||||
List of channels with id, name, type, position, category_id, and can_index fields
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
try:
|
||||
# Get connector and verify ownership
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.DISCORD_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Discord connector not found or access denied",
|
||||
)
|
||||
|
||||
# Get credentials and decrypt bot token
|
||||
credentials = DiscordAuthCredentialsBase.from_dict(connector.config)
|
||||
token_encryption = get_token_encryption()
|
||||
is_encrypted = connector.config.get("_token_encrypted", False)
|
||||
|
||||
bot_token = credentials.bot_token
|
||||
if is_encrypted and bot_token:
|
||||
try:
|
||||
bot_token = token_encryption.decrypt_token(bot_token)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt bot token: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to decrypt stored bot token"
|
||||
) from e
|
||||
|
||||
if not bot_token:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No bot token available. Please re-authenticate.",
|
||||
)
|
||||
|
||||
# Get guild_id from connector config
|
||||
guild_id = connector.config.get("guild_id")
|
||||
if not guild_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No guild_id associated with this connector. Please reconnect the Discord server.",
|
||||
)
|
||||
|
||||
headers = {"Authorization": f"Bot {bot_token}"}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Fetch bot's user info to get bot user ID
|
||||
bot_user_response = await client.get(
|
||||
"https://discord.com/api/v10/users/@me",
|
||||
headers=headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if bot_user_response.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch bot user info: {bot_user_response.text}"
|
||||
)
|
||||
bot_user_id = None
|
||||
else:
|
||||
bot_user_id = bot_user_response.json().get("id")
|
||||
|
||||
# Fetch guild info to get roles
|
||||
guild_response = await client.get(
|
||||
f"https://discord.com/api/v10/guilds/{guild_id}",
|
||||
headers=headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if guild_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=guild_response.status_code,
|
||||
detail="Failed to fetch guild information",
|
||||
)
|
||||
|
||||
guild_data = guild_response.json()
|
||||
guild_roles = {role["id"]: role for role in guild_data.get("roles", [])}
|
||||
|
||||
# Fetch bot's member info to get its roles
|
||||
bot_member_response = await client.get(
|
||||
f"https://discord.com/api/v10/guilds/{guild_id}/members/{bot_user_id}",
|
||||
headers=headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if bot_member_response.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch bot member info: {bot_member_response.text}"
|
||||
)
|
||||
bot_role_ids = {guild_id} # At minimum, bot has @everyone role
|
||||
base_permissions = int(
|
||||
guild_roles.get(guild_id, {}).get("permissions", 0)
|
||||
)
|
||||
else:
|
||||
bot_member_data = bot_member_response.json()
|
||||
bot_role_ids = set(bot_member_data.get("roles", []))
|
||||
bot_role_ids.add(guild_id) # @everyone role is always included
|
||||
|
||||
# Compute base permissions from all bot roles
|
||||
base_permissions = 0
|
||||
for role_id in bot_role_ids:
|
||||
if role_id in guild_roles:
|
||||
role_perms = int(guild_roles[role_id].get("permissions", 0))
|
||||
base_permissions |= role_perms
|
||||
|
||||
# Check if bot has administrator permission (bypasses all checks)
|
||||
is_admin = (base_permissions & ADMINISTRATOR) == ADMINISTRATOR
|
||||
|
||||
# Fetch channels
|
||||
channels_response = await client.get(
|
||||
f"https://discord.com/api/v10/guilds/{guild_id}/channels",
|
||||
headers=headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if channels_response.status_code == 403:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Bot does not have permission to view channels in this server. Please ensure the bot has the 'View Channels' permission.",
|
||||
)
|
||||
elif channels_response.status_code == 404:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Discord server not found. The bot may have been removed from the server.",
|
||||
)
|
||||
elif channels_response.status_code != 200:
|
||||
error_detail = channels_response.text
|
||||
try:
|
||||
error_json = channels_response.json()
|
||||
error_detail = error_json.get("message", error_detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code=channels_response.status_code,
|
||||
detail=f"Failed to fetch Discord channels: {error_detail}",
|
||||
)
|
||||
|
||||
channels_data = channels_response.json()
|
||||
|
||||
# Discord channel types:
|
||||
# 0 = GUILD_TEXT, 2 = GUILD_VOICE, 4 = GUILD_CATEGORY, 5 = GUILD_ANNOUNCEMENT
|
||||
# We want text channels (type 0) and announcement channels (type 5)
|
||||
text_channel_types = {0, 5}
|
||||
|
||||
text_channels = []
|
||||
for ch in channels_data:
|
||||
if ch.get("type") in text_channel_types:
|
||||
# Compute effective permissions for this channel
|
||||
if is_admin:
|
||||
# Administrators bypass all permission checks
|
||||
can_index = True
|
||||
else:
|
||||
channel_overwrites = ch.get("permission_overwrites", [])
|
||||
effective_perms = _compute_channel_permissions(
|
||||
base_permissions,
|
||||
bot_role_ids,
|
||||
bot_user_id,
|
||||
channel_overwrites,
|
||||
guild_id,
|
||||
)
|
||||
|
||||
# Bot can index if it has both VIEW_CHANNEL and READ_MESSAGE_HISTORY
|
||||
has_view = (effective_perms & VIEW_CHANNEL) == VIEW_CHANNEL
|
||||
has_read_history = (
|
||||
effective_perms & READ_MESSAGE_HISTORY
|
||||
) == READ_MESSAGE_HISTORY
|
||||
can_index = has_view and has_read_history
|
||||
|
||||
text_channels.append(
|
||||
{
|
||||
"id": ch["id"],
|
||||
"name": ch["name"],
|
||||
"type": "text" if ch["type"] == 0 else "announcement",
|
||||
"position": ch.get("position", 0),
|
||||
"category_id": ch.get("parent_id"),
|
||||
"can_index": can_index,
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by position
|
||||
text_channels.sort(key=lambda x: x["position"])
|
||||
|
||||
logger.info(
|
||||
f"Fetched {len(text_channels)} text channels for Discord connector {connector_id}"
|
||||
)
|
||||
|
||||
return text_channels
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get Discord channels for connector {connector_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get Discord channels: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -446,13 +446,26 @@ async def refresh_jira_token(
|
|||
|
||||
if token_response.status_code != 200:
|
||||
error_detail = token_response.text
|
||||
error_code = ""
|
||||
try:
|
||||
error_json = token_response.json()
|
||||
error_detail = error_json.get(
|
||||
"error_description", error_json.get("error", error_detail)
|
||||
)
|
||||
error_code = error_json.get("error", "")
|
||||
except Exception:
|
||||
pass
|
||||
# Check if this is a token expiration/revocation error
|
||||
error_lower = (error_detail + error_code).lower()
|
||||
if (
|
||||
"invalid_grant" in error_lower
|
||||
or "expired" in error_lower
|
||||
or "revoked" in error_lower
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Jira authentication failed. Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token refresh failed: {error_detail}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -403,11 +403,24 @@ async def refresh_linear_token(
|
|||
|
||||
if token_response.status_code != 200:
|
||||
error_detail = token_response.text
|
||||
error_code = ""
|
||||
try:
|
||||
error_json = token_response.json()
|
||||
error_detail = error_json.get("error_description", error_detail)
|
||||
error_code = error_json.get("error", "")
|
||||
except Exception:
|
||||
pass
|
||||
# Check if this is a token expiration/revocation error
|
||||
error_lower = (error_detail + error_code).lower()
|
||||
if (
|
||||
"invalid_grant" in error_lower
|
||||
or "expired" in error_lower
|
||||
or "revoked" in error_lower
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Linear authentication failed. Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token refresh failed: {error_detail}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@ from app.db import (
|
|||
get_async_session,
|
||||
)
|
||||
from app.schemas.new_chat import (
|
||||
CompleteCloneResponse,
|
||||
NewChatMessageAppend,
|
||||
NewChatMessageRead,
|
||||
NewChatRequest,
|
||||
|
|
@ -46,14 +45,13 @@ from app.schemas.new_chat import (
|
|||
NewChatThreadUpdate,
|
||||
NewChatThreadVisibilityUpdate,
|
||||
NewChatThreadWithMessages,
|
||||
PublicShareToggleRequest,
|
||||
PublicShareToggleResponse,
|
||||
RegenerateRequest,
|
||||
SnapshotCreateResponse,
|
||||
SnapshotListResponse,
|
||||
ThreadHistoryLoadResponse,
|
||||
ThreadListItem,
|
||||
ThreadListResponse,
|
||||
)
|
||||
from app.services.public_chat_service import toggle_public_share
|
||||
from app.tasks.chat.stream_new_chat import stream_new_chat
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
|
@ -219,7 +217,6 @@ async def list_threads(
|
|||
visibility=thread.visibility,
|
||||
created_by_id=thread.created_by_id,
|
||||
is_own_thread=is_own_thread,
|
||||
public_share_enabled=thread.public_share_enabled,
|
||||
created_at=thread.created_at,
|
||||
updated_at=thread.updated_at,
|
||||
)
|
||||
|
|
@ -321,7 +318,6 @@ async def search_threads(
|
|||
thread.created_by_id == user.id
|
||||
or (thread.created_by_id is None and is_search_space_owner)
|
||||
),
|
||||
public_share_enabled=thread.public_share_enabled,
|
||||
created_at=thread.created_at,
|
||||
updated_at=thread.updated_at,
|
||||
)
|
||||
|
|
@ -670,66 +666,6 @@ async def delete_thread(
|
|||
) from None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/threads/{thread_id}/complete-clone", response_model=CompleteCloneResponse
|
||||
)
|
||||
async def complete_clone(
|
||||
thread_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Complete the cloning process for a thread.
|
||||
|
||||
Copies messages and podcasts from the source thread.
|
||||
Sets clone_pending=False and needs_history_bootstrap=True when done.
|
||||
|
||||
Requires authentication and ownership of the thread.
|
||||
"""
|
||||
from app.services.public_chat_service import complete_clone_content
|
||||
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
if thread.created_by_id != user.id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized")
|
||||
|
||||
if not thread.clone_pending:
|
||||
raise HTTPException(status_code=400, detail="Clone already completed")
|
||||
|
||||
if not thread.cloned_from_thread_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No source thread to clone from"
|
||||
)
|
||||
|
||||
message_count = await complete_clone_content(
|
||||
session=session,
|
||||
target_thread=thread,
|
||||
source_thread_id=thread.cloned_from_thread_id,
|
||||
target_search_space_id=thread.search_space_id,
|
||||
)
|
||||
|
||||
return CompleteCloneResponse(
|
||||
status="success",
|
||||
message_count=message_count,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An unexpected error occurred while completing clone: {e!s}",
|
||||
) from None
|
||||
|
||||
|
||||
@router.patch("/threads/{thread_id}/visibility", response_model=NewChatThreadRead)
|
||||
async def update_thread_visibility(
|
||||
thread_id: int,
|
||||
|
|
@ -795,32 +731,83 @@ async def update_thread_visibility(
|
|||
) from None
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/threads/{thread_id}/public-share", response_model=PublicShareToggleResponse
|
||||
)
|
||||
async def update_thread_public_share(
|
||||
# =============================================================================
|
||||
# Snapshot Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/snapshots", response_model=SnapshotCreateResponse)
|
||||
async def create_thread_snapshot(
|
||||
thread_id: int,
|
||||
request: Request,
|
||||
toggle_request: PublicShareToggleRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Enable or disable public sharing for a thread.
|
||||
Create a public snapshot of the thread.
|
||||
|
||||
Only the creator of the thread can manage public sharing.
|
||||
When enabled, returns a public URL that anyone can use to view the chat.
|
||||
Returns existing snapshot URL if content unchanged (deduplication).
|
||||
Only the thread owner can create snapshots.
|
||||
"""
|
||||
from app.services.public_chat_service import create_snapshot
|
||||
|
||||
base_url = str(request.base_url).rstrip("/")
|
||||
return await toggle_public_share(
|
||||
return await create_snapshot(
|
||||
session=session,
|
||||
thread_id=thread_id,
|
||||
enabled=toggle_request.enabled,
|
||||
user=user,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/snapshots", response_model=SnapshotListResponse)
|
||||
async def list_thread_snapshots(
|
||||
thread_id: int,
|
||||
request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
List all public snapshots for this thread.
|
||||
|
||||
Only the thread owner can view snapshots.
|
||||
"""
|
||||
from app.services.public_chat_service import list_snapshots_for_thread
|
||||
|
||||
base_url = str(request.base_url).rstrip("/")
|
||||
return SnapshotListResponse(
|
||||
snapshots=await list_snapshots_for_thread(
|
||||
session=session,
|
||||
thread_id=thread_id,
|
||||
user=user,
|
||||
base_url=base_url,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/threads/{thread_id}/snapshots/{snapshot_id}")
|
||||
async def delete_thread_snapshot(
|
||||
thread_id: int,
|
||||
snapshot_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Delete a specific snapshot.
|
||||
|
||||
Only the thread owner can delete snapshots.
|
||||
"""
|
||||
from app.services.public_chat_service import delete_snapshot
|
||||
|
||||
await delete_snapshot(
|
||||
session=session,
|
||||
thread_id=thread_id,
|
||||
snapshot_id=snapshot_id,
|
||||
user=user,
|
||||
)
|
||||
return {"message": "Snapshot deleted successfully"}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Message Endpoints
|
||||
# =============================================================================
|
||||
|
|
@ -1286,6 +1273,8 @@ async def regenerate_response(
|
|||
.limit(2)
|
||||
)
|
||||
messages_to_delete = list(last_messages_result.scalars().all())
|
||||
|
||||
message_ids_to_delete = [msg.id for msg in messages_to_delete]
|
||||
|
||||
# Get search space for LLM config
|
||||
search_space_result = await session.execute(
|
||||
|
|
@ -1329,6 +1318,15 @@ async def regenerate_response(
|
|||
for msg in messages_to_delete:
|
||||
await session.delete(msg)
|
||||
await session.commit()
|
||||
|
||||
# Delete any public snapshots that contain the modified messages
|
||||
from app.services.public_chat_service import (
|
||||
delete_affected_snapshots,
|
||||
)
|
||||
|
||||
await delete_affected_snapshots(
|
||||
session, thread_id, message_ids_to_delete
|
||||
)
|
||||
except Exception as cleanup_error:
|
||||
# Log but don't fail - the new messages are already streamed
|
||||
print(
|
||||
|
|
|
|||
|
|
@ -407,11 +407,24 @@ async def refresh_notion_token(
|
|||
|
||||
if token_response.status_code != 200:
|
||||
error_detail = token_response.text
|
||||
error_code = ""
|
||||
try:
|
||||
error_json = token_response.json()
|
||||
error_detail = error_json.get("error_description", error_detail)
|
||||
error_code = error_json.get("error", "")
|
||||
except Exception:
|
||||
pass
|
||||
# Check if this is a token expiration/revocation error
|
||||
error_lower = (error_detail + error_code).lower()
|
||||
if (
|
||||
"invalid_grant" in error_lower
|
||||
or "expired" in error_lower
|
||||
or "revoked" in error_lower
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Notion authentication failed. Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token refresh failed: {error_detail}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from app.db import (
|
|||
get_async_session,
|
||||
)
|
||||
from app.schemas import PodcastRead
|
||||
from app.users import current_active_user, current_optional_user
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -82,17 +82,14 @@ async def read_podcasts(
|
|||
async def read_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User | None = Depends(current_optional_user),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get a specific podcast by ID.
|
||||
|
||||
Access is allowed if:
|
||||
- User is authenticated with PODCASTS_READ permission, OR
|
||||
- Podcast belongs to a publicly shared thread
|
||||
Requires authentication with PODCASTS_READ permission.
|
||||
For public podcast access, use /public/{share_token}/podcasts/{podcast_id}/stream
|
||||
"""
|
||||
from app.services.public_chat_service import is_podcast_publicly_accessible
|
||||
|
||||
try:
|
||||
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
|
||||
podcast = result.scalars().first()
|
||||
|
|
@ -103,18 +100,13 @@ async def read_podcast(
|
|||
detail="Podcast not found",
|
||||
)
|
||||
|
||||
is_public = await is_podcast_publicly_accessible(session, podcast_id)
|
||||
|
||||
if not is_public:
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
podcast.search_space_id,
|
||||
Permission.PODCASTS_READ.value,
|
||||
"You don't have permission to read podcasts in this search space",
|
||||
)
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
podcast.search_space_id,
|
||||
Permission.PODCASTS_READ.value,
|
||||
"You don't have permission to read podcasts in this search space",
|
||||
)
|
||||
|
||||
return PodcastRead.from_orm_with_entries(podcast)
|
||||
except HTTPException as he:
|
||||
|
|
@ -168,19 +160,16 @@ async def delete_podcast(
|
|||
async def stream_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User | None = Depends(current_optional_user),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Stream a podcast audio file.
|
||||
|
||||
Access is allowed if:
|
||||
- User is authenticated with PODCASTS_READ permission, OR
|
||||
- Podcast belongs to a publicly shared thread
|
||||
Requires authentication with PODCASTS_READ permission.
|
||||
For public podcast access, use /public/{share_token}/podcasts/{podcast_id}/stream
|
||||
|
||||
Note: Both /stream and /audio endpoints are supported for compatibility.
|
||||
"""
|
||||
from app.services.public_chat_service import is_podcast_publicly_accessible
|
||||
|
||||
try:
|
||||
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
|
||||
podcast = result.scalars().first()
|
||||
|
|
@ -188,19 +177,13 @@ async def stream_podcast(
|
|||
if not podcast:
|
||||
raise HTTPException(status_code=404, detail="Podcast not found")
|
||||
|
||||
is_public = await is_podcast_publicly_accessible(session, podcast_id)
|
||||
|
||||
if not is_public:
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
podcast.search_space_id,
|
||||
Permission.PODCASTS_READ.value,
|
||||
"You don't have permission to access podcasts in this search space",
|
||||
)
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
podcast.search_space_id,
|
||||
Permission.PODCASTS_READ.value,
|
||||
"You don't have permission to access podcasts in this search space",
|
||||
)
|
||||
|
||||
file_path = podcast.file_location
|
||||
|
||||
|
|
|
|||
|
|
@ -1,21 +1,25 @@
|
|||
"""
|
||||
Routes for public chat access (unauthenticated and mixed-auth endpoints).
|
||||
Routes for public chat access via immutable snapshots.
|
||||
|
||||
All public endpoints use share_token for access - no authentication required
|
||||
for read operations. Clone requires authentication.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import ChatVisibility, NewChatThread, User, get_async_session
|
||||
from app.db import User, get_async_session
|
||||
from app.schemas.new_chat import (
|
||||
CloneInitResponse,
|
||||
CloneResponse,
|
||||
PublicChatResponse,
|
||||
)
|
||||
from app.services.public_chat_service import (
|
||||
clone_from_snapshot,
|
||||
get_public_chat,
|
||||
get_thread_by_share_token,
|
||||
get_user_default_search_space,
|
||||
get_snapshot_podcast,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
|
||||
|
|
@ -28,57 +32,85 @@ async def read_public_chat(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""
|
||||
Get a public chat by share token.
|
||||
Get a public chat snapshot by share token.
|
||||
|
||||
No authentication required.
|
||||
Returns sanitized content (citations stripped).
|
||||
Returns immutable snapshot data (sanitized, citations stripped).
|
||||
"""
|
||||
return await get_public_chat(session, share_token)
|
||||
|
||||
|
||||
@router.post("/{share_token}/clone", response_model=CloneInitResponse)
|
||||
async def clone_public_chat_endpoint(
|
||||
@router.post("/{share_token}/clone", response_model=CloneResponse)
|
||||
async def clone_public_chat(
|
||||
share_token: str,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Initialize cloning a public chat to the user's account.
|
||||
|
||||
Creates an empty thread with clone_pending=True.
|
||||
Frontend should redirect to the new thread and call /complete-clone.
|
||||
Clone a public chat snapshot to the user's account.
|
||||
|
||||
Creates thread and copies messages.
|
||||
Requires authentication.
|
||||
"""
|
||||
source_thread = await get_thread_by_share_token(session, share_token)
|
||||
return await clone_from_snapshot(session, share_token, user)
|
||||
|
||||
if not source_thread:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Chat not found or no longer public"
|
||||
)
|
||||
|
||||
target_search_space_id = await get_user_default_search_space(session, user.id)
|
||||
@router.get("/{share_token}/podcasts/{podcast_id}")
|
||||
async def get_public_podcast(
|
||||
share_token: str,
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""
|
||||
Get podcast details from a public chat snapshot.
|
||||
|
||||
if target_search_space_id is None:
|
||||
raise HTTPException(status_code=400, detail="No search space found for user")
|
||||
No authentication required - the share_token provides access.
|
||||
Returns podcast info including transcript.
|
||||
"""
|
||||
podcast_info = await get_snapshot_podcast(session, share_token, podcast_id)
|
||||
|
||||
new_thread = NewChatThread(
|
||||
title=source_thread.title,
|
||||
archived=False,
|
||||
visibility=ChatVisibility.PRIVATE,
|
||||
search_space_id=target_search_space_id,
|
||||
created_by_id=user.id,
|
||||
public_share_enabled=False,
|
||||
cloned_from_thread_id=source_thread.id,
|
||||
cloned_at=datetime.now(UTC),
|
||||
clone_pending=True,
|
||||
)
|
||||
session.add(new_thread)
|
||||
await session.commit()
|
||||
await session.refresh(new_thread)
|
||||
|
||||
return CloneInitResponse(
|
||||
thread_id=new_thread.id,
|
||||
search_space_id=target_search_space_id,
|
||||
share_token=share_token,
|
||||
if not podcast_info:
|
||||
raise HTTPException(status_code=404, detail="Podcast not found")
|
||||
|
||||
return {
|
||||
"id": podcast_info.get("original_id"),
|
||||
"title": podcast_info.get("title"),
|
||||
"status": "ready",
|
||||
"podcast_transcript": podcast_info.get("transcript"),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{share_token}/podcasts/{podcast_id}/stream")
|
||||
async def stream_public_podcast(
|
||||
share_token: str,
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""
|
||||
Stream a podcast from a public chat snapshot.
|
||||
|
||||
No authentication required - the share_token provides access.
|
||||
Looks up podcast by original_id in the snapshot's podcasts array.
|
||||
"""
|
||||
podcast_info = await get_snapshot_podcast(session, share_token, podcast_id)
|
||||
|
||||
if not podcast_info:
|
||||
raise HTTPException(status_code=404, detail="Podcast not found")
|
||||
|
||||
file_path = podcast_info.get("file_path")
|
||||
|
||||
if not file_path or not os.path.isfile(file_path):
|
||||
raise HTTPException(status_code=404, detail="Podcast audio file not found")
|
||||
|
||||
def iterfile():
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
|
||||
return StreamingResponse(
|
||||
iterfile(),
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Disposition": f"inline; filename={os.path.basename(file_path)}",
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,10 +19,12 @@ Non-OAuth connectors (BookStack, GitHub, etc.) are limited to one per search spa
|
|||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import pytz
|
||||
import redis
|
||||
from dateutil.parser import isoparse
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
|
@ -78,6 +80,27 @@ from app.utils.rbac import check_permission
|
|||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis client for heartbeat tracking
|
||||
_heartbeat_redis_client: redis.Redis | None = None
|
||||
|
||||
# Redis key TTL - notification is stale if no heartbeat in this time
|
||||
HEARTBEAT_TTL_SECONDS = 120 # 2 minutes
|
||||
|
||||
|
||||
def get_heartbeat_redis_client() -> redis.Redis:
|
||||
"""Get or create Redis client for heartbeat tracking."""
|
||||
global _heartbeat_redis_client
|
||||
if _heartbeat_redis_client is None:
|
||||
redis_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
|
||||
_heartbeat_redis_client = redis.from_url(redis_url, decode_responses=True)
|
||||
return _heartbeat_redis_client
|
||||
|
||||
|
||||
def _get_heartbeat_key(notification_id: int) -> str:
|
||||
"""Generate Redis key for notification heartbeat."""
|
||||
return f"indexing:heartbeat:{notification_id}"
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
|
|
@ -1137,6 +1160,7 @@ async def run_slack_indexing(
|
|||
end_date=end_date,
|
||||
indexing_function=index_slack_messages,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1150,6 +1174,7 @@ async def _run_indexing_with_notifications(
|
|||
indexing_function,
|
||||
update_timestamp_func=None,
|
||||
supports_retry_callback: bool = False,
|
||||
supports_heartbeat_callback: bool = False,
|
||||
):
|
||||
"""
|
||||
Generic helper to run indexing with real-time notifications.
|
||||
|
|
@ -1164,11 +1189,14 @@ async def _run_indexing_with_notifications(
|
|||
indexing_function: Async function that performs the indexing
|
||||
update_timestamp_func: Optional function to update connector timestamp
|
||||
supports_retry_callback: Whether the indexing function supports on_retry_callback
|
||||
supports_heartbeat_callback: Whether the indexing function supports on_heartbeat_callback
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
|
||||
notification = None
|
||||
# Track indexed count for retry notifications
|
||||
# Track indexed count for retry notifications and heartbeat
|
||||
current_indexed_count = 0
|
||||
|
||||
try:
|
||||
|
|
@ -1195,6 +1223,16 @@ async def _run_indexing_with_notifications(
|
|||
)
|
||||
)
|
||||
|
||||
# Set initial Redis heartbeat for stale detection
|
||||
if notification:
|
||||
try:
|
||||
heartbeat_key = _get_heartbeat_key(notification.id)
|
||||
get_heartbeat_redis_client().setex(
|
||||
heartbeat_key, HEARTBEAT_TTL_SECONDS, "0"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to set initial Redis heartbeat: {e}")
|
||||
|
||||
# Update notification to fetching stage
|
||||
if notification:
|
||||
await NotificationService.connector_indexing.notify_indexing_progress(
|
||||
|
|
@ -1227,6 +1265,40 @@ async def _run_indexing_with_notifications(
|
|||
# Don't let notification errors break the indexing
|
||||
logger.warning(f"Failed to update retry notification: {e}")
|
||||
|
||||
# Create heartbeat callback for connectors that support it
|
||||
# This updates the notification periodically during long-running indexing loops
|
||||
# to prevent the task from appearing stuck if the worker crashes
|
||||
async def on_heartbeat_callback(indexed_count: int) -> None:
|
||||
"""Callback to update notification during indexing (heartbeat)."""
|
||||
nonlocal notification, current_indexed_count
|
||||
current_indexed_count = indexed_count
|
||||
if notification:
|
||||
try:
|
||||
# Set Redis heartbeat key with TTL (fast, for stale detection)
|
||||
heartbeat_key = _get_heartbeat_key(notification.id)
|
||||
get_heartbeat_redis_client().setex(
|
||||
heartbeat_key, HEARTBEAT_TTL_SECONDS, str(indexed_count)
|
||||
)
|
||||
except Exception as e:
|
||||
# Don't let Redis errors break the indexing
|
||||
logger.warning(f"Failed to set Redis heartbeat: {e}")
|
||||
|
||||
try:
|
||||
# Still update DB notification for progress display
|
||||
await session.refresh(notification)
|
||||
await (
|
||||
NotificationService.connector_indexing.notify_indexing_progress(
|
||||
session=session,
|
||||
notification=notification,
|
||||
indexed_count=indexed_count,
|
||||
stage="processing",
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
# Don't let notification errors break the indexing
|
||||
logger.warning(f"Failed to update heartbeat notification: {e}")
|
||||
|
||||
# Build kwargs for indexing function
|
||||
indexing_kwargs = {
|
||||
"session": session,
|
||||
|
|
@ -1242,6 +1314,10 @@ async def _run_indexing_with_notifications(
|
|||
if supports_retry_callback:
|
||||
indexing_kwargs["on_retry_callback"] = on_retry_callback
|
||||
|
||||
# Add heartbeat callback for connectors that support it
|
||||
if supports_heartbeat_callback:
|
||||
indexing_kwargs["on_heartbeat_callback"] = on_heartbeat_callback
|
||||
|
||||
# Run the indexing function
|
||||
# Some indexers return (indexed, error), others return (indexed, skipped, error)
|
||||
result = await indexing_function(**indexing_kwargs)
|
||||
|
|
@ -1398,6 +1474,32 @@ async def _run_indexing_with_notifications(
|
|||
await (
|
||||
session.commit()
|
||||
) # Commit to ensure Electric SQL syncs the notification update
|
||||
except SoftTimeLimitExceeded:
|
||||
# Celery soft time limit was reached - task is about to be killed
|
||||
# Gracefully save progress and mark as interrupted
|
||||
logger.warning(
|
||||
f"Soft time limit reached for connector {connector_id}. "
|
||||
f"Saving partial progress: {current_indexed_count} items indexed."
|
||||
)
|
||||
|
||||
if notification:
|
||||
try:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.connector_indexing.notify_indexing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
indexed_count=current_indexed_count,
|
||||
error_message="Time limit reached. Partial sync completed. Please run again for remaining items.",
|
||||
is_warning=True, # Mark as warning since partial data was indexed
|
||||
)
|
||||
await session.commit()
|
||||
except Exception as notif_error:
|
||||
logger.error(
|
||||
f"Failed to update notification on soft timeout: {notif_error!s}"
|
||||
)
|
||||
|
||||
# Re-raise so Celery knows the task was terminated
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in indexing task: {e!s}", exc_info=True)
|
||||
|
||||
|
|
@ -1409,12 +1511,20 @@ async def _run_indexing_with_notifications(
|
|||
await NotificationService.connector_indexing.notify_indexing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
indexed_count=0,
|
||||
indexed_count=current_indexed_count, # Use tracked count, not 0
|
||||
error_message=str(e),
|
||||
skipped_count=None, # Unknown on exception
|
||||
)
|
||||
except Exception as notif_error:
|
||||
logger.error(f"Failed to update notification: {notif_error!s}")
|
||||
finally:
|
||||
# Clean up Redis heartbeat key when task completes (success or failure)
|
||||
if notification:
|
||||
try:
|
||||
heartbeat_key = _get_heartbeat_key(notification.id)
|
||||
get_heartbeat_redis_client().delete(heartbeat_key)
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors - key will expire anyway
|
||||
|
||||
|
||||
async def run_notion_indexing_with_new_session(
|
||||
|
|
@ -1439,6 +1549,7 @@ async def run_notion_indexing_with_new_session(
|
|||
indexing_function=index_notion_pages,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_retry_callback=True, # Notion connector supports retry notifications
|
||||
supports_heartbeat_callback=True, # Notion connector supports heartbeat notifications
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1471,6 +1582,7 @@ async def run_notion_indexing(
|
|||
indexing_function=index_notion_pages,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_retry_callback=True, # Notion connector supports retry notifications
|
||||
supports_heartbeat_callback=True, # Notion connector supports heartbeat notifications
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1521,6 +1633,7 @@ async def run_github_indexing(
|
|||
end_date=end_date,
|
||||
indexing_function=index_github_repos,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1571,6 +1684,7 @@ async def run_linear_indexing(
|
|||
end_date=end_date,
|
||||
indexing_function=index_linear_issues,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1620,6 +1734,7 @@ async def run_discord_indexing(
|
|||
end_date=end_date,
|
||||
indexing_function=index_discord_messages,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1670,6 +1785,7 @@ async def run_teams_indexing(
|
|||
end_date=end_date,
|
||||
indexing_function=index_teams_messages,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1720,6 +1836,7 @@ async def run_jira_indexing(
|
|||
end_date=end_date,
|
||||
indexing_function=index_jira_issues,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1772,6 +1889,7 @@ async def run_confluence_indexing(
|
|||
end_date=end_date,
|
||||
indexing_function=index_confluence_pages,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1822,6 +1940,7 @@ async def run_clickup_indexing(
|
|||
end_date=end_date,
|
||||
indexing_function=index_clickup_tasks,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1872,6 +1991,7 @@ async def run_airtable_indexing(
|
|||
end_date=end_date,
|
||||
indexing_function=index_airtable_records,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1924,6 +2044,7 @@ async def run_google_calendar_indexing(
|
|||
end_date=end_date,
|
||||
indexing_function=index_google_calendar_events,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1998,6 +2119,7 @@ async def run_google_gmail_indexing(
|
|||
end_date=end_date,
|
||||
indexing_function=gmail_indexing_wrapper,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -2206,6 +2328,7 @@ async def run_luma_indexing(
|
|||
end_date=end_date,
|
||||
indexing_function=index_luma_events,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -2257,6 +2380,7 @@ async def run_elasticsearch_indexing(
|
|||
end_date=end_date,
|
||||
indexing_function=index_elasticsearch_documents,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -2306,6 +2430,7 @@ async def run_web_page_indexing(
|
|||
end_date=end_date,
|
||||
indexing_function=index_crawled_urls,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -2360,6 +2485,7 @@ async def run_bookstack_indexing(
|
|||
end_date=end_date,
|
||||
indexing_function=index_bookstack_pages,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -2412,6 +2538,7 @@ async def run_obsidian_indexing(
|
|||
end_date=end_date,
|
||||
indexing_function=index_obsidian_vault,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -2465,6 +2592,7 @@ async def run_composio_indexing(
|
|||
end_date=end_date,
|
||||
indexing_function=index_composio_connector,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ Handles OAuth 2.0 authentication flow for Slack connector.
|
|||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
|
|
@ -14,6 +15,7 @@ from fastapi.responses import RedirectResponse
|
|||
from pydantic import ValidationError
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
|
|
@ -418,6 +420,19 @@ async def refresh_slack_token(
|
|||
error_detail = error_json.get("error", error_detail)
|
||||
except Exception:
|
||||
pass
|
||||
# Check if this is a token expiration/revocation error
|
||||
error_lower = error_detail.lower()
|
||||
if (
|
||||
"invalid_grant" in error_lower
|
||||
or "invalid_auth" in error_lower
|
||||
or "token_revoked" in error_lower
|
||||
or "expired" in error_lower
|
||||
or "revoked" in error_lower
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Slack authentication failed. Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token refresh failed: {error_detail}"
|
||||
)
|
||||
|
|
@ -427,6 +442,20 @@ async def refresh_slack_token(
|
|||
# Slack OAuth v2 returns success status in the JSON
|
||||
if not token_json.get("ok", False):
|
||||
error_msg = token_json.get("error", "Unknown error")
|
||||
# Check if this is a token expiration/revocation error
|
||||
error_lower = error_msg.lower()
|
||||
if (
|
||||
"invalid_grant" in error_lower
|
||||
or "invalid_auth" in error_lower
|
||||
or "invalid_refresh_token" in error_lower
|
||||
or "token_revoked" in error_lower
|
||||
or "expired" in error_lower
|
||||
or "revoked" in error_lower
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Slack authentication failed. Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Slack OAuth refresh error: {error_msg}"
|
||||
)
|
||||
|
|
@ -490,3 +519,88 @@ async def refresh_slack_token(
|
|||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to refresh Slack token: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/slack/connector/{connector_id}/channels")
|
||||
async def get_slack_channels(
|
||||
connector_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get list of Slack channels with bot membership status.
|
||||
|
||||
This endpoint fetches all channels the bot can see and indicates
|
||||
whether the bot is a member of each channel (required for accessing messages).
|
||||
|
||||
Args:
|
||||
connector_id: The Slack connector ID
|
||||
session: Database session
|
||||
user: Current authenticated user
|
||||
|
||||
Returns:
|
||||
List of channels with id, name, is_private, and is_member fields
|
||||
"""
|
||||
try:
|
||||
# Get the connector and verify ownership
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.SLACK_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalar_one_or_none()
|
||||
|
||||
if not connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Slack connector not found or access denied",
|
||||
)
|
||||
|
||||
# Get credentials and decrypt bot token
|
||||
credentials = SlackAuthCredentialsBase.from_dict(connector.config)
|
||||
token_encryption = get_token_encryption()
|
||||
is_encrypted = connector.config.get("_token_encrypted", False)
|
||||
|
||||
bot_token = credentials.bot_token
|
||||
if is_encrypted and bot_token:
|
||||
try:
|
||||
bot_token = token_encryption.decrypt_token(bot_token)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt bot token: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to decrypt stored bot token"
|
||||
) from e
|
||||
|
||||
if not bot_token:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No bot token available. Please re-authenticate.",
|
||||
)
|
||||
|
||||
# Import SlackHistory here to avoid circular imports
|
||||
from app.connectors.slack_history import SlackHistory
|
||||
|
||||
# Create Slack client with direct token (simple pattern for quick operations)
|
||||
slack_client = SlackHistory(token=bot_token)
|
||||
|
||||
channels = await slack_client.get_all_channels(include_private=True)
|
||||
|
||||
logger.info(
|
||||
f"Fetched {len(channels)} channels for Slack connector {connector_id}"
|
||||
)
|
||||
|
||||
return channels
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get Slack channels for connector {connector_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to get Slack channels: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -420,11 +420,24 @@ async def refresh_teams_token(
|
|||
|
||||
if token_response.status_code != 200:
|
||||
error_detail = token_response.text
|
||||
error_code = ""
|
||||
try:
|
||||
error_json = token_response.json()
|
||||
error_detail = error_json.get("error_description", error_detail)
|
||||
error_code = error_json.get("error", "")
|
||||
except Exception:
|
||||
pass
|
||||
# Check if this is a token expiration/revocation error
|
||||
error_lower = (error_detail + error_code).lower()
|
||||
if (
|
||||
"invalid_grant" in error_lower
|
||||
or "expired" in error_lower
|
||||
or "revoked" in error_lower
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Microsoft Teams authentication failed. Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token refresh failed: {error_detail}"
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue