feat: Improved sub section writer agent & Chat UI

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2025-05-09 22:14:22 -07:00
parent 1b9d7a0d96
commit 2cee5acaa3
4 changed files with 304 additions and 240 deletions

View file

@ -14,6 +14,8 @@ from .configuration import Configuration
from .prompts import get_answer_outline_system_prompt
from .state import State
from .sub_section_writer.graph import graph as sub_section_writer_graph
from .sub_section_writer.configuration import SubSectionType
from langgraph.types import StreamWriter
@ -41,14 +43,14 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str
"""
streaming_service = state.streaming_service
streaming_service.only_update_terminal("Generating answer outline...")
streaming_service.only_update_terminal("🔍 Generating answer outline...")
writer({"yeild_value": streaming_service._format_annotations()})
# Get configuration from runnable config
configuration = Configuration.from_runnable_config(config)
user_query = configuration.user_query
num_sections = configuration.num_sections
streaming_service.only_update_terminal(f"Planning research approach for query: {user_query[:100]}...")
streaming_service.only_update_terminal(f"🤔 Planning research approach for: \"{user_query[:100]}...\"")
writer({"yeild_value": streaming_service._format_annotations()})
# Initialize LLM
@ -78,7 +80,7 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str
Your output MUST be valid JSON in exactly this format. Do not include any other text or explanation.
"""
streaming_service.only_update_terminal("Designing structured outline with AI...")
streaming_service.only_update_terminal("📝 Designing structured outline with AI...")
writer({"yeild_value": streaming_service._format_annotations()})
# Create messages for the LLM
@ -88,7 +90,7 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str
]
# Call the LLM directly without using structured output
streaming_service.only_update_terminal("Processing answer structure...")
streaming_service.only_update_terminal("⚙️ Processing answer structure...")
writer({"yeild_value": streaming_service._format_annotations()})
response = await llm.ainvoke(messages)
@ -111,7 +113,7 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str
answer_outline = AnswerOutline(**parsed_data)
total_questions = sum(len(section.questions) for section in answer_outline.answer_outline)
streaming_service.only_update_terminal(f"Successfully generated outline with {len(answer_outline.answer_outline)} sections and {total_questions} research questions")
streaming_service.only_update_terminal(f"Successfully generated outline with {len(answer_outline.answer_outline)} sections and {total_questions} research questions!")
writer({"yeild_value": streaming_service._format_annotations()})
print(f"Successfully generated answer outline with {len(answer_outline.answer_outline)} sections")
@ -121,14 +123,14 @@ async def write_answer_outline(state: State, config: RunnableConfig, writer: Str
else:
# If JSON structure not found, raise a clear error
error_message = f"Could not find valid JSON in LLM response. Raw response: {content}"
streaming_service.only_update_terminal(error_message, "error")
streaming_service.only_update_terminal(f"{error_message}", "error")
writer({"yeild_value": streaming_service._format_annotations()})
raise ValueError(error_message)
except (json.JSONDecodeError, ValueError) as e:
# Log the error and re-raise it
error_message = f"Error parsing LLM response: {str(e)}"
streaming_service.only_update_terminal(error_message, "error")
streaming_service.only_update_terminal(f"{error_message}", "error")
writer({"yeild_value": streaming_service._format_annotations()})
print(f"Error parsing LLM response: {str(e)}")
@ -149,6 +151,11 @@ async def fetch_relevant_documents(
"""
Fetch relevant documents for research questions using the provided connectors.
This function searches across multiple data sources for information related to the
research questions. It provides user-friendly feedback during the search process by
displaying connector names (like "Web Search" instead of "TAVILY_API") and adding
relevant emojis to indicate the type of source being searched.
Args:
research_questions: List of research questions to find documents for
user_id: The user ID
@ -158,6 +165,7 @@ async def fetch_relevant_documents(
writer: StreamWriter for sending progress updates
state: The current state containing the streaming service
top_k: Number of top results to retrieve per connector per question
connector_service: An initialized connector service to use for searching
Returns:
List of relevant documents
@ -170,7 +178,9 @@ async def fetch_relevant_documents(
# Stream initial status update
if streaming_service and writer:
streaming_service.only_update_terminal(f"Starting research on {len(research_questions)} questions using {len(connectors_to_search)} connectors...")
connector_names = [get_connector_friendly_name(connector) for connector in connectors_to_search]
connector_names_str = ", ".join(connector_names)
streaming_service.only_update_terminal(f"🔎 Starting research on {len(research_questions)} questions using {connector_names_str} data sources")
writer({"yeild_value": streaming_service._format_annotations()})
all_raw_documents = [] # Store all raw documents
@ -179,7 +189,7 @@ async def fetch_relevant_documents(
for i, user_query in enumerate(research_questions):
# Stream question being researched
if streaming_service and writer:
streaming_service.only_update_terminal(f"Researching question {i+1}/{len(research_questions)}: {user_query[:100]}...")
streaming_service.only_update_terminal(f"🧠 Researching question {i+1}/{len(research_questions)}: \"{user_query[:100]}...\"")
writer({"yeild_value": streaming_service._format_annotations()})
# Use original research question as the query
@ -189,7 +199,9 @@ async def fetch_relevant_documents(
for connector in connectors_to_search:
# Stream connector being searched
if streaming_service and writer:
streaming_service.only_update_terminal(f"Searching {connector} for relevant information...")
connector_emoji = get_connector_emoji(connector)
friendly_name = get_connector_friendly_name(connector)
streaming_service.only_update_terminal(f"{connector_emoji} Searching {friendly_name} for relevant information...")
writer({"yeild_value": streaming_service._format_annotations()})
try:
@ -208,7 +220,7 @@ async def fetch_relevant_documents(
# Stream found document count
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(youtube_chunks)} YouTube chunks relevant to the query")
streaming_service.only_update_terminal(f"📹 Found {len(youtube_chunks)} YouTube chunks related to your query")
writer({"yeild_value": streaming_service._format_annotations()})
elif connector == "EXTENSION":
@ -226,7 +238,7 @@ async def fetch_relevant_documents(
# Stream found document count
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(extension_chunks)} extension chunks relevant to the query")
streaming_service.only_update_terminal(f"🧩 Found {len(extension_chunks)} Browser Extension chunks related to your query")
writer({"yeild_value": streaming_service._format_annotations()})
elif connector == "CRAWLED_URL":
@ -244,7 +256,7 @@ async def fetch_relevant_documents(
# Stream found document count
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(crawled_urls_chunks)} crawled URL chunks relevant to the query")
streaming_service.only_update_terminal(f"🌐 Found {len(crawled_urls_chunks)} Web Pages chunks related to your query")
writer({"yeild_value": streaming_service._format_annotations()})
elif connector == "FILE":
@ -262,7 +274,7 @@ async def fetch_relevant_documents(
# Stream found document count
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(files_chunks)} file chunks relevant to the query")
streaming_service.only_update_terminal(f"📄 Found {len(files_chunks)} Files chunks related to your query")
writer({"yeild_value": streaming_service._format_annotations()})
@ -281,7 +293,7 @@ async def fetch_relevant_documents(
# Stream found document count
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(slack_chunks)} Slack messages relevant to the query")
streaming_service.only_update_terminal(f"💬 Found {len(slack_chunks)} Slack messages related to your query")
writer({"yeild_value": streaming_service._format_annotations()})
elif connector == "NOTION_CONNECTOR":
@ -299,7 +311,7 @@ async def fetch_relevant_documents(
# Stream found document count
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(notion_chunks)} Notion pages/blocks relevant to the query")
streaming_service.only_update_terminal(f"📘 Found {len(notion_chunks)} Notion pages/blocks related to your query")
writer({"yeild_value": streaming_service._format_annotations()})
elif connector == "GITHUB_CONNECTOR":
@ -317,7 +329,7 @@ async def fetch_relevant_documents(
# Stream found document count
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(github_chunks)} GitHub files/issues relevant to the query")
streaming_service.only_update_terminal(f"🐙 Found {len(github_chunks)} GitHub files/issues related to your query")
writer({"yeild_value": streaming_service._format_annotations()})
elif connector == "LINEAR_CONNECTOR":
@ -335,7 +347,7 @@ async def fetch_relevant_documents(
# Stream found document count
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(linear_chunks)} Linear issues relevant to the query")
streaming_service.only_update_terminal(f"📊 Found {len(linear_chunks)} Linear issues related to your query")
writer({"yeild_value": streaming_service._format_annotations()})
elif connector == "TAVILY_API":
@ -352,7 +364,7 @@ async def fetch_relevant_documents(
# Stream found document count
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(tavily_chunks)} web search results relevant to the query")
streaming_service.only_update_terminal(f"🔍 Found {len(tavily_chunks)} Web Search results related to your query")
writer({"yeild_value": streaming_service._format_annotations()})
elif connector == "LINKUP_API":
@ -374,7 +386,7 @@ async def fetch_relevant_documents(
# Stream found document count
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(linkup_chunks)} Linkup chunks relevant to the query")
streaming_service.only_update_terminal(f"🔗 Found {len(linkup_chunks)} Linkup results related to your query")
writer({"yeild_value": streaming_service._format_annotations()})
@ -384,7 +396,8 @@ async def fetch_relevant_documents(
# Stream error message
if streaming_service and writer:
streaming_service.only_update_terminal(error_message, "error")
friendly_name = get_connector_friendly_name(connector)
streaming_service.only_update_terminal(f"⚠️ Error searching {friendly_name}: {str(e)}", "error")
writer({"yeild_value": streaming_service._format_annotations()})
# Continue with other connectors on error
@ -411,7 +424,7 @@ async def fetch_relevant_documents(
# Stream info about deduplicated sources
if streaming_service and writer:
streaming_service.only_update_terminal(f"Collected {len(deduplicated_sources)} unique sources across all connectors")
streaming_service.only_update_terminal(f"📚 Collected {len(deduplicated_sources)} unique sources across all connectors")
writer({"yeild_value": streaming_service._format_annotations()})
# After all sources are collected and deduplicated, stream them
@ -441,12 +454,44 @@ async def fetch_relevant_documents(
# Stream info about deduplicated documents
if streaming_service and writer:
streaming_service.only_update_terminal(f"Found {len(deduplicated_docs)} unique document chunks after deduplication")
streaming_service.only_update_terminal(f"🧹 Found {len(deduplicated_docs)} unique document chunks after removing duplicates")
writer({"yeild_value": streaming_service._format_annotations()})
# Return deduplicated documents
return deduplicated_docs
def get_connector_emoji(connector_name: str) -> str:
"""Get an appropriate emoji for a connector type."""
connector_emojis = {
"YOUTUBE_VIDEO": "📹",
"EXTENSION": "🧩",
"CRAWLED_URL": "🌐",
"FILE": "📄",
"SLACK_CONNECTOR": "💬",
"NOTION_CONNECTOR": "📘",
"GITHUB_CONNECTOR": "🐙",
"LINEAR_CONNECTOR": "📊",
"TAVILY_API": "🔍",
"LINKUP_API": "🔗"
}
return connector_emojis.get(connector_name, "🔎")
def get_connector_friendly_name(connector_name: str) -> str:
"""Convert technical connector IDs to user-friendly names."""
connector_friendly_names = {
"YOUTUBE_VIDEO": "YouTube",
"EXTENSION": "Browser Extension",
"CRAWLED_URL": "Web Pages",
"FILE": "Files",
"SLACK_CONNECTOR": "Slack",
"NOTION_CONNECTOR": "Notion",
"GITHUB_CONNECTOR": "GitHub",
"LINEAR_CONNECTOR": "Linear",
"TAVILY_API": "Tavily Search",
"LINKUP_API": "Linkup Search"
}
return connector_friendly_names.get(connector_name, connector_name)
async def process_sections(state: State, config: RunnableConfig, writer: StreamWriter) -> Dict[str, Any]:
"""
Process all sections in parallel and combine the results.
@ -463,13 +508,13 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
answer_outline = state.answer_outline
streaming_service = state.streaming_service
streaming_service.only_update_terminal(f"Starting to process research sections...")
streaming_service.only_update_terminal(f"🚀 Starting to process research sections...")
writer({"yeild_value": streaming_service._format_annotations()})
print(f"Processing sections from outline: {answer_outline is not None}")
if not answer_outline:
streaming_service.only_update_terminal("Error: No answer outline was provided. Cannot generate report.", "error")
streaming_service.only_update_terminal("Error: No answer outline was provided. Cannot generate report.", "error")
writer({"yeild_value": streaming_service._format_annotations()})
return {
"final_written_report": "No answer outline was provided. Cannot generate final report."
@ -481,11 +526,11 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
all_questions.extend(section.questions)
print(f"Collected {len(all_questions)} questions from all sections")
streaming_service.only_update_terminal(f"Found {len(all_questions)} research questions across {len(answer_outline.answer_outline)} sections")
streaming_service.only_update_terminal(f"🧩 Found {len(all_questions)} research questions across {len(answer_outline.answer_outline)} sections")
writer({"yeild_value": streaming_service._format_annotations()})
# Fetch relevant documents once for all questions
streaming_service.only_update_terminal("Searching for relevant information across all connectors...")
streaming_service.only_update_terminal("🔍 Searching for relevant information across all connectors...")
writer({"yeild_value": streaming_service._format_annotations()})
if configuration.num_sections == 1:
@ -515,7 +560,7 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
except Exception as e:
error_message = f"Error fetching relevant documents: {str(e)}"
print(error_message)
streaming_service.only_update_terminal(error_message, "error")
streaming_service.only_update_terminal(f"{error_message}", "error")
writer({"yeild_value": streaming_service._format_annotations()})
# Log the error and continue with an empty list of documents
# This allows the process to continue, but the report might lack information
@ -523,15 +568,22 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
# Consider adding more robust error handling or reporting if needed
print(f"Fetched {len(relevant_documents)} relevant documents for all sections")
streaming_service.only_update_terminal(f"Starting to draft {len(answer_outline.answer_outline)} sections using {len(relevant_documents)} relevant document chunks")
streaming_service.only_update_terminal(f"Starting to draft {len(answer_outline.answer_outline)} sections using {len(relevant_documents)} relevant document chunks")
writer({"yeild_value": streaming_service._format_annotations()})
# Create tasks to process each section in parallel with the same document set
section_tasks = []
streaming_service.only_update_terminal("Creating processing tasks for each section...")
streaming_service.only_update_terminal("⚙️ Creating processing tasks for each section...")
writer({"yeild_value": streaming_service._format_annotations()})
for section in answer_outline.answer_outline:
for i, section in enumerate(answer_outline.answer_outline):
if i == 0:
sub_section_type = SubSectionType.START
elif i == len(answer_outline.answer_outline) - 1:
sub_section_type = SubSectionType.END
else:
sub_section_type = SubSectionType.MIDDLE
section_tasks.append(
process_section_with_documents(
section_title=section.section_title,
@ -541,19 +593,20 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
search_space_id=configuration.search_space_id,
relevant_documents=relevant_documents,
state=state,
writer=writer
writer=writer,
sub_section_type=sub_section_type
)
)
# Run all section processing tasks in parallel
print(f"Running {len(section_tasks)} section processing tasks in parallel")
streaming_service.only_update_terminal(f"Processing {len(section_tasks)} sections simultaneously...")
streaming_service.only_update_terminal(f"Processing {len(section_tasks)} sections simultaneously...")
writer({"yeild_value": streaming_service._format_annotations()})
section_results = await asyncio.gather(*section_tasks, return_exceptions=True)
# Handle any exceptions in the results
streaming_service.only_update_terminal("Combining section results into final report...")
streaming_service.only_update_terminal("🧵 Combining section results into final report...")
writer({"yeild_value": streaming_service._format_annotations()})
processed_results = []
@ -562,7 +615,7 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
section_title = answer_outline.answer_outline[i].section_title
error_message = f"Error processing section '{section_title}': {str(result)}"
print(error_message)
streaming_service.only_update_terminal(error_message, "error")
streaming_service.only_update_terminal(f"⚠️ {error_message}", "error")
writer({"yeild_value": streaming_service._format_annotations()})
processed_results.append(error_message)
else:
@ -580,7 +633,7 @@ async def process_sections(state: State, config: RunnableConfig, writer: StreamW
final_written_report = "\n".join(final_report)
print(f"Generated final report with {len(final_report)} parts")
streaming_service.only_update_terminal("Final research report generated successfully!")
streaming_service.only_update_terminal("🎉 Final research report generated successfully!")
writer({"yeild_value": streaming_service._format_annotations()})
if hasattr(state, 'streaming_service') and state.streaming_service:
@ -612,7 +665,8 @@ async def process_section_with_documents(
relevant_documents: List[Dict[str, Any]],
user_query: str,
state: State = None,
writer: StreamWriter = None
writer: StreamWriter = None,
sub_section_type: SubSectionType = SubSectionType.MIDDLE
) -> str:
"""
Process a single section using pre-fetched documents.
@ -635,14 +689,14 @@ async def process_section_with_documents(
# Send status update via streaming if available
if state and state.streaming_service and writer:
state.streaming_service.only_update_terminal(f"Writing section: {section_title} with {len(section_questions)} research questions")
state.streaming_service.only_update_terminal(f"📝 Writing section: \"{section_title}\" with {len(section_questions)} research questions")
writer({"yeild_value": state.streaming_service._format_annotations()})
# Fallback if no documents found
if not documents_to_use:
print(f"No relevant documents found for section: {section_title}")
if state and state.streaming_service and writer:
state.streaming_service.only_update_terminal(f"Warning: No relevant documents found for section: {section_title}", "warning")
state.streaming_service.only_update_terminal(f"⚠️ Warning: No relevant documents found for section: \"{section_title}\"", "warning")
writer({"yeild_value": state.streaming_service._format_annotations()})
documents_to_use = [
@ -657,6 +711,7 @@ async def process_section_with_documents(
"configurable": {
"sub_section_title": section_title,
"sub_section_questions": section_questions,
"sub_section_type": sub_section_type,
"user_query": user_query,
"relevant_documents": documents_to_use,
"user_id": user_id,
@ -670,7 +725,7 @@ async def process_section_with_documents(
# Invoke the sub-section writer graph
print(f"Invoking sub_section_writer for: {section_title}")
if state and state.streaming_service and writer:
state.streaming_service.only_update_terminal(f"Analyzing information and drafting content for section: {section_title}")
state.streaming_service.only_update_terminal(f"🧠 Analyzing information and drafting content for section: \"{section_title}\"")
writer({"yeild_value": state.streaming_service._format_annotations()})
result = await sub_section_writer_graph.ainvoke(sub_state, config)
@ -680,7 +735,7 @@ async def process_section_with_documents(
# Send section content update via streaming if available
if state and state.streaming_service and writer:
state.streaming_service.only_update_terminal(f"Completed writing section: {section_title}")
state.streaming_service.only_update_terminal(f"Completed writing section: \"{section_title}\"")
writer({"yeild_value": state.streaming_service._format_annotations()})
return final_answer
@ -689,7 +744,7 @@ async def process_section_with_documents(
# Send error update via streaming if available
if state and state.streaming_service and writer:
state.streaming_service.only_update_terminal(f"Error processing section '{section_title}': {str(e)}", "error")
state.streaming_service.only_update_terminal(f"❌ Error processing section \"{section_title}\": {str(e)}", "error")
writer({"yeild_value": state.streaming_service._format_annotations()})
return f"Error processing section: {section_title}. Details: {str(e)}"