mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-25 19:15:18 +02:00
feat: Document Selector in Chat.
- Still need improvements but lets use it first.
This commit is contained in:
parent
e8a19c496b
commit
d7bb31f894
12 changed files with 599 additions and 67 deletions
|
|
@ -21,6 +21,237 @@ from app.utils.query_service import QueryService
|
|||
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
# Additional imports for document fetching
|
||||
from sqlalchemy.future import select
|
||||
from app.db import Document, SearchSpace
|
||||
|
||||
|
||||
async def fetch_documents_by_ids(
|
||||
document_ids: List[int],
|
||||
user_id: str,
|
||||
db_session: AsyncSession
|
||||
) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Fetch documents by their IDs with ownership check using DOCUMENTS mode approach.
|
||||
|
||||
This function ensures that only documents belonging to the user are fetched,
|
||||
providing security by checking ownership through SearchSpace association.
|
||||
Similar to SearchMode.DOCUMENTS, it fetches full documents and concatenates their chunks.
|
||||
Also creates source objects for UI display, grouped by document type.
|
||||
|
||||
Args:
|
||||
document_ids: List of document IDs to fetch
|
||||
user_id: The user ID to check ownership
|
||||
db_session: The database session
|
||||
|
||||
Returns:
|
||||
Tuple of (source_objects, document_chunks) - similar to ConnectorService pattern
|
||||
"""
|
||||
if not document_ids:
|
||||
return [], []
|
||||
|
||||
try:
|
||||
# Query documents with ownership check
|
||||
result = await db_session.execute(
|
||||
select(Document)
|
||||
.join(SearchSpace)
|
||||
.filter(
|
||||
Document.id.in_(document_ids),
|
||||
SearchSpace.user_id == user_id
|
||||
)
|
||||
)
|
||||
documents = result.scalars().all()
|
||||
|
||||
# Group documents by type for source object creation
|
||||
documents_by_type = {}
|
||||
formatted_documents = []
|
||||
|
||||
for doc in documents:
|
||||
# Fetch associated chunks for this document (similar to DocumentHybridSearchRetriever)
|
||||
from app.db import Chunk
|
||||
chunks_query = select(Chunk).where(Chunk.document_id == doc.id).order_by(Chunk.id)
|
||||
chunks_result = await db_session.execute(chunks_query)
|
||||
chunks = chunks_result.scalars().all()
|
||||
|
||||
# Concatenate chunks content (similar to SearchMode.DOCUMENTS approach)
|
||||
concatenated_chunks_content = " ".join([chunk.content for chunk in chunks]) if chunks else doc.content
|
||||
|
||||
# Format to match connector service return format
|
||||
formatted_doc = {
|
||||
"chunk_id": f"user_doc_{doc.id}",
|
||||
"content": concatenated_chunks_content, # Use concatenated content like DOCUMENTS mode
|
||||
"score": 0.5, # High score since user explicitly selected these
|
||||
"document": {
|
||||
"id": doc.id,
|
||||
"title": doc.title,
|
||||
"document_type": doc.document_type.value if doc.document_type else "UNKNOWN",
|
||||
"metadata": doc.document_metadata or {},
|
||||
},
|
||||
"source": doc.document_type.value if doc.document_type else "UNKNOWN"
|
||||
}
|
||||
formatted_documents.append(formatted_doc)
|
||||
|
||||
# Group by document type for source objects
|
||||
doc_type = doc.document_type.value if doc.document_type else "UNKNOWN"
|
||||
if doc_type not in documents_by_type:
|
||||
documents_by_type[doc_type] = []
|
||||
documents_by_type[doc_type].append(doc)
|
||||
|
||||
# Create source objects for each document type (similar to ConnectorService)
|
||||
source_objects = []
|
||||
connector_id_counter = 100 # Start from 100 to avoid conflicts with regular connectors
|
||||
|
||||
for doc_type, docs in documents_by_type.items():
|
||||
sources_list = []
|
||||
|
||||
for doc in docs:
|
||||
metadata = doc.document_metadata or {}
|
||||
|
||||
# Create type-specific source formatting (similar to ConnectorService)
|
||||
if doc_type == "LINEAR_CONNECTOR":
|
||||
# Extract Linear-specific metadata
|
||||
issue_identifier = metadata.get('issue_identifier', '')
|
||||
issue_title = metadata.get('issue_title', doc.title)
|
||||
issue_state = metadata.get('state', '')
|
||||
comment_count = metadata.get('comment_count', 0)
|
||||
|
||||
# Create a more descriptive title for Linear issues
|
||||
title = f"Linear: {issue_identifier} - {issue_title}" if issue_identifier else f"Linear: {issue_title}"
|
||||
if issue_state:
|
||||
title += f" ({issue_state})"
|
||||
|
||||
# Create description
|
||||
description = doc.content[:100] + "..." if len(doc.content) > 100 else doc.content
|
||||
if comment_count:
|
||||
description += f" | Comments: {comment_count}"
|
||||
|
||||
# Create URL
|
||||
url = f"https://linear.app/issue/{issue_identifier}" if issue_identifier else ""
|
||||
|
||||
elif doc_type == "SLACK_CONNECTOR":
|
||||
# Extract Slack-specific metadata
|
||||
channel_name = metadata.get('channel_name', 'Unknown Channel')
|
||||
channel_id = metadata.get('channel_id', '')
|
||||
message_date = metadata.get('start_date', '')
|
||||
|
||||
title = f"Slack: {channel_name}"
|
||||
if message_date:
|
||||
title += f" ({message_date})"
|
||||
|
||||
description = doc.content[:100] + "..." if len(doc.content) > 100 else doc.content
|
||||
url = f"https://slack.com/app_redirect?channel={channel_id}" if channel_id else ""
|
||||
|
||||
elif doc_type == "NOTION_CONNECTOR":
|
||||
# Extract Notion-specific metadata
|
||||
page_title = metadata.get('page_title', doc.title)
|
||||
page_id = metadata.get('page_id', '')
|
||||
|
||||
title = f"Notion: {page_title}"
|
||||
description = doc.content[:100] + "..." if len(doc.content) > 100 else doc.content
|
||||
url = f"https://notion.so/{page_id.replace('-', '')}" if page_id else ""
|
||||
|
||||
elif doc_type == "GITHUB_CONNECTOR":
|
||||
title = f"GitHub: {doc.title}"
|
||||
description = metadata.get('description', doc.content[:100] + "..." if len(doc.content) > 100 else doc.content)
|
||||
url = metadata.get('url', '')
|
||||
|
||||
elif doc_type == "YOUTUBE_VIDEO":
|
||||
# Extract YouTube-specific metadata
|
||||
video_title = metadata.get('video_title', doc.title)
|
||||
video_id = metadata.get('video_id', '')
|
||||
channel_name = metadata.get('channel_name', '')
|
||||
|
||||
title = video_title
|
||||
if channel_name:
|
||||
title += f" - {channel_name}"
|
||||
|
||||
description = metadata.get('description', doc.content[:100] + "..." if len(doc.content) > 100 else doc.content)
|
||||
url = f"https://www.youtube.com/watch?v={video_id}" if video_id else ""
|
||||
|
||||
elif doc_type == "DISCORD_CONNECTOR":
|
||||
# Extract Discord-specific metadata
|
||||
channel_name = metadata.get('channel_name', 'Unknown Channel')
|
||||
channel_id = metadata.get('channel_id', '')
|
||||
guild_id = metadata.get('guild_id', '')
|
||||
message_date = metadata.get('start_date', '')
|
||||
|
||||
title = f"Discord: {channel_name}"
|
||||
if message_date:
|
||||
title += f" ({message_date})"
|
||||
|
||||
description = doc.content[:100] + "..." if len(doc.content) > 100 else doc.content
|
||||
|
||||
if guild_id and channel_id:
|
||||
url = f"https://discord.com/channels/{guild_id}/{channel_id}"
|
||||
elif channel_id:
|
||||
url = f"https://discord.com/channels/@me/{channel_id}"
|
||||
else:
|
||||
url = ""
|
||||
|
||||
elif doc_type == "EXTENSION":
|
||||
# Extract Extension-specific metadata
|
||||
webpage_title = metadata.get('VisitedWebPageTitle', doc.title)
|
||||
webpage_url = metadata.get('VisitedWebPageURL', '')
|
||||
visit_date = metadata.get('VisitedWebPageDateWithTimeInISOString', '')
|
||||
|
||||
title = webpage_title
|
||||
if visit_date:
|
||||
formatted_date = visit_date.split('T')[0] if 'T' in visit_date else visit_date
|
||||
title += f" (visited: {formatted_date})"
|
||||
|
||||
description = doc.content[:100] + "..." if len(doc.content) > 100 else doc.content
|
||||
url = webpage_url
|
||||
|
||||
elif doc_type == "CRAWLED_URL":
|
||||
title = doc.title
|
||||
description = metadata.get('og:description', metadata.get('ogDescription', doc.content[:100] + "..." if len(doc.content) > 100 else doc.content))
|
||||
url = metadata.get('url', '')
|
||||
|
||||
else: # FILE and other types
|
||||
title = doc.title
|
||||
description = doc.content[:100] + "..." if len(doc.content) > 100 else doc.content
|
||||
url = metadata.get('url', '')
|
||||
|
||||
# Create source entry
|
||||
source = {
|
||||
"id": doc.id,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"url": url
|
||||
}
|
||||
sources_list.append(source)
|
||||
|
||||
# Create source object for this document type
|
||||
friendly_type_names = {
|
||||
"LINEAR_CONNECTOR": "Linear Issues (Selected)",
|
||||
"SLACK_CONNECTOR": "Slack (Selected)",
|
||||
"NOTION_CONNECTOR": "Notion (Selected)",
|
||||
"GITHUB_CONNECTOR": "GitHub (Selected)",
|
||||
"YOUTUBE_VIDEO": "YouTube Videos (Selected)",
|
||||
"DISCORD_CONNECTOR": "Discord (Selected)",
|
||||
"EXTENSION": "Browser Extension (Selected)",
|
||||
"CRAWLED_URL": "Web Pages (Selected)",
|
||||
"FILE": "Files (Selected)"
|
||||
}
|
||||
|
||||
source_object = {
|
||||
"id": connector_id_counter,
|
||||
"name": friendly_type_names.get(doc_type, f"{doc_type} (Selected)"),
|
||||
"type": f"USER_SELECTED_{doc_type}",
|
||||
"sources": sources_list,
|
||||
}
|
||||
source_objects.append(source_object)
|
||||
connector_id_counter += 1
|
||||
|
||||
print(f"Fetched {len(formatted_documents)} user-selected documents (with concatenated chunks) from {len(document_ids)} requested IDs")
|
||||
print(f"Created {len(source_objects)} source objects for UI display")
|
||||
|
||||
return source_objects, formatted_documents
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error fetching documents by IDs: {str(e)}")
|
||||
return [], []
|
||||
|
||||
|
||||
class Section(BaseModel):
|
||||
"""A section in the answer outline."""
|
||||
|
|
@ -150,7 +381,8 @@ async def fetch_relevant_documents(
|
|||
state: State = None,
|
||||
top_k: int = 10,
|
||||
connector_service: ConnectorService = None,
|
||||
search_mode: SearchMode = SearchMode.CHUNKS
|
||||
search_mode: SearchMode = SearchMode.CHUNKS,
|
||||
user_selected_sources: List[Dict[str, Any]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch relevant documents for research questions using the provided connectors.
|
||||
|
|
@ -436,6 +668,21 @@ async def fetch_relevant_documents(
|
|||
deduplicated_sources = []
|
||||
seen_source_keys = set()
|
||||
|
||||
# First add user-selected sources (if any)
|
||||
if user_selected_sources:
|
||||
for source_obj in user_selected_sources:
|
||||
source_id = source_obj.get('id')
|
||||
source_type = source_obj.get('type')
|
||||
|
||||
if source_id and source_type:
|
||||
source_key = f"{source_type}_{source_id}"
|
||||
if source_key not in seen_source_keys:
|
||||
seen_source_keys.add(source_key)
|
||||
deduplicated_sources.append(source_obj)
|
||||
else:
|
||||
deduplicated_sources.append(source_obj)
|
||||
|
||||
# Then add connector sources
|
||||
for source_obj in all_sources:
|
||||
# Use combination of source ID and type as a unique identifier
|
||||
# This ensures we don't accidentally deduplicate sources from different connectors
|
||||
|
|
@ -453,7 +700,9 @@ async def fetch_relevant_documents(
|
|||
|
||||
# Stream info about deduplicated sources
|
||||
if streaming_service and writer:
|
||||
streaming_service.only_update_terminal(f"📚 Collected {len(deduplicated_sources)} unique sources across all connectors")
|
||||
user_source_count = len(user_selected_sources) if user_selected_sources else 0
|
||||
connector_source_count = len(deduplicated_sources) - user_source_count
|
||||
streaming_service.only_update_terminal(f"📚 Collected {len(deduplicated_sources)} total sources ({user_source_count} user-selected + {connector_source_count} from connectors)")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
# After all sources are collected and deduplicated, stream them
|
||||
|
|
@ -576,8 +825,26 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
|
|||
TOP_K = 10
|
||||
|
||||
relevant_documents = []
|
||||
user_selected_documents = []
|
||||
user_selected_sources = []
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
try:
|
||||
# First, fetch user-selected documents if any
|
||||
if configuration.document_ids_to_add_in_context:
|
||||
streaming_service.only_update_terminal(f"📋 Including {len(configuration.document_ids_to_add_in_context)} user-selected documents...")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
user_selected_sources, user_selected_documents = await fetch_documents_by_ids(
|
||||
document_ids=configuration.document_ids_to_add_in_context,
|
||||
user_id=configuration.user_id,
|
||||
db_session=db_session
|
||||
)
|
||||
|
||||
if user_selected_documents:
|
||||
streaming_service.only_update_terminal(f"✅ Successfully added {len(user_selected_documents)} user-selected documents to context")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
# Create connector service inside the db_session scope
|
||||
connector_service = ConnectorService(db_session, user_id=configuration.user_id)
|
||||
await connector_service.initialize_counter()
|
||||
|
|
@ -592,7 +859,8 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
|
|||
state=state,
|
||||
top_k=TOP_K,
|
||||
connector_service=connector_service,
|
||||
search_mode=configuration.search_mode
|
||||
search_mode=configuration.search_mode,
|
||||
user_selected_sources=user_selected_sources
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = f"Error fetching relevant documents: {str(e)}"
|
||||
|
|
@ -603,8 +871,14 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
|
|||
# This allows the process to continue, but the report might lack information
|
||||
relevant_documents = []
|
||||
|
||||
# Combine user-selected documents with connector-fetched documents
|
||||
all_documents = user_selected_documents + relevant_documents
|
||||
|
||||
print(f"Fetched {len(relevant_documents)} relevant documents for all sections")
|
||||
streaming_service.only_update_terminal(f"✨ Starting to draft {len(answer_outline.answer_outline)} sections using {len(relevant_documents)} relevant document chunks")
|
||||
print(f"Added {len(user_selected_documents)} user-selected documents for all sections")
|
||||
print(f"Total documents for sections: {len(all_documents)}")
|
||||
|
||||
streaming_service.only_update_terminal(f"✨ Starting to draft {len(answer_outline.answer_outline)} sections using {len(all_documents)} total document chunks ({len(user_selected_documents)} user-selected + {len(relevant_documents)} connector-found)")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
# Create tasks to process each section in parallel with the same document set
|
||||
|
|
@ -635,7 +909,7 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
|
|||
user_query=configuration.user_query,
|
||||
user_id=configuration.user_id,
|
||||
search_space_id=configuration.search_space_id,
|
||||
relevant_documents=relevant_documents,
|
||||
relevant_documents=all_documents, # Use combined documents
|
||||
state=state,
|
||||
writer=writer,
|
||||
sub_section_type=sub_section_type,
|
||||
|
|
@ -875,8 +1149,26 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre
|
|||
TOP_K = 15
|
||||
|
||||
relevant_documents = []
|
||||
user_selected_documents = []
|
||||
user_selected_sources = []
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
try:
|
||||
# First, fetch user-selected documents if any
|
||||
if configuration.document_ids_to_add_in_context:
|
||||
streaming_service.only_update_terminal(f"📋 Including {len(configuration.document_ids_to_add_in_context)} user-selected documents...")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
user_selected_sources, user_selected_documents = await fetch_documents_by_ids(
|
||||
document_ids=configuration.document_ids_to_add_in_context,
|
||||
user_id=configuration.user_id,
|
||||
db_session=db_session
|
||||
)
|
||||
|
||||
if user_selected_documents:
|
||||
streaming_service.only_update_terminal(f"✅ Successfully added {len(user_selected_documents)} user-selected documents to context")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
# Create connector service inside the db_session scope
|
||||
connector_service = ConnectorService(db_session, user_id=configuration.user_id)
|
||||
await connector_service.initialize_counter()
|
||||
|
|
@ -894,7 +1186,8 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre
|
|||
state=state,
|
||||
top_k=TOP_K,
|
||||
connector_service=connector_service,
|
||||
search_mode=configuration.search_mode
|
||||
search_mode=configuration.search_mode,
|
||||
user_selected_sources=user_selected_sources
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = f"Error fetching relevant documents for QNA: {str(e)}"
|
||||
|
|
@ -904,15 +1197,21 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre
|
|||
# Continue with empty documents - the QNA agent will handle this gracefully
|
||||
relevant_documents = []
|
||||
|
||||
# Combine user-selected documents with connector-fetched documents
|
||||
all_documents = user_selected_documents + relevant_documents
|
||||
|
||||
print(f"Fetched {len(relevant_documents)} relevant documents for QNA")
|
||||
streaming_service.only_update_terminal(f"🧠 Generating comprehensive answer using {len(relevant_documents)} relevant sources...")
|
||||
print(f"Added {len(user_selected_documents)} user-selected documents for QNA")
|
||||
print(f"Total documents for QNA: {len(all_documents)}")
|
||||
|
||||
streaming_service.only_update_terminal(f"🧠 Generating comprehensive answer using {len(all_documents)} total sources ({len(user_selected_documents)} user-selected + {len(relevant_documents)} connector-found)...")
|
||||
writer({"yeild_value": streaming_service._format_annotations()})
|
||||
|
||||
# Prepare configuration for the QNA agent
|
||||
qna_config = {
|
||||
"configurable": {
|
||||
"user_query": reformulated_query, # Use the reformulated query
|
||||
"relevant_documents": relevant_documents,
|
||||
"relevant_documents": all_documents, # Use combined documents
|
||||
"user_id": configuration.user_id,
|
||||
"search_space_id": configuration.search_space_id
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue