feat: SurfSense v0.0.6 init

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2025-03-14 18:53:14 -07:00
parent 18fc19e8d9
commit da23012970
58 changed files with 8284 additions and 2076 deletions

View file

View file

@ -0,0 +1,246 @@
from typing import Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import SQLAlchemyError
from app.db import Document, DocumentType, Chunk
from app.schemas import ExtensionDocumentContent
from app.config import config
from app.prompts import SUMMARY_PROMPT_TEMPLATE
from datetime import datetime
from app.utils.document_converters import convert_document_to_markdown
from langchain_core.documents import Document as LangChainDocument
from langchain_community.document_loaders import FireCrawlLoader, AsyncChromiumLoader
from langchain_community.document_transformers import MarkdownifyTransformer
import validators
md = MarkdownifyTransformer()
async def add_crawled_url_document(
session: AsyncSession,
url: str,
search_space_id: int
) -> Optional[Document]:
try:
if not validators.url(url):
raise ValueError(f"Url {url} is not a valid URL address")
if config.FIRECRAWL_API_KEY:
crawl_loader = FireCrawlLoader(
url=url,
api_key=config.FIRECRAWL_API_KEY,
mode="scrape",
params={
"formats": ["markdown"],
"excludeTags": ["a"],
}
)
else:
crawl_loader = AsyncChromiumLoader(urls=[url], headless=True)
url_crawled = await crawl_loader.aload()
if type(crawl_loader) == FireCrawlLoader:
content_in_markdown = url_crawled[0].page_content
elif type(crawl_loader) == AsyncChromiumLoader:
content_in_markdown = md.transform_documents(url_crawled)[
0].page_content
# Format document metadata in a more maintainable way
metadata_sections = [
("METADATA", [
f"{key.upper()}: {value}" for key, value in url_crawled[0].metadata.items()
]),
("CONTENT", [
"FORMAT: markdown",
"TEXT_START",
content_in_markdown,
"TEXT_END"
])
]
# Build the document string more efficiently
document_parts = []
document_parts.append("<DOCUMENT>")
for section_title, section_content in metadata_sections:
document_parts.append(f"<{section_title}>")
document_parts.extend(section_content)
document_parts.append(f"</{section_title}>")
document_parts.append("</DOCUMENT>")
combined_document_string = '\n'.join(document_parts)
# Generate summary
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
summary_content = summary_result.content
summary_embedding = config.embedding_model_instance.embed(
summary_content)
# Process chunks
chunks = [
Chunk(content=chunk.text, embedding=chunk.embedding)
for chunk in config.chunker_instance.chunk(content_in_markdown)
]
# Create and store document
document = Document(
search_space_id=search_space_id,
title=url_crawled[0].metadata['title'] if type(
crawl_loader) == FireCrawlLoader else url_crawled[0].metadata['source'],
document_type=DocumentType.CRAWLED_URL,
document_metadata=url_crawled[0].metadata,
content=summary_content,
embedding=summary_embedding,
chunks=chunks
)
session.add(document)
await session.commit()
await session.refresh(document)
return document
except SQLAlchemyError as db_error:
await session.rollback()
raise db_error
except Exception as e:
await session.rollback()
raise RuntimeError(f"Failed to crawl URL: {str(e)}")
async def add_extension_received_document(
session: AsyncSession,
content: ExtensionDocumentContent,
search_space_id: int
) -> Optional[Document]:
"""
Process and store document content received from the SurfSense Extension.
Args:
session: Database session
content: Document content from extension
search_space_id: ID of the search space
Returns:
Document object if successful, None if failed
"""
try:
# Format document metadata in a more maintainable way
metadata_sections = [
("METADATA", [
f"SESSION_ID: {content.metadata.BrowsingSessionId}",
f"URL: {content.metadata.VisitedWebPageURL}",
f"TITLE: {content.metadata.VisitedWebPageTitle}",
f"REFERRER: {content.metadata.VisitedWebPageReffererURL}",
f"TIMESTAMP: {content.metadata.VisitedWebPageDateWithTimeInISOString}",
f"DURATION_MS: {content.metadata.VisitedWebPageVisitDurationInMilliseconds}"
]),
("CONTENT", [
"FORMAT: markdown",
"TEXT_START",
content.pageContent,
"TEXT_END"
])
]
# Build the document string more efficiently
document_parts = []
document_parts.append("<DOCUMENT>")
for section_title, section_content in metadata_sections:
document_parts.append(f"<{section_title}>")
document_parts.extend(section_content)
document_parts.append(f"</{section_title}>")
document_parts.append("</DOCUMENT>")
combined_document_string = '\n'.join(document_parts)
# Generate summary
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
summary_content = summary_result.content
summary_embedding = config.embedding_model_instance.embed(
summary_content)
# Process chunks
chunks = [
Chunk(content=chunk.text, embedding=chunk.embedding)
for chunk in config.chunker_instance.chunk(content.pageContent)
]
# Create and store document
document = Document(
search_space_id=search_space_id,
title=content.metadata.VisitedWebPageTitle,
document_type=DocumentType.EXTENSION,
document_metadata=content.metadata.model_dump(),
content=summary_content,
embedding=summary_embedding,
chunks=chunks
)
session.add(document)
await session.commit()
await session.refresh(document)
return document
except SQLAlchemyError as db_error:
await session.rollback()
raise db_error
except Exception as e:
await session.rollback()
raise RuntimeError(f"Failed to process extension document: {str(e)}")
async def add_received_file_document(
session: AsyncSession,
file_name: str,
unstructured_processed_elements: List[LangChainDocument],
search_space_id: int
) -> Optional[Document]:
try:
file_in_markdown = await convert_document_to_markdown(unstructured_processed_elements)
# TODO: Check if file_markdown exceeds token limit of embedding model
# Generate summary
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
summary_result = await summary_chain.ainvoke({"document": file_in_markdown})
summary_content = summary_result.content
summary_embedding = config.embedding_model_instance.embed(
summary_content)
# Process chunks
chunks = [
Chunk(content=chunk.text, embedding=chunk.embedding)
for chunk in config.chunker_instance.chunk(file_in_markdown)
]
# Create and store document
document = Document(
search_space_id=search_space_id,
title=file_name,
document_type=DocumentType.FILE,
document_metadata={
"FILE_NAME": file_name,
"SAVED_AT": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
},
content=summary_content,
embedding=summary_embedding,
chunks=chunks
)
session.add(document)
await session.commit()
await session.refresh(document)
return document
except SQLAlchemyError as db_error:
await session.rollback()
raise db_error
except Exception as e:
await session.rollback()
raise RuntimeError(f"Failed to process file document: {str(e)}")

