mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-25 19:15:18 +02:00
feat: SurfSense v0.0.6 init
This commit is contained in:
parent
18fc19e8d9
commit
da23012970
58 changed files with 8284 additions and 2076 deletions
0
surfsense_backend/app/tasks/__init__.py
Normal file
0
surfsense_backend/app/tasks/__init__.py
Normal file
246
surfsense_backend/app/tasks/background_tasks.py
Normal file
246
surfsense_backend/app/tasks/background_tasks.py
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
from typing import Optional, List
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from app.db import Document, DocumentType, Chunk
|
||||
from app.schemas import ExtensionDocumentContent
|
||||
from app.config import config
|
||||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||
from datetime import datetime
|
||||
from app.utils.document_converters import convert_document_to_markdown
|
||||
from langchain_core.documents import Document as LangChainDocument
|
||||
from langchain_community.document_loaders import FireCrawlLoader, AsyncChromiumLoader
|
||||
from langchain_community.document_transformers import MarkdownifyTransformer
|
||||
import validators
|
||||
|
||||
md = MarkdownifyTransformer()
|
||||
|
||||
|
||||
async def add_crawled_url_document(
|
||||
session: AsyncSession,
|
||||
url: str,
|
||||
search_space_id: int
|
||||
) -> Optional[Document]:
|
||||
try:
|
||||
|
||||
if not validators.url(url):
|
||||
raise ValueError(f"Url {url} is not a valid URL address")
|
||||
|
||||
if config.FIRECRAWL_API_KEY:
|
||||
crawl_loader = FireCrawlLoader(
|
||||
url=url,
|
||||
api_key=config.FIRECRAWL_API_KEY,
|
||||
mode="scrape",
|
||||
params={
|
||||
"formats": ["markdown"],
|
||||
"excludeTags": ["a"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
crawl_loader = AsyncChromiumLoader(urls=[url], headless=True)
|
||||
|
||||
url_crawled = await crawl_loader.aload()
|
||||
|
||||
if type(crawl_loader) == FireCrawlLoader:
|
||||
content_in_markdown = url_crawled[0].page_content
|
||||
elif type(crawl_loader) == AsyncChromiumLoader:
|
||||
content_in_markdown = md.transform_documents(url_crawled)[
|
||||
0].page_content
|
||||
|
||||
# Format document metadata in a more maintainable way
|
||||
metadata_sections = [
|
||||
("METADATA", [
|
||||
f"{key.upper()}: {value}" for key, value in url_crawled[0].metadata.items()
|
||||
]),
|
||||
("CONTENT", [
|
||||
"FORMAT: markdown",
|
||||
"TEXT_START",
|
||||
content_in_markdown,
|
||||
"TEXT_END"
|
||||
])
|
||||
]
|
||||
|
||||
# Build the document string more efficiently
|
||||
document_parts = []
|
||||
document_parts.append("<DOCUMENT>")
|
||||
|
||||
for section_title, section_content in metadata_sections:
|
||||
document_parts.append(f"<{section_title}>")
|
||||
document_parts.extend(section_content)
|
||||
document_parts.append(f"</{section_title}>")
|
||||
|
||||
document_parts.append("</DOCUMENT>")
|
||||
combined_document_string = '\n'.join(document_parts)
|
||||
|
||||
# Generate summary
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content)
|
||||
|
||||
# Process chunks
|
||||
chunks = [
|
||||
Chunk(content=chunk.text, embedding=chunk.embedding)
|
||||
for chunk in config.chunker_instance.chunk(content_in_markdown)
|
||||
]
|
||||
|
||||
# Create and store document
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=url_crawled[0].metadata['title'] if type(
|
||||
crawl_loader) == FireCrawlLoader else url_crawled[0].metadata['source'],
|
||||
document_type=DocumentType.CRAWLED_URL,
|
||||
document_metadata=url_crawled[0].metadata,
|
||||
content=summary_content,
|
||||
embedding=summary_embedding,
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
session.add(document)
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
return document
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise RuntimeError(f"Failed to crawl URL: {str(e)}")
|
||||
|
||||
|
||||
async def add_extension_received_document(
|
||||
session: AsyncSession,
|
||||
content: ExtensionDocumentContent,
|
||||
search_space_id: int
|
||||
) -> Optional[Document]:
|
||||
"""
|
||||
Process and store document content received from the SurfSense Extension.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
content: Document content from extension
|
||||
search_space_id: ID of the search space
|
||||
|
||||
Returns:
|
||||
Document object if successful, None if failed
|
||||
"""
|
||||
try:
|
||||
# Format document metadata in a more maintainable way
|
||||
metadata_sections = [
|
||||
("METADATA", [
|
||||
f"SESSION_ID: {content.metadata.BrowsingSessionId}",
|
||||
f"URL: {content.metadata.VisitedWebPageURL}",
|
||||
f"TITLE: {content.metadata.VisitedWebPageTitle}",
|
||||
f"REFERRER: {content.metadata.VisitedWebPageReffererURL}",
|
||||
f"TIMESTAMP: {content.metadata.VisitedWebPageDateWithTimeInISOString}",
|
||||
f"DURATION_MS: {content.metadata.VisitedWebPageVisitDurationInMilliseconds}"
|
||||
]),
|
||||
("CONTENT", [
|
||||
"FORMAT: markdown",
|
||||
"TEXT_START",
|
||||
content.pageContent,
|
||||
"TEXT_END"
|
||||
])
|
||||
]
|
||||
|
||||
# Build the document string more efficiently
|
||||
document_parts = []
|
||||
document_parts.append("<DOCUMENT>")
|
||||
|
||||
for section_title, section_content in metadata_sections:
|
||||
document_parts.append(f"<{section_title}>")
|
||||
document_parts.extend(section_content)
|
||||
document_parts.append(f"</{section_title}>")
|
||||
|
||||
document_parts.append("</DOCUMENT>")
|
||||
combined_document_string = '\n'.join(document_parts)
|
||||
|
||||
# Generate summary
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content)
|
||||
|
||||
# Process chunks
|
||||
chunks = [
|
||||
Chunk(content=chunk.text, embedding=chunk.embedding)
|
||||
for chunk in config.chunker_instance.chunk(content.pageContent)
|
||||
]
|
||||
|
||||
# Create and store document
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=content.metadata.VisitedWebPageTitle,
|
||||
document_type=DocumentType.EXTENSION,
|
||||
document_metadata=content.metadata.model_dump(),
|
||||
content=summary_content,
|
||||
embedding=summary_embedding,
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
session.add(document)
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
return document
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise RuntimeError(f"Failed to process extension document: {str(e)}")
|
||||
|
||||
|
||||
async def add_received_file_document(
|
||||
session: AsyncSession,
|
||||
file_name: str,
|
||||
unstructured_processed_elements: List[LangChainDocument],
|
||||
search_space_id: int
|
||||
) -> Optional[Document]:
|
||||
try:
|
||||
file_in_markdown = await convert_document_to_markdown(unstructured_processed_elements)
|
||||
|
||||
# TODO: Check if file_markdown exceeds token limit of embedding model
|
||||
|
||||
# Generate summary
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_result = await summary_chain.ainvoke({"document": file_in_markdown})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content)
|
||||
|
||||
# Process chunks
|
||||
chunks = [
|
||||
Chunk(content=chunk.text, embedding=chunk.embedding)
|
||||
for chunk in config.chunker_instance.chunk(file_in_markdown)
|
||||
]
|
||||
|
||||
# Create and store document
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=file_name,
|
||||
document_type=DocumentType.FILE,
|
||||
document_metadata={
|
||||
"FILE_NAME": file_name,
|
||||
"SAVED_AT": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
},
|
||||
content=summary_content,
|
||||
embedding=summary_embedding,
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
session.add(document)
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
return document
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise RuntimeError(f"Failed to process file document: {str(e)}")
|
||||
486
surfsense_backend/app/tasks/connectors_indexing_tasks.py
Normal file
486
surfsense_backend/app/tasks/connectors_indexing_tasks.py
Normal file
|
|
@ -0,0 +1,486 @@
|
|||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.future import select
|
||||
from datetime import datetime, timedelta
|
||||
from app.db import Document, DocumentType, Chunk, SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.config import config
|
||||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||
from app.connectors.slack_history import SlackHistory
|
||||
from app.connectors.notion_history import NotionHistoryConnector
|
||||
from slack_sdk.errors import SlackApiError
|
||||
import logging
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def index_slack_messages(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
update_last_indexed: bool = True
|
||||
) -> Tuple[int, Optional[str]]:
|
||||
"""
|
||||
Index Slack messages from all accessible channels.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Slack connector
|
||||
search_space_id: ID of the search space to store documents in
|
||||
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
|
||||
|
||||
Returns:
|
||||
Tuple containing (number of documents indexed, error message or None)
|
||||
"""
|
||||
try:
|
||||
# Get the connector
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
return 0, f"Connector with ID {connector_id} not found or is not a Slack connector"
|
||||
|
||||
# Get the Slack token from the connector config
|
||||
slack_token = connector.config.get("SLACK_BOT_TOKEN")
|
||||
if not slack_token:
|
||||
return 0, "Slack token not found in connector config"
|
||||
|
||||
# Initialize Slack client
|
||||
slack_client = SlackHistory(token=slack_token)
|
||||
|
||||
# Calculate date range
|
||||
end_date = datetime.now()
|
||||
|
||||
# Use last_indexed_at as start date if available, otherwise use 365 days ago
|
||||
if connector.last_indexed_at:
|
||||
# Check if last_indexed_at is today
|
||||
today = datetime.now().date()
|
||||
if connector.last_indexed_at.date() == today:
|
||||
# If last indexed today, go back 1 day to ensure we don't miss anything
|
||||
start_date = end_date - timedelta(days=7)
|
||||
else:
|
||||
start_date = connector.last_indexed_at
|
||||
else:
|
||||
start_date = end_date - timedelta(days=365)
|
||||
|
||||
# Format dates for Slack API
|
||||
start_date_str = start_date.strftime("%Y-%m-%d")
|
||||
end_date_str = end_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Get all channels
|
||||
try:
|
||||
channels = slack_client.get_all_channels()
|
||||
except Exception as e:
|
||||
return 0, f"Failed to get Slack channels: {str(e)}"
|
||||
|
||||
if not channels:
|
||||
return 0, "No Slack channels found"
|
||||
|
||||
# Track the number of documents indexed
|
||||
documents_indexed = 0
|
||||
skipped_channels = []
|
||||
|
||||
# Process each channel
|
||||
for channel_name, channel_id in channels.items():
|
||||
try:
|
||||
# Check if the bot is a member of the channel
|
||||
try:
|
||||
# First try to get channel info to check if bot is a member
|
||||
channel_info = slack_client.client.conversations_info(channel=channel_id)
|
||||
|
||||
# For private channels, the bot needs to be a member
|
||||
if channel_info.get("channel", {}).get("is_private", False):
|
||||
# Check if bot is a member
|
||||
is_member = channel_info.get("channel", {}).get("is_member", False)
|
||||
if not is_member:
|
||||
logger.warning(f"Bot is not a member of private channel {channel_name} ({channel_id}). Skipping.")
|
||||
skipped_channels.append(f"{channel_name} (private, bot not a member)")
|
||||
continue
|
||||
except SlackApiError as e:
|
||||
if "not_in_channel" in str(e) or "channel_not_found" in str(e):
|
||||
logger.warning(f"Bot cannot access channel {channel_name} ({channel_id}). Skipping.")
|
||||
skipped_channels.append(f"{channel_name} (access error)")
|
||||
continue
|
||||
else:
|
||||
# Re-raise if it's a different error
|
||||
raise
|
||||
|
||||
# Get messages for this channel
|
||||
messages, error = slack_client.get_history_by_date_range(
|
||||
channel_id=channel_id,
|
||||
start_date=start_date_str,
|
||||
end_date=end_date_str,
|
||||
limit=1000 # Limit to 1000 messages per channel
|
||||
)
|
||||
|
||||
if error:
|
||||
logger.warning(f"Error getting messages from channel {channel_name}: {error}")
|
||||
skipped_channels.append(f"{channel_name} (error: {error})")
|
||||
continue # Skip this channel if there's an error
|
||||
|
||||
if not messages:
|
||||
logger.info(f"No messages found in channel {channel_name} for the specified date range.")
|
||||
continue # Skip if no messages
|
||||
|
||||
# Format messages with user info
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
# Skip bot messages and system messages
|
||||
if msg.get("subtype") in ["bot_message", "channel_join", "channel_leave"]:
|
||||
continue
|
||||
|
||||
formatted_msg = slack_client.format_message(msg, include_user_info=True)
|
||||
formatted_messages.append(formatted_msg)
|
||||
|
||||
if not formatted_messages:
|
||||
logger.info(f"No valid messages found in channel {channel_name} after filtering.")
|
||||
continue # Skip if no valid messages after filtering
|
||||
|
||||
# Convert messages to markdown format
|
||||
channel_content = f"# Slack Channel: {channel_name}\n\n"
|
||||
|
||||
for msg in formatted_messages:
|
||||
user_name = msg.get("user_name", "Unknown User")
|
||||
timestamp = msg.get("datetime", "Unknown Time")
|
||||
text = msg.get("text", "")
|
||||
|
||||
channel_content += f"## {user_name} ({timestamp})\n\n{text}\n\n---\n\n"
|
||||
|
||||
# Format document metadata
|
||||
metadata_sections = [
|
||||
("METADATA", [
|
||||
f"CHANNEL_NAME: {channel_name}",
|
||||
f"CHANNEL_ID: {channel_id}",
|
||||
f"START_DATE: {start_date_str}",
|
||||
f"END_DATE: {end_date_str}",
|
||||
f"MESSAGE_COUNT: {len(formatted_messages)}",
|
||||
f"INDEXED_AT: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
]),
|
||||
("CONTENT", [
|
||||
"FORMAT: markdown",
|
||||
"TEXT_START",
|
||||
channel_content,
|
||||
"TEXT_END"
|
||||
])
|
||||
]
|
||||
|
||||
# Build the document string
|
||||
document_parts = []
|
||||
document_parts.append("<DOCUMENT>")
|
||||
|
||||
for section_title, section_content in metadata_sections:
|
||||
document_parts.append(f"<{section_title}>")
|
||||
document_parts.extend(section_content)
|
||||
document_parts.append(f"</{section_title}>")
|
||||
|
||||
document_parts.append("</DOCUMENT>")
|
||||
combined_document_string = '\n'.join(document_parts)
|
||||
|
||||
# Generate summary
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = config.embedding_model_instance.embed(summary_content)
|
||||
|
||||
# Process chunks
|
||||
chunks = [
|
||||
Chunk(content=chunk.text, embedding=chunk.embedding)
|
||||
for chunk in config.chunker_instance.chunk(channel_content)
|
||||
]
|
||||
|
||||
# Create and store document
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=f"Slack - {channel_name}",
|
||||
document_type=DocumentType.SLACK_CONNECTOR,
|
||||
document_metadata={
|
||||
"channel_name": channel_name,
|
||||
"channel_id": channel_id,
|
||||
"start_date": start_date_str,
|
||||
"end_date": end_date_str,
|
||||
"message_count": len(formatted_messages),
|
||||
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
},
|
||||
content=summary_content,
|
||||
embedding=summary_embedding,
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
session.add(document)
|
||||
documents_indexed += 1
|
||||
logger.info(f"Successfully indexed channel {channel_name} with {len(formatted_messages)} messages")
|
||||
|
||||
except SlackApiError as slack_error:
|
||||
logger.error(f"Slack API error for channel {channel_name}: {str(slack_error)}")
|
||||
skipped_channels.append(f"{channel_name} (Slack API error)")
|
||||
continue # Skip this channel and continue with others
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing channel {channel_name}: {str(e)}")
|
||||
skipped_channels.append(f"{channel_name} (processing error)")
|
||||
continue # Skip this channel and continue with others
|
||||
|
||||
# Update the last_indexed_at timestamp for the connector only if requested
|
||||
# and if we successfully indexed at least one channel
|
||||
if update_last_indexed and documents_indexed > 0:
|
||||
connector.last_indexed_at = datetime.now()
|
||||
|
||||
# Commit all changes
|
||||
await session.commit()
|
||||
|
||||
# Prepare result message
|
||||
result_message = None
|
||||
if skipped_channels:
|
||||
result_message = f"Indexed {documents_indexed} channels. Skipped {len(skipped_channels)} channels: {', '.join(skipped_channels)}"
|
||||
|
||||
return documents_indexed, result_message
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
logger.error(f"Database error: {str(db_error)}")
|
||||
return 0, f"Database error: {str(db_error)}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Failed to index Slack messages: {str(e)}")
|
||||
return 0, f"Failed to index Slack messages: {str(e)}"
|
||||
|
||||
async def index_notion_pages(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
update_last_indexed: bool = True
|
||||
) -> Tuple[int, Optional[str]]:
|
||||
"""
|
||||
Index Notion pages from all accessible pages.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Notion connector
|
||||
search_space_id: ID of the search space to store documents in
|
||||
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
|
||||
|
||||
Returns:
|
||||
Tuple containing (number of documents indexed, error message or None)
|
||||
"""
|
||||
try:
|
||||
# Get the connector
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
return 0, f"Connector with ID {connector_id} not found or is not a Notion connector"
|
||||
|
||||
# Get the Notion token from the connector config
|
||||
notion_token = connector.config.get("NOTION_INTEGRATION_TOKEN")
|
||||
if not notion_token:
|
||||
return 0, "Notion integration token not found in connector config"
|
||||
|
||||
# Initialize Notion client
|
||||
logger.info(f"Initializing Notion client for connector {connector_id}")
|
||||
notion_client = NotionHistoryConnector(token=notion_token)
|
||||
|
||||
# Calculate date range
|
||||
end_date = datetime.now()
|
||||
|
||||
# Use last_indexed_at as start date if available, otherwise use 365 days ago
|
||||
if connector.last_indexed_at:
|
||||
# Check if last_indexed_at is today
|
||||
today = datetime.now().date()
|
||||
if connector.last_indexed_at.date() == today:
|
||||
# If last indexed today, go back 1 day to ensure we don't miss anything
|
||||
start_date = end_date - timedelta(days=1)
|
||||
else:
|
||||
start_date = connector.last_indexed_at
|
||||
else:
|
||||
start_date = end_date - timedelta(days=365)
|
||||
|
||||
# Format dates for Notion API (ISO format)
|
||||
start_date_str = start_date.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
end_date_str = end_date.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
logger.info(f"Fetching Notion pages from {start_date_str} to {end_date_str}")
|
||||
|
||||
# Get all pages
|
||||
try:
|
||||
pages = notion_client.get_all_pages(start_date=start_date_str, end_date=end_date_str)
|
||||
logger.info(f"Found {len(pages)} Notion pages")
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Notion pages: {str(e)}", exc_info=True)
|
||||
return 0, f"Failed to get Notion pages: {str(e)}"
|
||||
|
||||
if not pages:
|
||||
logger.info("No Notion pages found to index")
|
||||
return 0, "No Notion pages found"
|
||||
|
||||
# Track the number of documents indexed
|
||||
documents_indexed = 0
|
||||
skipped_pages = []
|
||||
|
||||
# Process each page
|
||||
for page in pages:
|
||||
try:
|
||||
page_id = page.get("page_id")
|
||||
page_title = page.get("title", f"Untitled page ({page_id})")
|
||||
page_content = page.get("content", [])
|
||||
|
||||
logger.info(f"Processing Notion page: {page_title} ({page_id})")
|
||||
|
||||
if not page_content:
|
||||
logger.info(f"No content found in page {page_title}. Skipping.")
|
||||
skipped_pages.append(f"{page_title} (no content)")
|
||||
continue
|
||||
|
||||
# Convert page content to markdown format
|
||||
markdown_content = f"# Notion Page: {page_title}\n\n"
|
||||
|
||||
# Process blocks recursively
|
||||
def process_blocks(blocks, level=0):
|
||||
result = ""
|
||||
for block in blocks:
|
||||
block_type = block.get("type")
|
||||
block_content = block.get("content", "")
|
||||
children = block.get("children", [])
|
||||
|
||||
# Add indentation based on level
|
||||
indent = " " * level
|
||||
|
||||
# Format based on block type
|
||||
if block_type in ["paragraph", "text"]:
|
||||
result += f"{indent}{block_content}\n\n"
|
||||
elif block_type in ["heading_1", "header"]:
|
||||
result += f"{indent}# {block_content}\n\n"
|
||||
elif block_type == "heading_2":
|
||||
result += f"{indent}## {block_content}\n\n"
|
||||
elif block_type == "heading_3":
|
||||
result += f"{indent}### {block_content}\n\n"
|
||||
elif block_type == "bulleted_list_item":
|
||||
result += f"{indent}* {block_content}\n"
|
||||
elif block_type == "numbered_list_item":
|
||||
result += f"{indent}1. {block_content}\n"
|
||||
elif block_type == "to_do":
|
||||
result += f"{indent}- [ ] {block_content}\n"
|
||||
elif block_type == "toggle":
|
||||
result += f"{indent}> {block_content}\n"
|
||||
elif block_type == "code":
|
||||
result += f"{indent}```\n{block_content}\n```\n\n"
|
||||
elif block_type == "quote":
|
||||
result += f"{indent}> {block_content}\n\n"
|
||||
elif block_type == "callout":
|
||||
result += f"{indent}> **Note:** {block_content}\n\n"
|
||||
elif block_type == "image":
|
||||
result += f"{indent}\n\n"
|
||||
else:
|
||||
# Default for other block types
|
||||
if block_content:
|
||||
result += f"{indent}{block_content}\n\n"
|
||||
|
||||
# Process children recursively
|
||||
if children:
|
||||
result += process_blocks(children, level + 1)
|
||||
|
||||
return result
|
||||
|
||||
logger.debug(f"Converting {len(page_content)} blocks to markdown for page {page_title}")
|
||||
markdown_content += process_blocks(page_content)
|
||||
|
||||
# Format document metadata
|
||||
metadata_sections = [
|
||||
("METADATA", [
|
||||
f"PAGE_TITLE: {page_title}",
|
||||
f"PAGE_ID: {page_id}",
|
||||
f"INDEXED_AT: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
]),
|
||||
("CONTENT", [
|
||||
"FORMAT: markdown",
|
||||
"TEXT_START",
|
||||
markdown_content,
|
||||
"TEXT_END"
|
||||
])
|
||||
]
|
||||
|
||||
# Build the document string
|
||||
document_parts = []
|
||||
document_parts.append("<DOCUMENT>")
|
||||
|
||||
for section_title, section_content in metadata_sections:
|
||||
document_parts.append(f"<{section_title}>")
|
||||
document_parts.extend(section_content)
|
||||
document_parts.append(f"</{section_title}>")
|
||||
|
||||
document_parts.append("</DOCUMENT>")
|
||||
combined_document_string = '\n'.join(document_parts)
|
||||
|
||||
# Generate summary
|
||||
logger.debug(f"Generating summary for page {page_title}")
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = config.embedding_model_instance.embed(summary_content)
|
||||
|
||||
# Process chunks
|
||||
logger.debug(f"Chunking content for page {page_title}")
|
||||
chunks = [
|
||||
Chunk(content=chunk.text, embedding=chunk.embedding)
|
||||
for chunk in config.chunker_instance.chunk(markdown_content)
|
||||
]
|
||||
|
||||
# Create and store document
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=f"Notion - {page_title}",
|
||||
document_type=DocumentType.NOTION_CONNECTOR,
|
||||
document_metadata={
|
||||
"page_title": page_title,
|
||||
"page_id": page_id,
|
||||
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
},
|
||||
content=summary_content,
|
||||
embedding=summary_embedding,
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
session.add(document)
|
||||
documents_indexed += 1
|
||||
logger.info(f"Successfully indexed Notion page: {page_title}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Notion page {page.get('title', 'Unknown')}: {str(e)}", exc_info=True)
|
||||
skipped_pages.append(f"{page.get('title', 'Unknown')} (processing error)")
|
||||
continue # Skip this page and continue with others
|
||||
|
||||
# Update the last_indexed_at timestamp for the connector only if requested
|
||||
# and if we successfully indexed at least one page
|
||||
if update_last_indexed and documents_indexed > 0:
|
||||
connector.last_indexed_at = datetime.now()
|
||||
logger.info(f"Updated last_indexed_at for connector {connector_id}")
|
||||
|
||||
# Commit all changes
|
||||
await session.commit()
|
||||
|
||||
# Prepare result message
|
||||
result_message = None
|
||||
if skipped_pages:
|
||||
result_message = f"Indexed {documents_indexed} pages. Skipped {len(skipped_pages)} pages: {', '.join(skipped_pages)}"
|
||||
|
||||
logger.info(f"Notion indexing completed: {documents_indexed} pages indexed, {len(skipped_pages)} pages skipped")
|
||||
return documents_indexed, result_message
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
logger.error(f"Database error during Notion indexing: {str(db_error)}", exc_info=True)
|
||||
return 0, f"Database error: {str(db_error)}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Failed to index Notion pages: {str(e)}", exc_info=True)
|
||||
return 0, f"Failed to index Notion pages: {str(e)}"
|
||||
340
surfsense_backend/app/tasks/stream_connector_search_results.py
Normal file
340
surfsense_backend/app/tasks/stream_connector_search_results.py
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
import json
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, AsyncGenerator, Dict, Any
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from app.utils.connector_service import ConnectorService
|
||||
from app.utils.research_service import ResearchService
|
||||
from app.utils.streaming_service import StreamingService
|
||||
from app.utils.reranker_service import RerankerService
|
||||
from app.config import config
|
||||
from app.utils.document_converters import convert_chunks_to_langchain_documents
|
||||
|
||||
async def stream_connector_search_results(
|
||||
user_query: str,
|
||||
user_id: int,
|
||||
search_space_id: int,
|
||||
session: AsyncSession,
|
||||
research_mode: str,
|
||||
selected_connectors: List[str]
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream connector search results to the client
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
user_id: The user's ID
|
||||
search_space_id: The search space ID
|
||||
session: The database session
|
||||
research_mode: The research mode
|
||||
selected_connectors: List of selected connectors
|
||||
|
||||
Yields:
|
||||
str: Formatted response strings
|
||||
"""
|
||||
# Initialize services
|
||||
connector_service = ConnectorService(session)
|
||||
streaming_service = StreamingService()
|
||||
|
||||
|
||||
reranker_service = RerankerService.get_reranker_instance(config)
|
||||
|
||||
all_raw_documents = [] # Store all raw documents before reranking
|
||||
all_sources = []
|
||||
TOP_K = 20
|
||||
|
||||
if research_mode == "GENERAL":
|
||||
TOP_K = 20
|
||||
elif research_mode == "DEEP":
|
||||
TOP_K = 40
|
||||
elif research_mode == "DEEPER":
|
||||
TOP_K = 60
|
||||
|
||||
|
||||
# Process each selected connector
|
||||
for connector in selected_connectors:
|
||||
# Crawled URLs
|
||||
if connector == "CRAWLED_URL":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search for crawled URLs...")
|
||||
|
||||
# Search for crawled URLs
|
||||
result_object, crawled_urls_chunks = await connector_service.search_crawled_urls(
|
||||
user_query=user_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant crawled URLs",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(crawled_urls_chunks)
|
||||
|
||||
|
||||
# Files
|
||||
if connector == "FILE":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search for files...")
|
||||
|
||||
# Search for files
|
||||
result_object, files_chunks = await connector_service.search_files(
|
||||
user_query=user_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant files",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(files_chunks)
|
||||
|
||||
# Tavily Connector
|
||||
if connector == "TAVILY_API":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search with Tavily API...")
|
||||
|
||||
# Search using Tavily API
|
||||
result_object, tavily_chunks = await connector_service.search_tavily(
|
||||
user_query=user_query,
|
||||
user_id=user_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant results from Tavily",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(tavily_chunks)
|
||||
|
||||
# Slack Connector
|
||||
if connector == "SLACK_CONNECTOR":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search for slack connector...")
|
||||
|
||||
# Search using Slack API
|
||||
result_object, slack_chunks = await connector_service.search_slack(
|
||||
user_query=user_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant results from Slack",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(slack_chunks)
|
||||
|
||||
|
||||
# Notion Connector
|
||||
if connector == "NOTION_CONNECTOR":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search for notion connector...")
|
||||
|
||||
# Search using Notion API
|
||||
result_object, notion_chunks = await connector_service.search_notion(
|
||||
user_query=user_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant results from Notion",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(notion_chunks)
|
||||
|
||||
|
||||
|
||||
|
||||
# If we have documents to research
|
||||
if all_raw_documents:
|
||||
# Rerank all documents if reranker is available
|
||||
if reranker_service:
|
||||
yield streaming_service.add_terminal_message("Reranking documents for better relevance...", "info")
|
||||
|
||||
# Convert documents to format expected by reranker
|
||||
reranker_input_docs = [
|
||||
{
|
||||
"chunk_id": doc.get("chunk_id", f"chunk_{i}"),
|
||||
"content": doc.get("content", ""),
|
||||
"score": doc.get("score", 0.0),
|
||||
"document": {
|
||||
"id": doc.get("document", {}).get("id", ""),
|
||||
"title": doc.get("document", {}).get("title", ""),
|
||||
"document_type": doc.get("document", {}).get("document_type", ""),
|
||||
"metadata": doc.get("document", {}).get("metadata", {})
|
||||
}
|
||||
} for i, doc in enumerate(all_raw_documents)
|
||||
]
|
||||
|
||||
# Rerank documents
|
||||
reranked_docs = reranker_service.rerank_documents(user_query, reranker_input_docs)
|
||||
|
||||
# Sort by score in descending order
|
||||
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||||
|
||||
|
||||
|
||||
# Convert back to langchain documents format
|
||||
from langchain.schema import Document as LangchainDocument
|
||||
all_langchain_documents_to_research = [
|
||||
LangchainDocument(
|
||||
page_content= f"""<document><metadata><source_id>{doc.get("document", {}).get("id", "")}</source_id></metadata><content>{doc.get("content", "")}</content></document>""",
|
||||
metadata={
|
||||
# **doc.get("document", {}).get("metadata", {}),
|
||||
# "score": doc.get("score", 0.0),
|
||||
# "rank": doc.get("rank", 0),
|
||||
# "document_id": doc.get("document", {}).get("id", ""),
|
||||
# "document_title": doc.get("document", {}).get("title", ""),
|
||||
# "document_type": doc.get("document", {}).get("document_type", ""),
|
||||
# # Explicitly set source_id for citation purposes
|
||||
"source_id": str(doc.get("document", {}).get("id", ""))
|
||||
}
|
||||
) for doc in reranked_docs
|
||||
]
|
||||
|
||||
yield streaming_service.add_terminal_message(f"Reranked {len(all_langchain_documents_to_research)} documents", "success")
|
||||
else:
|
||||
# Use raw documents if no reranker is available
|
||||
all_langchain_documents_to_research = convert_chunks_to_langchain_documents(all_raw_documents)
|
||||
|
||||
# Send terminal message about starting research
|
||||
yield streaming_service.add_terminal_message("Starting to research...", "info")
|
||||
|
||||
# Create a buffer to collect report content
|
||||
report_buffer = []
|
||||
|
||||
|
||||
# Use the streaming research method
|
||||
yield streaming_service.add_terminal_message("Generating report...", "info")
|
||||
|
||||
# Create a wrapper to handle the streaming
|
||||
class StreamHandler:
|
||||
def __init__(self):
|
||||
self.queue = asyncio.Queue()
|
||||
|
||||
async def handle_progress(self, data):
|
||||
result = None
|
||||
if data.get("type") == "logs":
|
||||
# Handle log messages
|
||||
result = streaming_service.add_terminal_message(data.get("output", ""), "info")
|
||||
elif data.get("type") == "report":
|
||||
# Handle report content
|
||||
content = data.get("output", "")
|
||||
|
||||
# Fix incorrect citation formats using regex
|
||||
|
||||
# More specific pattern to match only numeric citations in markdown-style links
|
||||
# This matches patterns like ([1](https://github.com/...)) but not general links like ([Click here](https://...))
|
||||
pattern = r'\(\[(\d+)\]\((https?://[^\)]+)\)\)'
|
||||
|
||||
# Replace with just [X] where X is the number
|
||||
content = re.sub(pattern, r'[\1]', content)
|
||||
|
||||
# Also match other incorrect formats like ([1]) and convert to [1]
|
||||
# Only match if the content inside brackets is a number
|
||||
content = re.sub(r'\(\[(\d+)\]\)', r'[\1]', content)
|
||||
|
||||
report_buffer.append(content)
|
||||
# Update the answer with the accumulated content
|
||||
result = streaming_service.update_answer(report_buffer)
|
||||
|
||||
if result:
|
||||
await self.queue.put(result)
|
||||
return result
|
||||
|
||||
async def get_next(self):
|
||||
try:
|
||||
return await self.queue.get()
|
||||
except Exception as e:
|
||||
print(f"Error getting next item from queue: {e}")
|
||||
return None
|
||||
|
||||
def task_done(self):
|
||||
self.queue.task_done()
|
||||
|
||||
# Create the stream handler
|
||||
stream_handler = StreamHandler()
|
||||
|
||||
# Start the research process in a separate task
|
||||
research_task = asyncio.create_task(
|
||||
ResearchService.stream_research(
|
||||
user_query=user_query,
|
||||
documents=all_langchain_documents_to_research,
|
||||
on_progress=stream_handler.handle_progress,
|
||||
research_mode=research_mode
|
||||
)
|
||||
)
|
||||
|
||||
# Stream results as they become available
|
||||
while not research_task.done() or not stream_handler.queue.empty():
|
||||
try:
|
||||
# Get the next result with a timeout
|
||||
result = await asyncio.wait_for(stream_handler.get_next(), timeout=0.1)
|
||||
stream_handler.task_done()
|
||||
yield result
|
||||
except asyncio.TimeoutError:
|
||||
# No result available yet, check if the research task is done
|
||||
if research_task.done():
|
||||
# If the queue is empty and the task is done, we're finished
|
||||
if stream_handler.queue.empty():
|
||||
break
|
||||
|
||||
# Get the final report
|
||||
try:
|
||||
final_report = await research_task
|
||||
|
||||
# Send terminal message about research completion
|
||||
yield streaming_service.add_terminal_message("Research completed", "success")
|
||||
|
||||
# Update the answer with the final report
|
||||
final_report_lines = final_report.split('\n')
|
||||
yield streaming_service.update_answer(final_report_lines)
|
||||
except Exception as e:
|
||||
# Handle any exceptions
|
||||
yield streaming_service.add_terminal_message(f"Error during research: {str(e)}", "error")
|
||||
|
||||
# Send completion message
|
||||
yield streaming_service.format_completion()
|
||||
Loading…
Add table
Add a link
Reference in a new issue