feat: Document Selector in Chat.

- Still need improvements but lets use it first.
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2025-06-04 21:46:50 -07:00
parent e8a19c496b
commit d7bb31f894
12 changed files with 599 additions and 67 deletions

View file

@ -21,6 +21,237 @@ from app.utils.query_service import QueryService
from langgraph.types import StreamWriter
# Additional imports for document fetching
from sqlalchemy.future import select
from app.db import Document, SearchSpace
async def fetch_documents_by_ids(
document_ids: List[int],
user_id: str,
db_session: AsyncSession
) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Fetch documents by their IDs with ownership check using DOCUMENTS mode approach.
This function ensures that only documents belonging to the user are fetched,
providing security by checking ownership through SearchSpace association.
Similar to SearchMode.DOCUMENTS, it fetches full documents and concatenates their chunks.
Also creates source objects for UI display, grouped by document type.
Args:
document_ids: List of document IDs to fetch
user_id: The user ID to check ownership
db_session: The database session
Returns:
Tuple of (source_objects, document_chunks) - similar to ConnectorService pattern
"""
if not document_ids:
return [], []
try:
# Query documents with ownership check
result = await db_session.execute(
select(Document)
.join(SearchSpace)
.filter(
Document.id.in_(document_ids),
SearchSpace.user_id == user_id
)
)
documents = result.scalars().all()
# Group documents by type for source object creation
documents_by_type = {}
formatted_documents = []
for doc in documents:
# Fetch associated chunks for this document (similar to DocumentHybridSearchRetriever)
from app.db import Chunk
chunks_query = select(Chunk).where(Chunk.document_id == doc.id).order_by(Chunk.id)
chunks_result = await db_session.execute(chunks_query)
chunks = chunks_result.scalars().all()
# Concatenate chunks content (similar to SearchMode.DOCUMENTS approach)
concatenated_chunks_content = " ".join([chunk.content for chunk in chunks]) if chunks else doc.content
# Format to match connector service return format
formatted_doc = {
"chunk_id": f"user_doc_{doc.id}",
"content": concatenated_chunks_content, # Use concatenated content like DOCUMENTS mode
"score": 0.5, # High score since user explicitly selected these
"document": {
"id": doc.id,
"title": doc.title,
"document_type": doc.document_type.value if doc.document_type else "UNKNOWN",
"metadata": doc.document_metadata or {},
},
"source": doc.document_type.value if doc.document_type else "UNKNOWN"
}
formatted_documents.append(formatted_doc)
# Group by document type for source objects
doc_type = doc.document_type.value if doc.document_type else "UNKNOWN"
if doc_type not in documents_by_type:
documents_by_type[doc_type] = []
documents_by_type[doc_type].append(doc)
# Create source objects for each document type (similar to ConnectorService)
source_objects = []
connector_id_counter = 100 # Start from 100 to avoid conflicts with regular connectors
for doc_type, docs in documents_by_type.items():
sources_list = []
for doc in docs:
metadata = doc.document_metadata or {}
# Create type-specific source formatting (similar to ConnectorService)
if doc_type == "LINEAR_CONNECTOR":
# Extract Linear-specific metadata
issue_identifier = metadata.get('issue_identifier', '')
issue_title = metadata.get('issue_title', doc.title)
issue_state = metadata.get('state', '')
comment_count = metadata.get('comment_count', 0)
# Create a more descriptive title for Linear issues
title = f"Linear: {issue_identifier} - {issue_title}" if issue_identifier else f"Linear: {issue_title}"
if issue_state:
title += f" ({issue_state})"
# Create description
description = doc.content[:100] + "..." if len(doc.content) > 100 else doc.content
if comment_count:
description += f" | Comments: {comment_count}"
# Create URL
url = f"https://linear.app/issue/{issue_identifier}" if issue_identifier else ""
elif doc_type == "SLACK_CONNECTOR":
# Extract Slack-specific metadata
channel_name = metadata.get('channel_name', 'Unknown Channel')
channel_id = metadata.get('channel_id', '')
message_date = metadata.get('start_date', '')
title = f"Slack: {channel_name}"
if message_date:
title += f" ({message_date})"
description = doc.content[:100] + "..." if len(doc.content) > 100 else doc.content
url = f"https://slack.com/app_redirect?channel={channel_id}" if channel_id else ""
elif doc_type == "NOTION_CONNECTOR":
# Extract Notion-specific metadata
page_title = metadata.get('page_title', doc.title)
page_id = metadata.get('page_id', '')
title = f"Notion: {page_title}"
description = doc.content[:100] + "..." if len(doc.content) > 100 else doc.content
url = f"https://notion.so/{page_id.replace('-', '')}" if page_id else ""
elif doc_type == "GITHUB_CONNECTOR":
title = f"GitHub: {doc.title}"
description = metadata.get('description', doc.content[:100] + "..." if len(doc.content) > 100 else doc.content)
url = metadata.get('url', '')
elif doc_type == "YOUTUBE_VIDEO":
# Extract YouTube-specific metadata
video_title = metadata.get('video_title', doc.title)
video_id = metadata.get('video_id', '')
channel_name = metadata.get('channel_name', '')
title = video_title
if channel_name:
title += f" - {channel_name}"
description = metadata.get('description', doc.content[:100] + "..." if len(doc.content) > 100 else doc.content)
url = f"https://www.youtube.com/watch?v={video_id}" if video_id else ""
elif doc_type == "DISCORD_CONNECTOR":
# Extract Discord-specific metadata
channel_name = metadata.get('channel_name', 'Unknown Channel')
channel_id = metadata.get('channel_id', '')
guild_id = metadata.get('guild_id', '')
message_date = metadata.get('start_date', '')
title = f"Discord: {channel_name}"
if message_date:
title += f" ({message_date})"
description = doc.content[:100] + "..." if len(doc.content) > 100 else doc.content
if guild_id and channel_id:
url = f"https://discord.com/channels/{guild_id}/{channel_id}"
elif channel_id:
url = f"https://discord.com/channels/@me/{channel_id}"
else:
url = ""
elif doc_type == "EXTENSION":
# Extract Extension-specific metadata
webpage_title = metadata.get('VisitedWebPageTitle', doc.title)
webpage_url = metadata.get('VisitedWebPageURL', '')
visit_date = metadata.get('VisitedWebPageDateWithTimeInISOString', '')
title = webpage_title
if visit_date:
formatted_date = visit_date.split('T')[0] if 'T' in visit_date else visit_date
title += f" (visited: {formatted_date})"
description = doc.content[:100] + "..." if len(doc.content) > 100 else doc.content
url = webpage_url
elif doc_type == "CRAWLED_URL":
title = doc.title
description = metadata.get('og:description', metadata.get('ogDescription', doc.content[:100] + "..." if len(doc.content) > 100 else doc.content))
url = metadata.get('url', '')
else: # FILE and other types
title = doc.title
description = doc.content[:100] + "..." if len(doc.content) > 100 else doc.content
url = metadata.get('url', '')
# Create source entry
source = {
"id": doc.id,
"title": title,
"description": description,
"url": url
}
sources_list.append(source)
# Create source object for this document type
friendly_type_names = {
"LINEAR_CONNECTOR": "Linear Issues (Selected)",
"SLACK_CONNECTOR": "Slack (Selected)",
"NOTION_CONNECTOR": "Notion (Selected)",
"GITHUB_CONNECTOR": "GitHub (Selected)",
"YOUTUBE_VIDEO": "YouTube Videos (Selected)",
"DISCORD_CONNECTOR": "Discord (Selected)",
"EXTENSION": "Browser Extension (Selected)",
"CRAWLED_URL": "Web Pages (Selected)",
"FILE": "Files (Selected)"
}
source_object = {
"id": connector_id_counter,
"name": friendly_type_names.get(doc_type, f"{doc_type} (Selected)"),
"type": f"USER_SELECTED_{doc_type}",
"sources": sources_list,
}
source_objects.append(source_object)
connector_id_counter += 1
print(f"Fetched {len(formatted_documents)} user-selected documents (with concatenated chunks) from {len(document_ids)} requested IDs")
print(f"Created {len(source_objects)} source objects for UI display")
return source_objects, formatted_documents
except Exception as e:
print(f"Error fetching documents by IDs: {str(e)}")
return [], []
class Section(BaseModel):
"""A section in the answer outline."""
@ -150,7 +381,8 @@ async def fetch_relevant_documents(
state: State = None,
top_k: int = 10,
connector_service: ConnectorService = None,
search_mode: SearchMode = SearchMode.CHUNKS
search_mode: SearchMode = SearchMode.CHUNKS,
user_selected_sources: List[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
Fetch relevant documents for research questions using the provided connectors.
@ -436,6 +668,21 @@ async def fetch_relevant_documents(
deduplicated_sources = []
seen_source_keys = set()
# First add user-selected sources (if any)
if user_selected_sources:
for source_obj in user_selected_sources:
source_id = source_obj.get('id')
source_type = source_obj.get('type')
if source_id and source_type:
source_key = f"{source_type}_{source_id}"
if source_key not in seen_source_keys:
seen_source_keys.add(source_key)
deduplicated_sources.append(source_obj)
else:
deduplicated_sources.append(source_obj)
# Then add connector sources
for source_obj in all_sources:
# Use combination of source ID and type as a unique identifier
# This ensures we don't accidentally deduplicate sources from different connectors
@ -453,7 +700,9 @@ async def fetch_relevant_documents(
# Stream info about deduplicated sources
if streaming_service and writer:
streaming_service.only_update_terminal(f"📚 Collected {len(deduplicated_sources)} unique sources across all connectors")
user_source_count = len(user_selected_sources) if user_selected_sources else 0
connector_source_count = len(deduplicated_sources) - user_source_count
streaming_service.only_update_terminal(f"📚 Collected {len(deduplicated_sources)} total sources ({user_source_count} user-selected + {connector_source_count} from connectors)")
writer({"yeild_value": streaming_service._format_annotations()})
# After all sources are collected and deduplicated, stream them
@ -576,8 +825,26 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
TOP_K = 10
relevant_documents = []
user_selected_documents = []
user_selected_sources = []
async with async_session_maker() as db_session:
try:
# First, fetch user-selected documents if any
if configuration.document_ids_to_add_in_context:
streaming_service.only_update_terminal(f"📋 Including {len(configuration.document_ids_to_add_in_context)} user-selected documents...")
writer({"yeild_value": streaming_service._format_annotations()})
user_selected_sources, user_selected_documents = await fetch_documents_by_ids(
document_ids=configuration.document_ids_to_add_in_context,
user_id=configuration.user_id,
db_session=db_session
)
if user_selected_documents:
streaming_service.only_update_terminal(f"✅ Successfully added {len(user_selected_documents)} user-selected documents to context")
writer({"yeild_value": streaming_service._format_annotations()})
# Create connector service inside the db_session scope
connector_service = ConnectorService(db_session, user_id=configuration.user_id)
await connector_service.initialize_counter()
@ -592,7 +859,8 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
state=state,
top_k=TOP_K,
connector_service=connector_service,
search_mode=configuration.search_mode
search_mode=configuration.search_mode,
user_selected_sources=user_selected_sources
)
except Exception as e:
error_message = f"Error fetching relevant documents: {str(e)}"
@ -603,8 +871,14 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
# This allows the process to continue, but the report might lack information
relevant_documents = []
# Combine user-selected documents with connector-fetched documents
all_documents = user_selected_documents + relevant_documents
print(f"Fetched {len(relevant_documents)} relevant documents for all sections")
streaming_service.only_update_terminal(f"✨ Starting to draft {len(answer_outline.answer_outline)} sections using {len(relevant_documents)} relevant document chunks")
print(f"Added {len(user_selected_documents)} user-selected documents for all sections")
print(f"Total documents for sections: {len(all_documents)}")
streaming_service.only_update_terminal(f"✨ Starting to draft {len(answer_outline.answer_outline)} sections using {len(all_documents)} total document chunks ({len(user_selected_documents)} user-selected + {len(relevant_documents)} connector-found)")
writer({"yeild_value": streaming_service._format_annotations()})
# Create tasks to process each section in parallel with the same document set
@ -635,7 +909,7 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
user_query=configuration.user_query,
user_id=configuration.user_id,
search_space_id=configuration.search_space_id,
relevant_documents=relevant_documents,
relevant_documents=all_documents, # Use combined documents
state=state,
writer=writer,
sub_section_type=sub_section_type,
@ -875,8 +1149,26 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre
TOP_K = 15
relevant_documents = []
user_selected_documents = []
user_selected_sources = []
async with async_session_maker() as db_session:
try:
# First, fetch user-selected documents if any
if configuration.document_ids_to_add_in_context:
streaming_service.only_update_terminal(f"📋 Including {len(configuration.document_ids_to_add_in_context)} user-selected documents...")
writer({"yeild_value": streaming_service._format_annotations()})
user_selected_sources, user_selected_documents = await fetch_documents_by_ids(
document_ids=configuration.document_ids_to_add_in_context,
user_id=configuration.user_id,
db_session=db_session
)
if user_selected_documents:
streaming_service.only_update_terminal(f"✅ Successfully added {len(user_selected_documents)} user-selected documents to context")
writer({"yeild_value": streaming_service._format_annotations()})
# Create connector service inside the db_session scope
connector_service = ConnectorService(db_session, user_id=configuration.user_id)
await connector_service.initialize_counter()
@ -894,7 +1186,8 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre
state=state,
top_k=TOP_K,
connector_service=connector_service,
search_mode=configuration.search_mode
search_mode=configuration.search_mode,
user_selected_sources=user_selected_sources
)
except Exception as e:
error_message = f"Error fetching relevant documents for QNA: {str(e)}"
@ -904,15 +1197,21 @@ async def handle_qna_workflow(state: State, config: RunnableConfig, writer: Stre
# Continue with empty documents - the QNA agent will handle this gracefully
relevant_documents = []
# Combine user-selected documents with connector-fetched documents
all_documents = user_selected_documents + relevant_documents
print(f"Fetched {len(relevant_documents)} relevant documents for QNA")
streaming_service.only_update_terminal(f"🧠 Generating comprehensive answer using {len(relevant_documents)} relevant sources...")
print(f"Added {len(user_selected_documents)} user-selected documents for QNA")
print(f"Total documents for QNA: {len(all_documents)}")
streaming_service.only_update_terminal(f"🧠 Generating comprehensive answer using {len(all_documents)} total sources ({len(user_selected_documents)} user-selected + {len(relevant_documents)} connector-found)...")
writer({"yeild_value": streaming_service._format_annotations()})
# Prepare configuration for the QNA agent
qna_config = {
"configurable": {
"user_query": reformulated_query, # Use the reformulated query
"relevant_documents": relevant_documents,
"relevant_documents": all_documents, # Use combined documents
"user_id": configuration.user_id,
"search_space_id": configuration.search_space_id
}