View file

@ -0,0 +1,486 @@
from typing import Optional, List, Dict, Any, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.future import select
from datetime import datetime, timedelta
from app.db import Document, DocumentType, Chunk, SearchSourceConnector, SearchSourceConnectorType
from app.config import config
from app.prompts import SUMMARY_PROMPT_TEMPLATE
from app.connectors.slack_history import SlackHistory
from app.connectors.notion_history import NotionHistoryConnector
from slack_sdk.errors import SlackApiError
import logging
# Set up logging
logger = logging.getLogger(__name__)
async def index_slack_messages(
session: AsyncSession,
connector_id: int,
search_space_id: int,
update_last_indexed: bool = True
) -> Tuple[int, Optional[str]]:
"""
Index Slack messages from all accessible channels.
Args:
session: Database session
connector_id: ID of the Slack connector
search_space_id: ID of the search space to store documents in
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
Returns:
Tuple containing (number of documents indexed, error message or None)
"""
try:
# Get the connector
result = await session.execute(
select(SearchSourceConnector)
.filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR
)
)
connector = result.scalars().first()
if not connector:
return 0, f"Connector with ID {connector_id} not found or is not a Slack connector"
# Get the Slack token from the connector config
slack_token = connector.config.get("SLACK_BOT_TOKEN")
if not slack_token:
return 0, "Slack token not found in connector config"
# Initialize Slack client
slack_client = SlackHistory(token=slack_token)
# Calculate date range
end_date = datetime.now()
# Use last_indexed_at as start date if available, otherwise use 365 days ago
if connector.last_indexed_at:
# Check if last_indexed_at is today
today = datetime.now().date()
if connector.last_indexed_at.date() == today:
# If last indexed today, go back 1 day to ensure we don't miss anything
start_date = end_date - timedelta(days=7)
else:
start_date = connector.last_indexed_at
else:
start_date = end_date - timedelta(days=365)
# Format dates for Slack API
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
# Get all channels
try:
channels = slack_client.get_all_channels()
except Exception as e:
return 0, f"Failed to get Slack channels: {str(e)}"
if not channels:
return 0, "No Slack channels found"
# Track the number of documents indexed
documents_indexed = 0
skipped_channels = []
# Process each channel
for channel_name, channel_id in channels.items():
try:
# Check if the bot is a member of the channel
try:
# First try to get channel info to check if bot is a member
channel_info = slack_client.client.conversations_info(channel=channel_id)
# For private channels, the bot needs to be a member
if channel_info.get("channel", {}).get("is_private", False):
# Check if bot is a member
is_member = channel_info.get("channel", {}).get("is_member", False)
if not is_member:
logger.warning(f"Bot is not a member of private channel {channel_name} ({channel_id}). Skipping.")
skipped_channels.append(f"{channel_name} (private, bot not a member)")
continue
except SlackApiError as e:
if "not_in_channel" in str(e) or "channel_not_found" in str(e):
logger.warning(f"Bot cannot access channel {channel_name} ({channel_id}). Skipping.")
skipped_channels.append(f"{channel_name} (access error)")
continue
else:
# Re-raise if it's a different error
raise
# Get messages for this channel
messages, error = slack_client.get_history_by_date_range(
channel_id=channel_id,
start_date=start_date_str,
end_date=end_date_str,
limit=1000 # Limit to 1000 messages per channel
)
if error:
logger.warning(f"Error getting messages from channel {channel_name}: {error}")
skipped_channels.append(f"{channel_name} (error: {error})")
continue # Skip this channel if there's an error
if not messages:
logger.info(f"No messages found in channel {channel_name} for the specified date range.")
continue # Skip if no messages
# Format messages with user info
formatted_messages = []
for msg in messages:
# Skip bot messages and system messages
if msg.get("subtype") in ["bot_message", "channel_join", "channel_leave"]:
continue
formatted_msg = slack_client.format_message(msg, include_user_info=True)
formatted_messages.append(formatted_msg)
if not formatted_messages:
logger.info(f"No valid messages found in channel {channel_name} after filtering.")
continue # Skip if no valid messages after filtering
# Convert messages to markdown format
channel_content = f"# Slack Channel: {channel_name}\n\n"
for msg in formatted_messages:
user_name = msg.get("user_name", "Unknown User")
timestamp = msg.get("datetime", "Unknown Time")
text = msg.get("text", "")
channel_content += f"## {user_name} ({timestamp})\n\n{text}\n\n---\n\n"
# Format document metadata
metadata_sections = [
("METADATA", [
f"CHANNEL_NAME: {channel_name}",
f"CHANNEL_ID: {channel_id}",
f"START_DATE: {start_date_str}",
f"END_DATE: {end_date_str}",
f"MESSAGE_COUNT: {len(formatted_messages)}",
f"INDEXED_AT: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
]),
("CONTENT", [
"FORMAT: markdown",
"TEXT_START",
channel_content,
"TEXT_END"
])
]
# Build the document string
document_parts = []
document_parts.append("<DOCUMENT>")
for section_title, section_content in metadata_sections:
document_parts.append(f"<{section_title}>")
document_parts.extend(section_content)
document_parts.append(f"</{section_title}>")
document_parts.append("</DOCUMENT>")
combined_document_string = '\n'.join(document_parts)
# Generate summary
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
summary_content = summary_result.content
summary_embedding = config.embedding_model_instance.embed(summary_content)
# Process chunks
chunks = [
Chunk(content=chunk.text, embedding=chunk.embedding)
for chunk in config.chunker_instance.chunk(channel_content)
]
# Create and store document
document = Document(
search_space_id=search_space_id,
title=f"Slack - {channel_name}",
document_type=DocumentType.SLACK_CONNECTOR,
document_metadata={
"channel_name": channel_name,
"channel_id": channel_id,
"start_date": start_date_str,
"end_date": end_date_str,
"message_count": len(formatted_messages),
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
},
content=summary_content,
embedding=summary_embedding,
chunks=chunks
)
session.add(document)
documents_indexed += 1
logger.info(f"Successfully indexed channel {channel_name} with {len(formatted_messages)} messages")
except SlackApiError as slack_error:
logger.error(f"Slack API error for channel {channel_name}: {str(slack_error)}")
skipped_channels.append(f"{channel_name} (Slack API error)")
continue # Skip this channel and continue with others
except Exception as e:
logger.error(f"Error processing channel {channel_name}: {str(e)}")
skipped_channels.append(f"{channel_name} (processing error)")
continue # Skip this channel and continue with others
# Update the last_indexed_at timestamp for the connector only if requested
# and if we successfully indexed at least one channel
if update_last_indexed and documents_indexed > 0:
connector.last_indexed_at = datetime.now()
# Commit all changes
await session.commit()
# Prepare result message
result_message = None
if skipped_channels:
result_message = f"Indexed {documents_indexed} channels. Skipped {len(skipped_channels)} channels: {', '.join(skipped_channels)}"
return documents_indexed, result_message
except SQLAlchemyError as db_error:
await session.rollback()
logger.error(f"Database error: {str(db_error)}")
return 0, f"Database error: {str(db_error)}"
except Exception as e:
await session.rollback()
logger.error(f"Failed to index Slack messages: {str(e)}")
return 0, f"Failed to index Slack messages: {str(e)}"
async def index_notion_pages(
session: AsyncSession,
connector_id: int,
search_space_id: int,
update_last_indexed: bool = True
) -> Tuple[int, Optional[str]]:
"""
Index Notion pages from all accessible pages.
Args:
session: Database session
connector_id: ID of the Notion connector
search_space_id: ID of the search space to store documents in
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
Returns:
Tuple containing (number of documents indexed, error message or None)
"""
try:
# Get the connector
result = await session.execute(
select(SearchSourceConnector)
.filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR
)
)
connector = result.scalars().first()
if not connector:
return 0, f"Connector with ID {connector_id} not found or is not a Notion connector"
# Get the Notion token from the connector config
notion_token = connector.config.get("NOTION_INTEGRATION_TOKEN")
if not notion_token:
return 0, "Notion integration token not found in connector config"
# Initialize Notion client
logger.info(f"Initializing Notion client for connector {connector_id}")
notion_client = NotionHistoryConnector(token=notion_token)
# Calculate date range
end_date = datetime.now()
# Use last_indexed_at as start date if available, otherwise use 365 days ago
if connector.last_indexed_at:
# Check if last_indexed_at is today
today = datetime.now().date()
if connector.last_indexed_at.date() == today:
# If last indexed today, go back 1 day to ensure we don't miss anything
start_date = end_date - timedelta(days=1)
else:
start_date = connector.last_indexed_at
else:
start_date = end_date - timedelta(days=365)
# Format dates for Notion API (ISO format)
start_date_str = start_date.strftime("%Y-%m-%dT%H:%M:%SZ")
end_date_str = end_date.strftime("%Y-%m-%dT%H:%M:%SZ")
logger.info(f"Fetching Notion pages from {start_date_str} to {end_date_str}")
# Get all pages
try:
pages = notion_client.get_all_pages(start_date=start_date_str, end_date=end_date_str)
logger.info(f"Found {len(pages)} Notion pages")
except Exception as e:
logger.error(f"Error fetching Notion pages: {str(e)}", exc_info=True)
return 0, f"Failed to get Notion pages: {str(e)}"
if not pages:
logger.info("No Notion pages found to index")
return 0, "No Notion pages found"
# Track the number of documents indexed
documents_indexed = 0
skipped_pages = []
# Process each page
for page in pages:
try:
page_id = page.get("page_id")
page_title = page.get("title", f"Untitled page ({page_id})")
page_content = page.get("content", [])
logger.info(f"Processing Notion page: {page_title} ({page_id})")
if not page_content:
logger.info(f"No content found in page {page_title}. Skipping.")
skipped_pages.append(f"{page_title} (no content)")
continue
# Convert page content to markdown format
markdown_content = f"# Notion Page: {page_title}\n\n"
# Process blocks recursively
def process_blocks(blocks, level=0):
result = ""
for block in blocks:
block_type = block.get("type")
block_content = block.get("content", "")
children = block.get("children", [])
# Add indentation based on level
indent = " " * level
# Format based on block type
if block_type in ["paragraph", "text"]:
result += f"{indent}{block_content}\n\n"
elif block_type in ["heading_1", "header"]:
result += f"{indent}# {block_content}\n\n"
elif block_type == "heading_2":
result += f"{indent}## {block_content}\n\n"
elif block_type == "heading_3":
result += f"{indent}### {block_content}\n\n"
elif block_type == "bulleted_list_item":
result += f"{indent}* {block_content}\n"
elif block_type == "numbered_list_item":
result += f"{indent}1. {block_content}\n"
elif block_type == "to_do":
result += f"{indent}- [ ] {block_content}\n"
elif block_type == "toggle":
result += f"{indent}> {block_content}\n"
elif block_type == "code":
result += f"{indent}```\n{block_content}\n```\n\n"
elif block_type == "quote":
result += f"{indent}> {block_content}\n\n"
elif block_type == "callout":
result += f"{indent}> **Note:** {block_content}\n\n"
elif block_type == "image":
result += f"{indent}![Image]({block_content})\n\n"
else:
# Default for other block types
if block_content:
result += f"{indent}{block_content}\n\n"
# Process children recursively
if children:
result += process_blocks(children, level + 1)
return result
logger.debug(f"Converting {len(page_content)} blocks to markdown for page {page_title}")
markdown_content += process_blocks(page_content)
# Format document metadata
metadata_sections = [
("METADATA", [
f"PAGE_TITLE: {page_title}",
f"PAGE_ID: {page_id}",
f"INDEXED_AT: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
]),
("CONTENT", [
"FORMAT: markdown",
"TEXT_START",
markdown_content,
"TEXT_END"
])
]
# Build the document string
document_parts = []
document_parts.append("<DOCUMENT>")
for section_title, section_content in metadata_sections:
document_parts.append(f"<{section_title}>")
document_parts.extend(section_content)
document_parts.append(f"</{section_title}>")
document_parts.append("</DOCUMENT>")
combined_document_string = '\n'.join(document_parts)
# Generate summary
logger.debug(f"Generating summary for page {page_title}")
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
summary_content = summary_result.content
summary_embedding = config.embedding_model_instance.embed(summary_content)
# Process chunks
logger.debug(f"Chunking content for page {page_title}")
chunks = [
Chunk(content=chunk.text, embedding=chunk.embedding)
for chunk in config.chunker_instance.chunk(markdown_content)
]
# Create and store document
document = Document(
search_space_id=search_space_id,
title=f"Notion - {page_title}",
document_type=DocumentType.NOTION_CONNECTOR,
document_metadata={
"page_title": page_title,
"page_id": page_id,
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
},
content=summary_content,
embedding=summary_embedding,
chunks=chunks
)
session.add(document)
documents_indexed += 1
logger.info(f"Successfully indexed Notion page: {page_title}")
except Exception as e:
logger.error(f"Error processing Notion page {page.get('title', 'Unknown')}: {str(e)}", exc_info=True)
skipped_pages.append(f"{page.get('title', 'Unknown')} (processing error)")
continue # Skip this page and continue with others
# Update the last_indexed_at timestamp for the connector only if requested
# and if we successfully indexed at least one page
if update_last_indexed and documents_indexed > 0:
connector.last_indexed_at = datetime.now()
logger.info(f"Updated last_indexed_at for connector {connector_id}")
# Commit all changes
await session.commit()
# Prepare result message
result_message = None
if skipped_pages:
result_message = f"Indexed {documents_indexed} pages. Skipped {len(skipped_pages)} pages: {', '.join(skipped_pages)}"
logger.info(f"Notion indexing completed: {documents_indexed} pages indexed, {len(skipped_pages)} pages skipped")
return documents_indexed, result_message
except SQLAlchemyError as db_error:
await session.rollback()
logger.error(f"Database error during Notion indexing: {str(db_error)}", exc_info=True)
return 0, f"Database error: {str(db_error)}"
except Exception as e:
await session.rollback()
logger.error(f"Failed to index Notion pages: {str(e)}", exc_info=True)
return 0, f"Failed to index Notion pages: {str(e)}"

View file

@ -0,0 +1,340 @@
import json
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, AsyncGenerator, Dict, Any
import asyncio
import re
from app.utils.connector_service import ConnectorService
from app.utils.research_service import ResearchService
from app.utils.streaming_service import StreamingService
from app.utils.reranker_service import RerankerService
from app.config import config
from app.utils.document_converters import convert_chunks_to_langchain_documents
async def stream_connector_search_results(
user_query: str,
user_id: int,
search_space_id: int,
session: AsyncSession,
research_mode: str,
selected_connectors: List[str]
) -> AsyncGenerator[str, None]:
"""
Stream connector search results to the client
Args:
user_query: The user's query
user_id: The user's ID
search_space_id: The search space ID
session: The database session
research_mode: The research mode
selected_connectors: List of selected connectors
Yields:
str: Formatted response strings
"""
# Initialize services
connector_service = ConnectorService(session)
streaming_service = StreamingService()
reranker_service = RerankerService.get_reranker_instance(config)
all_raw_documents = [] # Store all raw documents before reranking
all_sources = []
TOP_K = 20
if research_mode == "GENERAL":
TOP_K = 20
elif research_mode == "DEEP":
TOP_K = 40
elif research_mode == "DEEPER":
TOP_K = 60
# Process each selected connector
for connector in selected_connectors:
# Crawled URLs
if connector == "CRAWLED_URL":
# Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search for crawled URLs...")
# Search for crawled URLs
result_object, crawled_urls_chunks = await connector_service.search_crawled_urls(
user_query=user_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
# Send terminal message about search results
yield streaming_service.add_terminal_message(
f"Found {len(result_object['sources'])} relevant crawled URLs",
"success"
)
# Update sources
all_sources.append(result_object)
yield streaming_service.update_sources(all_sources)
# Add documents to collection
all_raw_documents.extend(crawled_urls_chunks)
# Files
if connector == "FILE":
# Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search for files...")
# Search for files
result_object, files_chunks = await connector_service.search_files(
user_query=user_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
# Send terminal message about search results
yield streaming_service.add_terminal_message(
f"Found {len(result_object['sources'])} relevant files",
"success"
)
# Update sources
all_sources.append(result_object)
yield streaming_service.update_sources(all_sources)
# Add documents to collection
all_raw_documents.extend(files_chunks)
# Tavily Connector
if connector == "TAVILY_API":
# Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search with Tavily API...")
# Search using Tavily API
result_object, tavily_chunks = await connector_service.search_tavily(
user_query=user_query,
user_id=user_id,
top_k=TOP_K
)
# Send terminal message about search results
yield streaming_service.add_terminal_message(
f"Found {len(result_object['sources'])} relevant results from Tavily",
"success"
)
# Update sources
all_sources.append(result_object)
yield streaming_service.update_sources(all_sources)
# Add documents to collection
all_raw_documents.extend(tavily_chunks)
# Slack Connector
if connector == "SLACK_CONNECTOR":
# Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search for slack connector...")
# Search using Slack API
result_object, slack_chunks = await connector_service.search_slack(
user_query=user_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
# Send terminal message about search results
yield streaming_service.add_terminal_message(
f"Found {len(result_object['sources'])} relevant results from Slack",
"success"
)
# Update sources
all_sources.append(result_object)
yield streaming_service.update_sources(all_sources)
# Add documents to collection
all_raw_documents.extend(slack_chunks)
# Notion Connector
if connector == "NOTION_CONNECTOR":
# Send terminal message about starting search
yield streaming_service.add_terminal_message("Starting to search for notion connector...")
# Search using Notion API
result_object, notion_chunks = await connector_service.search_notion(
user_query=user_query,
user_id=user_id,
search_space_id=search_space_id,
top_k=TOP_K
)
# Send terminal message about search results
yield streaming_service.add_terminal_message(
f"Found {len(result_object['sources'])} relevant results from Notion",
"success"
)
# Update sources
all_sources.append(result_object)
yield streaming_service.update_sources(all_sources)
# Add documents to collection
all_raw_documents.extend(notion_chunks)
# If we have documents to research
if all_raw_documents:
# Rerank all documents if reranker is available
if reranker_service:
yield streaming_service.add_terminal_message("Reranking documents for better relevance...", "info")
# Convert documents to format expected by reranker
reranker_input_docs = [
{
"chunk_id": doc.get("chunk_id", f"chunk_{i}"),
"content": doc.get("content", ""),
"score": doc.get("score", 0.0),
"document": {
"id": doc.get("document", {}).get("id", ""),
"title": doc.get("document", {}).get("title", ""),
"document_type": doc.get("document", {}).get("document_type", ""),
"metadata": doc.get("document", {}).get("metadata", {})
}
} for i, doc in enumerate(all_raw_documents)
]
# Rerank documents
reranked_docs = reranker_service.rerank_documents(user_query, reranker_input_docs)
# Sort by score in descending order
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
# Convert back to langchain documents format
from langchain.schema import Document as LangchainDocument
all_langchain_documents_to_research = [
LangchainDocument(
page_content= f"""<document><metadata><source_id>{doc.get("document", {}).get("id", "")}</source_id></metadata><content>{doc.get("content", "")}</content></document>""",
metadata={
# **doc.get("document", {}).get("metadata", {}),
# "score": doc.get("score", 0.0),
# "rank": doc.get("rank", 0),
# "document_id": doc.get("document", {}).get("id", ""),
# "document_title": doc.get("document", {}).get("title", ""),
# "document_type": doc.get("document", {}).get("document_type", ""),
# # Explicitly set source_id for citation purposes
"source_id": str(doc.get("document", {}).get("id", ""))
}
) for doc in reranked_docs
]
yield streaming_service.add_terminal_message(f"Reranked {len(all_langchain_documents_to_research)} documents", "success")
else:
# Use raw documents if no reranker is available
all_langchain_documents_to_research = convert_chunks_to_langchain_documents(all_raw_documents)
# Send terminal message about starting research
yield streaming_service.add_terminal_message("Starting to research...", "info")
# Create a buffer to collect report content
report_buffer = []
# Use the streaming research method
yield streaming_service.add_terminal_message("Generating report...", "info")
# Create a wrapper to handle the streaming
class StreamHandler:
def __init__(self):
self.queue = asyncio.Queue()
async def handle_progress(self, data):
result = None
if data.get("type") == "logs":
# Handle log messages
result = streaming_service.add_terminal_message(data.get("output", ""), "info")
elif data.get("type") == "report":
# Handle report content
content = data.get("output", "")
# Fix incorrect citation formats using regex
# More specific pattern to match only numeric citations in markdown-style links
# This matches patterns like ([1](https://github.com/...)) but not general links like ([Click here](https://...))
pattern = r'\(\[(\d+)\]\((https?://[^\)]+)\)\)'
# Replace with just [X] where X is the number
content = re.sub(pattern, r'[\1]', content)
# Also match other incorrect formats like ([1]) and convert to [1]
# Only match if the content inside brackets is a number
content = re.sub(r'\(\[(\d+)\]\)', r'[\1]', content)
report_buffer.append(content)
# Update the answer with the accumulated content
result = streaming_service.update_answer(report_buffer)
if result:
await self.queue.put(result)
return result
async def get_next(self):
try:
return await self.queue.get()
except Exception as e:
print(f"Error getting next item from queue: {e}")
return None
def task_done(self):
self.queue.task_done()
# Create the stream handler
stream_handler = StreamHandler()
# Start the research process in a separate task
research_task = asyncio.create_task(
ResearchService.stream_research(
user_query=user_query,
documents=all_langchain_documents_to_research,
on_progress=stream_handler.handle_progress,
research_mode=research_mode
)
)
# Stream results as they become available
while not research_task.done() or not stream_handler.queue.empty():
try:
# Get the next result with a timeout
result = await asyncio.wait_for(stream_handler.get_next(), timeout=0.1)
stream_handler.task_done()
yield result
except asyncio.TimeoutError:
# No result available yet, check if the research task is done
if research_task.done():
# If the queue is empty and the task is done, we're finished
if stream_handler.queue.empty():
break
# Get the final report
try:
final_report = await research_task
# Send terminal message about research completion
yield streaming_service.add_terminal_message("Research completed", "success")
# Update the answer with the final report
final_report_lines = final_report.split('\n')
yield streaming_service.update_answer(final_report_lines)
except Exception as e:
# Handle any exceptions
yield streaming_service.add_terminal_message(f"Error during research: {str(e)}", "error")
# Send completion message
yield streaming_service.format_completion()