diff --git a/surfsense_backend/alembic/versions/49_migrate_old_chats_to_new_chat.py b/surfsense_backend/alembic/versions/49_migrate_old_chats_to_new_chat.py new file mode 100644 index 000000000..2d188b4e1 --- /dev/null +++ b/surfsense_backend/alembic/versions/49_migrate_old_chats_to_new_chat.py @@ -0,0 +1,216 @@ +"""Migrate old chats to new_chat_threads and remove old tables + +Revision ID: 49 +Revises: 48 +Create Date: 2025-12-21 + +This migration: +1. Migrates data from old 'chats' table to 'new_chat_threads' and 'new_chat_messages' +2. Drops the 'podcasts' table (podcast data is not migrated as per user request) +3. Drops the 'chats' table +4. Removes the 'chattype' enum +""" + +import json +from collections.abc import Sequence +from datetime import datetime, timezone + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "49" +down_revision: str | None = "48" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def extract_text_content(content: str | dict | list) -> str: + """Extract plain text content from various message formats.""" + if isinstance(content, str): + return content + if isinstance(content, dict): + # Handle dict with 'text' key + if "text" in content: + return content["text"] + return str(content) + if isinstance(content, list): + # Handle list of parts (e.g., [{"type": "text", "text": "..."}]) + texts = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + texts.append(part.get("text", "")) + elif isinstance(part, str): + texts.append(part) + return "\n".join(texts) if texts else "" + return "" + + +def parse_timestamp(ts, fallback): + """Parse ISO timestamp string to datetime object.""" + if ts is None: + return fallback + if isinstance(ts, datetime): + return ts + if isinstance(ts, str): + try: + # Handle ISO format like '2025-11-26T22:43:34.399Z' + ts = ts.replace("Z", "+00:00") + return datetime.fromisoformat(ts) + except (ValueError, TypeError): + return fallback + return fallback + + +def upgrade() -> None: + """Migrate old chats to new_chat_threads and remove old tables.""" + connection = op.get_bind() + + # Get all old chats + old_chats = connection.execute( + sa.text(""" + SELECT id, title, messages, search_space_id, created_at + FROM chats + ORDER BY created_at ASC + """) + ).fetchall() + + print(f"[Migration 49] Found {len(old_chats)} old chats to migrate") + + migrated_count = 0 + for chat_id, title, messages_json, search_space_id, created_at in old_chats: + try: + # Parse messages JSON + if isinstance(messages_json, str): + messages = json.loads(messages_json) + else: + messages = messages_json or [] + + # Skip empty chats + if not messages: + print(f"[Migration 49] Skipping empty chat {chat_id}") + continue + + # Create new thread + result = connection.execute( + sa.text(""" + INSERT INTO new_chat_threads + (title, archived, search_space_id, created_at, updated_at) + VALUES (:title, FALSE, :search_space_id, :created_at, :created_at) + RETURNING id + """), + { + "title": title or "Migrated Chat", + "search_space_id": search_space_id, + "created_at": created_at, + }, + ) + new_thread_id = result.fetchone()[0] + + # Migrate messages - only user and assistant roles, skip SOURCES/TERMINAL_INFO + message_count = 0 + for msg in messages: + role_lower = msg.get("role", "").lower() + + # Only migrate user and assistant messages + if role_lower not in ("user", "assistant"): + continue + + # Convert to uppercase for database enum + role = role_lower.upper() + + # Extract content - handle various formats + content_raw = msg.get("content", "") + content_text = extract_text_content(content_raw) + + # Skip empty messages + if not content_text.strip(): + continue + + # Parse message timestamp + msg_created_at = parse_timestamp(msg.get("createdAt"), created_at) + + # Store content as JSONB array format for assistant-ui compatibility + content_list = [{"type": "text", "text": content_text}] + + # Use direct SQL with string interpolation for the enum since CAST doesn't work + # The enum value comes from trusted source (our own code), not user input + connection.execute( + sa.text(f""" + INSERT INTO new_chat_messages + (thread_id, role, content, created_at) + VALUES (:thread_id, '{role}', CAST(:content AS jsonb), :created_at) + """), + { + "thread_id": new_thread_id, + "content": json.dumps(content_list), + "created_at": msg_created_at, + }, + ) + message_count += 1 + + print( + f"[Migration 49] Migrated chat {chat_id} -> thread {new_thread_id} ({message_count} messages)" + ) + migrated_count += 1 + + except Exception as e: + print(f"[Migration 49] Error migrating chat {chat_id}: {e}") + # Re-raise to abort migration - we don't want partial data + raise + + print(f"[Migration 49] Successfully migrated {migrated_count} chats") + + # Drop podcasts table (FK references chats, so drop first) + print("[Migration 49] Dropping podcasts table...") + op.drop_table("podcasts") + + # Drop chats table + print("[Migration 49] Dropping chats table...") + op.drop_table("chats") + + # Drop chattype enum + print("[Migration 49] Dropping chattype enum...") + op.execute(sa.text("DROP TYPE IF EXISTS chattype")) + + print("[Migration 49] Migration complete!") + + +def downgrade() -> None: + """Recreate old tables (data cannot be restored).""" + # Recreate chattype enum + op.execute( + sa.text(""" + CREATE TYPE chattype AS ENUM ('QNA') + """) + ) + + # Recreate chats table + op.create_table( + "chats", + sa.Column("id", sa.Integer(), primary_key=True, index=True), + sa.Column("type", sa.Enum("QNA", name="chattype"), nullable=False), + sa.Column("title", sa.String(), nullable=False, index=True), + sa.Column("initial_connectors", sa.ARRAY(sa.String()), nullable=True), + sa.Column("messages", sa.JSON(), nullable=False), + sa.Column("state_version", sa.BigInteger(), nullable=False, default=1), + sa.Column("search_space_id", sa.Integer(), sa.ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False), + sa.Column("created_at", sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.func.now()), + ) + + # Recreate podcasts table + op.create_table( + "podcasts", + sa.Column("id", sa.Integer(), primary_key=True, index=True), + sa.Column("title", sa.String(), nullable=False, index=True), + sa.Column("podcast_transcript", sa.JSON(), nullable=False, server_default="{}"), + sa.Column("file_location", sa.String(500), nullable=False, server_default=""), + sa.Column("chat_id", sa.Integer(), sa.ForeignKey("chats.id", ondelete="CASCADE"), nullable=True), + sa.Column("chat_state_version", sa.BigInteger(), nullable=True), + sa.Column("search_space_id", sa.Integer(), sa.ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False), + sa.Column("created_at", sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.func.now()), + ) + + print("[Migration 49 Downgrade] Tables recreated (data not restored)") + diff --git a/surfsense_backend/app/agents/researcher/__init__.py b/surfsense_backend/app/agents/researcher/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/surfsense_backend/app/agents/researcher/configuration.py b/surfsense_backend/app/agents/researcher/configuration.py deleted file mode 100644 index c89592c65..000000000 --- a/surfsense_backend/app/agents/researcher/configuration.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Define the configurable parameters for the agent.""" - -from __future__ import annotations - -from dataclasses import dataclass, fields - -from langchain_core.runnables import RunnableConfig - - -@dataclass(kw_only=True) -class Configuration: - """The configuration for the agent.""" - - # Input parameters provided at invocation - user_query: str - connectors_to_search: list[str] - user_id: str - search_space_id: int - document_ids_to_add_in_context: list[int] - language: str | None = None - top_k: int = 10 - - @classmethod - def from_runnable_config( - cls, config: RunnableConfig | None = None - ) -> Configuration: - """Create a Configuration instance from a RunnableConfig object.""" - configurable = (config.get("configurable") or {}) if config else {} - _fields = {f.name for f in fields(cls) if f.init} - return cls(**{k: v for k, v in configurable.items() if k in _fields}) diff --git a/surfsense_backend/app/agents/researcher/graph.py b/surfsense_backend/app/agents/researcher/graph.py deleted file mode 100644 index be2a1cff5..000000000 --- a/surfsense_backend/app/agents/researcher/graph.py +++ /dev/null @@ -1,47 +0,0 @@ -from langgraph.graph import StateGraph - -from .configuration import Configuration -from .nodes import ( - generate_further_questions, - handle_qna_workflow, - reformulate_user_query, -) -from .state import State - - -def build_graph(): - """ - Build and return the LangGraph workflow. - - This function constructs the researcher agent graph for Q&A workflow. - The workflow follows a simple path: - 1. Reformulate user query based on chat history - 2. Handle QNA workflow (fetch documents and generate answer) - 3. Generate follow-up questions - - Returns: - A compiled LangGraph workflow - """ - # Define a new graph with state class - workflow = StateGraph(State, config_schema=Configuration) - - # Add nodes to the graph - workflow.add_node("reformulate_user_query", reformulate_user_query) - workflow.add_node("handle_qna_workflow", handle_qna_workflow) - workflow.add_node("generate_further_questions", generate_further_questions) - - # Define the edges - simple linear flow for QNA - workflow.add_edge("__start__", "reformulate_user_query") - workflow.add_edge("reformulate_user_query", "handle_qna_workflow") - workflow.add_edge("handle_qna_workflow", "generate_further_questions") - workflow.add_edge("generate_further_questions", "__end__") - - # Compile the workflow into an executable graph - graph = workflow.compile() - graph.name = "Surfsense Researcher" # This defines the custom name in LangSmith - - return graph - - -# Compile the graph once when the module is loaded -graph = build_graph() diff --git a/surfsense_backend/app/agents/researcher/nodes.py b/surfsense_backend/app/agents/researcher/nodes.py deleted file mode 100644 index b16d4f0c1..000000000 --- a/surfsense_backend/app/agents/researcher/nodes.py +++ /dev/null @@ -1,1785 +0,0 @@ -import json -import logging -import traceback -from datetime import UTC, datetime, timedelta -from typing import Any - -from langchain_core.messages import HumanMessage, SystemMessage -from langchain_core.runnables import RunnableConfig -from langgraph.types import StreamWriter -from sqlalchemy.ext.asyncio import AsyncSession - -# Additional imports for document fetching -from sqlalchemy.future import select - -from app.db import Document -from app.services.connector_service import ConnectorService -from app.services.query_service import QueryService - -from .configuration import Configuration -from .prompts import get_further_questions_system_prompt -from .qna_agent.graph import graph as qna_agent_graph -from .state import State -from .utils import get_connector_emoji, get_connector_friendly_name - -# Time filter constants - hardcoded 2 year time range for now -DEFAULT_TIME_FILTER_YEARS = 2 - - -def extract_sources_from_documents( - all_documents: list[dict[str, Any]], -) -> list[dict[str, Any]]: - """ - Extract sources from **document-grouped** results and group them by document type. - - Args: - all_documents: List of document-grouped results from user-selected documents and connector-fetched documents - - Returns: - List of source objects grouped by type for streaming - """ - # Group sources by their source type - documents_by_type = {} - - for doc in all_documents: - document_info = doc.get("document", {}) or {} - source_type = doc.get("source", "UNKNOWN") - document_type = document_info.get("document_type", source_type) or source_type - group_type = document_type if document_type != "UNKNOWN" else source_type - if group_type not in documents_by_type: - documents_by_type[group_type] = [] - documents_by_type[group_type].append(doc) - - # Create source objects for each document type - source_objects = [] - for doc_type, docs in documents_by_type.items(): - sources_list = [] - - for doc in docs: - document_info = doc.get("document", {}) - metadata = document_info.get("metadata", {}) - url = ( - metadata.get("url") - or metadata.get("source") - or metadata.get("page_url") - or metadata.get("VisitedWebPageURL") - or "" - ) - - # Each chunk becomes a source entry so citations like [citation:] resolve in UI. - for chunk in doc.get("chunks", []) or []: - chunk_id = chunk.get("chunk_id") - chunk_content = (chunk.get("content") or "").strip() - description = ( - chunk_content - if len(chunk_content) <= 240 - else chunk_content[:240] + "..." - ) - sources_list.append( - { - "id": chunk_id, - "title": document_info.get("title", "Untitled Document"), - "description": description, - "url": url, - } - ) - - # Create group object - group_name = ( - get_connector_friendly_name(doc_type) - if doc_type != "UNKNOWN" - else "Unknown Sources" - ) - - source_object = { - "id": len(source_objects) + 1, - "name": group_name, - "type": doc_type, - "sources": sources_list, - } - - source_objects.append(source_object) - - return source_objects - - -async def fetch_documents_by_ids( - document_ids: list[int], search_space_id: int, db_session: AsyncSession -) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: - """ - Fetch documents by their IDs within a search space. - - This function ensures that only documents belonging to the search space are fetched. - It fetches full documents and returns their chunks individually. - Also creates source objects for UI display, grouped by document type. - - Args: - document_ids: List of document IDs to fetch - search_space_id: The search space ID to filter by - db_session: The database session - - Returns: - Tuple of (source_objects, document_chunks) - similar to ConnectorService pattern - """ - if not document_ids: - return [], [] - - try: - # Query documents filtered by search space - result = await db_session.execute( - select(Document).filter( - Document.id.in_(document_ids), - Document.search_space_id == search_space_id, - ) - ) - documents = result.scalars().all() - - # Group documents by type for source object creation - documents_by_type: dict[str, list[Document]] = {} - formatted_documents: list[dict[str, Any]] = [] - - from app.db import Chunk - - for doc in documents: - # Fetch associated chunks for this document - chunks_query = ( - select(Chunk).where(Chunk.document_id == doc.id).order_by(Chunk.id) - ) - chunks_result = await db_session.execute(chunks_query) - chunks = chunks_result.scalars().all() - - doc_type = doc.document_type.value if doc.document_type else "UNKNOWN" - documents_by_type.setdefault(doc_type, []).append(doc) - - doc_group = { - "document_id": doc.id, - "content": "\n\n".join(c.content for c in chunks) - if chunks - else (doc.content or ""), - "score": 0.5, # High score since user explicitly selected these - "chunks": [{"chunk_id": c.id, "content": c.content} for c in chunks] - if chunks - else [], - "document": { - "id": doc.id, - "title": doc.title, - "document_type": doc_type, - "metadata": doc.document_metadata or {}, - }, - "source": doc_type, - } - formatted_documents.append(doc_group) - - # Create source objects for each document type (similar to ConnectorService) - source_objects = [] - connector_id_counter = 100 - - for doc_type, docs in documents_by_type.items(): - sources_list = [] - - for doc in docs: - metadata = doc.document_metadata or {} - - # Create type-specific source formatting (similar to ConnectorService) - if doc_type == "LINEAR_CONNECTOR": - # Extract Linear-specific metadata - issue_identifier = metadata.get("issue_identifier", "") - issue_title = metadata.get("issue_title", doc.title) - issue_state = metadata.get("state", "") - comment_count = metadata.get("comment_count", 0) - - # Create a more descriptive title for Linear issues - title = ( - f"Linear: {issue_identifier} - {issue_title}" - if issue_identifier - else f"Linear: {issue_title}" - ) - if issue_state: - title += f" ({issue_state})" - - # Create description - description = doc.content - if comment_count: - description += f" | Comments: {comment_count}" - - # Create URL - url = ( - f"https://linear.app/issue/{issue_identifier}" - if issue_identifier - else "" - ) - - elif doc_type == "SLACK_CONNECTOR": - # Extract Slack-specific metadata - channel_name = metadata.get("channel_name", "Unknown Channel") - channel_id = metadata.get("channel_id", "") - message_date = metadata.get("start_date", "") - - title = f"Slack: {channel_name}" - if message_date: - title += f" ({message_date})" - - description = doc.content - url = ( - f"https://slack.com/app_redirect?channel={channel_id}" - if channel_id - else "" - ) - - elif doc_type == "NOTION_CONNECTOR": - # Extract Notion-specific metadata - page_title = metadata.get("page_title", doc.title) - page_id = metadata.get("page_id", "") - - title = f"Notion: {page_title}" - description = doc.content - url = ( - f"https://notion.so/{page_id.replace('-', '')}" - if page_id - else "" - ) - - elif doc_type == "GITHUB_CONNECTOR": - title = f"GitHub: {doc.title}" - description = metadata.get( - "description", - (doc.content), - ) - url = metadata.get("url", "") - - elif doc_type == "YOUTUBE_VIDEO": - # Extract YouTube-specific metadata - video_title = metadata.get("video_title", doc.title) - video_id = metadata.get("video_id", "") - channel_name = metadata.get("channel_name", "") - - title = video_title - if channel_name: - title += f" - {channel_name}" - - description = metadata.get( - "description", - (doc.content), - ) - url = ( - f"https://www.youtube.com/watch?v={video_id}" - if video_id - else "" - ) - - elif doc_type == "DISCORD_CONNECTOR": - # Extract Discord-specific metadata - channel_name = metadata.get("channel_name", "Unknown Channel") - channel_id = metadata.get("channel_id", "") - guild_id = metadata.get("guild_id", "") - message_date = metadata.get("start_date", "") - - title = f"Discord: {channel_name}" - if message_date: - title += f" ({message_date})" - - description = doc.content - - if guild_id and channel_id: - url = f"https://discord.com/channels/{guild_id}/{channel_id}" - elif channel_id: - url = f"https://discord.com/channels/@me/{channel_id}" - else: - url = "" - - elif doc_type == "JIRA_CONNECTOR": - # Extract Jira-specific metadata - issue_key = metadata.get("issue_key", "Unknown Issue") - issue_title = metadata.get("issue_title", "Untitled Issue") - status = metadata.get("status", "") - priority = metadata.get("priority", "") - issue_type = metadata.get("issue_type", "") - - title = f"Jira: {issue_key} - {issue_title}" - if status: - title += f" ({status})" - - description = doc.content - if priority: - description += f" | Priority: {priority}" - if issue_type: - description += f" | Type: {issue_type}" - - # Construct Jira URL if we have the base URL - base_url = metadata.get("base_url", "") - if base_url and issue_key: - url = f"{base_url}/browse/{issue_key}" - else: - url = "" - - elif doc_type == "GOOGLE_CALENDAR_CONNECTOR": - # Extract Google Calendar-specific metadata - event_id = metadata.get("event_id", "Unknown Event") - event_summary = metadata.get("event_summary", "Untitled Event") - calendar_id = metadata.get("calendar_id", "") - start_time = metadata.get("start_time", "") - location = metadata.get("location", "") - - title = f"Calendar: {event_summary}" - if start_time: - # Format the start time for display - try: - if "T" in start_time: - from datetime import datetime - - start_dt = datetime.fromisoformat( - start_time.replace("Z", "+00:00") - ) - formatted_time = start_dt.strftime("%Y-%m-%d %H:%M") - title += f" ({formatted_time})" - else: - title += f" ({start_time})" - except Exception: - title += f" ({start_time})" - - elif doc_type == "AIRTABLE_CONNECTOR": - # Extract Airtable-specific metadata - base_name = metadata.get("base_name", "Unknown Base") - table_name = metadata.get("table_name", "Unknown Table") - record_id = metadata.get("record_id", "Unknown Record") - created_time = metadata.get("created_time", "") - - title = f"Airtable: {base_name} - {table_name}" - if record_id: - title += f" (Record: {record_id[:8]}...)" - if created_time: - # Format the created time for display - try: - if "T" in created_time: - from datetime import datetime - - created_dt = datetime.fromisoformat( - created_time.replace("Z", "+00:00") - ) - formatted_time = created_dt.strftime("%Y-%m-%d %H:%M") - title += f" - {formatted_time}" - except Exception: - pass - - description = doc.content - if location: - description += f" | Location: {location}" - if calendar_id and calendar_id != "primary": - description += f" | Calendar: {calendar_id}" - - # Construct Google Calendar URL - if event_id: - url = ( - f"https://calendar.google.com/calendar/event?eid={event_id}" - ) - else: - url = "" - - elif doc_type == "LUMA_CONNECTOR": - # Extract Luma-specific metadata - event_id = metadata.get("event_id", "") - event_name = metadata.get("event_name", "Untitled Event") - event_url = metadata.get("event_url", "") - start_time = metadata.get("start_time", "") - location_name = metadata.get("location_name", "") - meeting_url = metadata.get("meeting_url", "") - - title = f"Luma: {event_name}" - if start_time: - # Format the start time for display - try: - if "T" in start_time: - from datetime import datetime - - start_dt = datetime.fromisoformat( - start_time.replace("Z", "+00:00") - ) - formatted_time = start_dt.strftime("%Y-%m-%d %H:%M") - title += f" ({formatted_time})" - except Exception: - pass - - description = doc.content - - if location_name: - description += f" | Venue: {location_name}" - elif meeting_url: - description += " | Online Event" - - url = event_url if event_url else "" - - elif doc_type == "EXTENSION": - # Extract Extension-specific metadata - webpage_title = metadata.get("VisitedWebPageTitle", doc.title) - webpage_url = metadata.get("VisitedWebPageURL", "") - visit_date = metadata.get( - "VisitedWebPageDateWithTimeInISOString", "" - ) - - title = webpage_title - if visit_date: - formatted_date = ( - visit_date.split("T")[0] - if "T" in visit_date - else visit_date - ) - title += f" (visited: {formatted_date})" - - description = doc.content - url = webpage_url - - elif doc_type == "CRAWLED_URL": - title = doc.title - description = metadata.get( - "og:description", - metadata.get( - "ogDescription", - (doc.content), - ), - ) - url = metadata.get("url", "") - - elif doc_type == "ELASTICSEARCH_CONNECTOR": - # Prefer explicit title in metadata/source, otherwise fallback to doc.title - es_title = ( - metadata.get("title") - or metadata.get("es_title") - or doc.title - or f"Elasticsearch: {metadata.get('elasticsearch_index', '')}" - ) - title = es_title - description = metadata.get("description") or ( - doc.content[:100] + "..." - if len(doc.content) > 100 - else doc.content - ) - # If a link or index info is stored, surface it - url = metadata.get("url", "") or metadata.get( - "elasticsearch_index", "" - ) - - else: # FILE and other types - title = doc.title - description = doc.content - - url = metadata.get("url", "") - - # Create source entry - source = { - "id": doc.id, - "title": title, - "description": description, - "url": url, - } - sources_list.append(source) - - # Create source object for this document type - friendly_type_names = { - "LINEAR_CONNECTOR": "Linear Issues (Selected)", - "SLACK_CONNECTOR": "Slack (Selected)", - "NOTION_CONNECTOR": "Notion (Selected)", - "GITHUB_CONNECTOR": "GitHub (Selected)", - "ELASTICSEARCH_CONNECTOR": "Elasticsearch (Selected)", - "YOUTUBE_VIDEO": "YouTube Videos (Selected)", - "DISCORD_CONNECTOR": "Discord (Selected)", - "JIRA_CONNECTOR": "Jira Issues (Selected)", - "EXTENSION": "Browser Extension (Selected)", - "CRAWLED_URL": "Web Pages (Selected)", - "FILE": "Files (Selected)", - "GOOGLE_CALENDAR_CONNECTOR": "Google Calendar (Selected)", - "GOOGLE_GMAIL_CONNECTOR": "Google Gmail (Selected)", - "CONFLUENCE_CONNECTOR": "Confluence (Selected)", - "CLICKUP_CONNECTOR": "ClickUp (Selected)", - "AIRTABLE_CONNECTOR": "Airtable (Selected)", - "LUMA_CONNECTOR": "Luma Events (Selected)", - "NOTE": "Notes (Selected)", - } - - source_object = { - "id": connector_id_counter, - "name": friendly_type_names.get(doc_type, f"{doc_type} (Selected)"), - "type": f"USER_SELECTED_{doc_type}", - "sources": sources_list, - } - source_objects.append(source_object) - connector_id_counter += 1 - - print( - f"Fetched {len(formatted_documents)} user-selected chunks from {len(document_ids)} requested document IDs" - ) - print(f"Created {len(source_objects)} source objects for UI display") - - return source_objects, formatted_documents - - except Exception as e: - print(f"Error fetching documents by IDs: {e!s}") - return [], [] - - -async def fetch_relevant_documents( - research_questions: list[str], - search_space_id: int, - db_session: AsyncSession, - connectors_to_search: list[str], - writer: StreamWriter = None, - state: State = None, - top_k: int = 10, - connector_service: ConnectorService = None, - user_selected_sources: list[dict[str, Any]] | None = None, - start_date: datetime | None = None, - end_date: datetime | None = None, -) -> list[dict[str, Any]]: - """ - Fetch relevant documents for research questions using the provided connectors. - - This function searches across multiple data sources for information related to the - research questions. It provides user-friendly feedback during the search process by - displaying connector names (like "Web Search" instead of "TAVILY_API") and adding - relevant emojis to indicate the type of source being searched. - - Uses combined chunk-level and document-level hybrid search with RRF fusion. - - Args: - research_questions: List of research questions to find documents for - search_space_id: The search space ID - db_session: The database session - connectors_to_search: List of connectors to search - writer: StreamWriter for sending progress updates - state: The current state containing the streaming service - top_k: Number of top results to retrieve per connector per question - connector_service: An initialized connector service to use for searching - user_selected_sources: Optional list of user-selected source objects - start_date: Optional start date for filtering documents by updated_at - end_date: Optional end date for filtering documents by updated_at - - Returns: - List of relevant documents - """ - # Initialize services - # connector_service = ConnectorService(db_session) - - # Only use streaming if both writer and state are provided - streaming_service = state.streaming_service if state is not None else None - - # Handle case when no connectors are selected - if not connectors_to_search or len(connectors_to_search) == 0: - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - "📹 No data sources selected. Research will be generated using general knowledge and any user-selected documents." - ) - } - ) - print("No connectors selected for research. Returning empty document list.") - return [] # Return empty list gracefully - - # Stream initial status update - if streaming_service and writer: - connector_names = [ - get_connector_friendly_name(connector) for connector in connectors_to_search - ] - connector_names_str = ", ".join(connector_names) - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"🔎 Starting research on {len(research_questions)} questions using {connector_names_str} data sources" - ) - } - ) - - all_raw_documents = [] # Store all raw documents - all_sources = [] # Store all sources - - for i, user_query in enumerate(research_questions): - # Stream question being researched - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f'🧠 Researching question {i + 1}/{len(research_questions)}: "{user_query[:100]}..."' - ) - } - ) - - # Use original research question as the query - reformulated_query = user_query - - # Process each selected connector - for connector in connectors_to_search: - # Stream connector being searched - if streaming_service and writer: - connector_emoji = get_connector_emoji(connector) - friendly_name = get_connector_friendly_name(connector) - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"{connector_emoji} Searching {friendly_name} for relevant information..." - ) - } - ) - - try: - if connector == "YOUTUBE_VIDEO": - ( - source_object, - youtube_chunks, - ) = await connector_service.search_youtube( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - start_date=start_date, - end_date=end_date, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(youtube_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"📹 Found {len(youtube_chunks)} YouTube chunks related to your query" - ) - } - ) - - elif connector == "EXTENSION": - ( - source_object, - extension_chunks, - ) = await connector_service.search_extension( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - start_date=start_date, - end_date=end_date, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(extension_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"🧩 Found {len(extension_chunks)} Browser Extension chunks related to your query" - ) - } - ) - - elif connector == "CRAWLED_URL": - ( - source_object, - crawled_urls_chunks, - ) = await connector_service.search_crawled_urls( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - start_date=start_date, - end_date=end_date, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(crawled_urls_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"🌐 Found {len(crawled_urls_chunks)} Web Page chunks related to your query" - ) - } - ) - - elif connector == "FILE": - source_object, files_chunks = await connector_service.search_files( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - start_date=start_date, - end_date=end_date, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(files_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"📄 Found {len(files_chunks)} Files chunks related to your query" - ) - } - ) - - elif connector == "SLACK_CONNECTOR": - source_object, slack_chunks = await connector_service.search_slack( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - start_date=start_date, - end_date=end_date, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(slack_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"💬 Found {len(slack_chunks)} Slack messages related to your query" - ) - } - ) - - elif connector == "NOTION_CONNECTOR": - ( - source_object, - notion_chunks, - ) = await connector_service.search_notion( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - start_date=start_date, - end_date=end_date, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(notion_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"📘 Found {len(notion_chunks)} Notion pages/blocks related to your query" - ) - } - ) - - elif connector == "GITHUB_CONNECTOR": - ( - source_object, - github_chunks, - ) = await connector_service.search_github( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - start_date=start_date, - end_date=end_date, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(github_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"🐙 Found {len(github_chunks)} GitHub files/issues related to your query" - ) - } - ) - - elif connector == "LINEAR_CONNECTOR": - ( - source_object, - linear_chunks, - ) = await connector_service.search_linear( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - start_date=start_date, - end_date=end_date, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(linear_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"📊 Found {len(linear_chunks)} Linear issues related to your query" - ) - } - ) - - elif connector == "TAVILY_API": - ( - source_object, - tavily_chunks, - ) = await connector_service.search_tavily( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(tavily_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"🔍 Found {len(tavily_chunks)} Web Search results related to your query" - ) - } - ) - - elif connector == "SEARXNG_API": - ( - source_object, - searx_chunks, - ) = await connector_service.search_searxng( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - ) - - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(searx_chunks) - - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"🌐 Found {len(searx_chunks)} SearxNG results related to your query" - ) - } - ) - - elif connector == "LINKUP_API": - linkup_mode = "standard" - - ( - source_object, - linkup_chunks, - ) = await connector_service.search_linkup( - user_query=reformulated_query, - search_space_id=search_space_id, - mode=linkup_mode, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(linkup_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"🔗 Found {len(linkup_chunks)} Linkup results related to your query" - ) - } - ) - - elif connector == "BAIDU_SEARCH_API": - ( - source_object, - baidu_chunks, - ) = await connector_service.search_baidu( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(baidu_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"🇨🇳 Found {len(baidu_chunks)} Baidu Search results related to your query" - ) - } - ) - - elif connector == "DISCORD_CONNECTOR": - ( - source_object, - discord_chunks, - ) = await connector_service.search_discord( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - start_date=start_date, - end_date=end_date, - ) - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(discord_chunks) - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"🗨️ Found {len(discord_chunks)} Discord messages related to your query" - ) - } - ) - - elif connector == "JIRA_CONNECTOR": - source_object, jira_chunks = await connector_service.search_jira( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - start_date=start_date, - end_date=end_date, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(jira_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"🎫 Found {len(jira_chunks)} Jira issues related to your query" - ) - } - ) - elif connector == "GOOGLE_CALENDAR_CONNECTOR": - ( - source_object, - calendar_chunks, - ) = await connector_service.search_google_calendar( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - start_date=start_date, - end_date=end_date, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(calendar_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"📅 Found {len(calendar_chunks)} calendar events related to your query" - ) - } - ) - elif connector == "AIRTABLE_CONNECTOR": - ( - source_object, - airtable_chunks, - ) = await connector_service.search_airtable( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - start_date=start_date, - end_date=end_date, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(airtable_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"🗃️ Found {len(airtable_chunks)} Airtable records related to your query" - ) - } - ) - elif connector == "GOOGLE_GMAIL_CONNECTOR": - ( - source_object, - gmail_chunks, - ) = await connector_service.search_google_gmail( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - start_date=start_date, - end_date=end_date, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(gmail_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"📧 Found {len(gmail_chunks)} Gmail messages related to your query" - ) - } - ) - elif connector == "CONFLUENCE_CONNECTOR": - ( - source_object, - confluence_chunks, - ) = await connector_service.search_confluence( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - start_date=start_date, - end_date=end_date, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(confluence_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"📚 Found {len(confluence_chunks)} Confluence pages related to your query" - ) - } - ) - elif connector == "CLICKUP_CONNECTOR": - ( - source_object, - clickup_chunks, - ) = await connector_service.search_clickup( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - start_date=start_date, - end_date=end_date, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(clickup_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"📋 Found {len(clickup_chunks)} ClickUp tasks related to your query" - ) - } - ) - - elif connector == "LUMA_CONNECTOR": - ( - source_object, - luma_chunks, - ) = await connector_service.search_luma( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - start_date=start_date, - end_date=end_date, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(luma_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"🎯 Found {len(luma_chunks)} Luma events related to your query" - ) - } - ) - - elif connector == "ELASTICSEARCH_CONNECTOR": - ( - source_object, - elasticsearch_chunks, - ) = await connector_service.search_elasticsearch( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - start_date=start_date, - end_date=end_date, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(elasticsearch_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"🔎 Found {len(elasticsearch_chunks)} Elasticsearch chunks related to your query" - ) - } - ) - - elif connector == "BOOKSTACK_CONNECTOR": - ( - source_object, - bookstack_chunks, - ) = await connector_service.search_bookstack( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - start_date=start_date, - end_date=end_date, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(bookstack_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"📚 Found {len(bookstack_chunks)} BookStack pages related to your query" - ) - } - ) - - elif connector == "NOTE": - ( - source_object, - notes_chunks, - ) = await connector_service.search_notes( - user_query=reformulated_query, - search_space_id=search_space_id, - top_k=top_k, - start_date=start_date, - end_date=end_date, - ) - - # Add to sources and raw documents - if source_object: - all_sources.append(source_object) - all_raw_documents.extend(notes_chunks) - - # Stream found document count - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"📝 Found {len(notes_chunks)} Notes related to your query" - ) - } - ) - - except Exception as e: - logging.error("Error in search_airtable: %s", traceback.format_exc()) - error_message = f"Error searching connector {connector}: {e!s}" - print(error_message) - - # Stream error message - if streaming_service and writer: - friendly_name = get_connector_friendly_name(connector) - writer( - { - "yield_value": streaming_service.format_error( - f"Error searching {friendly_name}: {e!s}" - ) - } - ) - - # Continue with other connectors on error - continue - - # Deduplicate source objects by ID before streaming - deduplicated_sources = [] - seen_source_keys = set() - - # First add user-selected sources (if any) - if user_selected_sources: - for source_obj in user_selected_sources: - source_id = source_obj.get("id") - source_type = source_obj.get("type") - - if source_id and source_type: - source_key = f"{source_type}_{source_id}" - if source_key not in seen_source_keys: - seen_source_keys.add(source_key) - deduplicated_sources.append(source_obj) - else: - deduplicated_sources.append(source_obj) - - # Then add connector sources - for source_obj in all_sources: - # Use combination of source ID and type as a unique identifier - # This ensures we don't accidentally deduplicate sources from different connectors - source_id = source_obj.get("id") - source_type = source_obj.get("type") - - if source_id and source_type: - source_key = f"{source_type}_{source_id}" - current_sources_count = len(source_obj.get("sources", [])) - - if source_key not in seen_source_keys: - seen_source_keys.add(source_key) - deduplicated_sources.append(source_obj) - print( - f"Debug: Added source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {current_sources_count}" - ) - else: - # Check if this source object has more sources than the existing one - existing_index = None - for i, existing_source in enumerate(deduplicated_sources): - existing_id = existing_source.get("id") - existing_type = existing_source.get("type") - if existing_id == source_id and existing_type == source_type: - existing_index = i - break - - if existing_index is not None: - existing_sources_count = len( - deduplicated_sources[existing_index].get("sources", []) - ) - if current_sources_count > existing_sources_count: - # Replace the existing source object with the new one that has more sources - deduplicated_sources[existing_index] = source_obj - print( - f"Debug: Replaced source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {existing_sources_count} -> {current_sources_count}" - ) - else: - print( - f"Debug: Skipped duplicate source - ID: {source_id}, Type: {source_type}, Key: {source_key}, Sources count: {current_sources_count} <= {existing_sources_count}" - ) - else: - print( - f"Debug: Skipped duplicate source - ID: {source_id}, Type: {source_type}, Key: {source_key} (couldn't find existing)" - ) - else: - # If there's no ID or type, just add it to be safe - deduplicated_sources.append(source_obj) - print( - f"Debug: Added source without ID/type - {source_obj.get('name', 'UNKNOWN')}" - ) - - # Stream info about deduplicated sources - if streaming_service and writer: - user_source_count = len(user_selected_sources) if user_selected_sources else 0 - connector_source_count = len(deduplicated_sources) - user_source_count - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"📚 Collected {len(deduplicated_sources)} total sources ({user_source_count} user-selected + {connector_source_count} from connectors)" - ) - } - ) - - # Deduplicate raw documents based on document_id (preferred) or content hash - seen_doc_ids = set() - seen_content_hashes = set() - deduplicated_docs: list[dict[str, Any]] = [] - - for doc in all_raw_documents: - doc_id = (doc.get("document", {}) or {}).get("id") - content = doc.get("content", "") or "" - content_hash = hash(content) - - # Skip if we've seen this document_id or content before - if (doc_id and doc_id in seen_doc_ids) or content_hash in seen_content_hashes: - continue - - if doc_id: - seen_doc_ids.add(doc_id) - seen_content_hashes.add(content_hash) - deduplicated_docs.append(doc) - - # Stream info about deduplicated documents - if streaming_service and writer: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"🧹 Found {len(deduplicated_docs)} unique documents after removing duplicates" - ) - } - ) - - # Return deduplicated documents - return deduplicated_docs - - -async def reformulate_user_query( - state: State, config: RunnableConfig, writer: StreamWriter -) -> dict[str, Any]: - """ - Reforms the user query based on the chat history. - """ - - configuration = Configuration.from_runnable_config(config) - user_query = configuration.user_query - chat_history_str = await QueryService.langchain_chat_history_to_str( - state.chat_history - ) - if len(state.chat_history) == 0: - reformulated_query = user_query - else: - reformulated_query = await QueryService.reformulate_query_with_chat_history( - user_query=user_query, - session=state.db_session, - search_space_id=configuration.search_space_id, - chat_history_str=chat_history_str, - ) - - return {"reformulated_query": reformulated_query} - - -async def handle_qna_workflow( - state: State, config: RunnableConfig, writer: StreamWriter -) -> dict[str, Any]: - """ - Handle the QNA research workflow. - - This node fetches relevant documents for the user query and then uses the QNA agent - to generate a comprehensive answer with proper citations. - - Returns: - Dict containing the final answer in the "final_written_report" key for consistency. - """ - streaming_service = state.streaming_service - configuration = Configuration.from_runnable_config(config) - - reformulated_query = state.reformulated_query - user_query = configuration.user_query - - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - "🤔 Starting Q&A research workflow..." - ) - } - ) - - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f'🔍 Researching: "{user_query[:100]}..."' - ) - } - ) - - # Fetch relevant documents for the QNA query - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - "🔍 Searching for relevant information across all connectors..." - ) - } - ) - - # Use the top_k value from configuration - top_k = configuration.top_k - - relevant_documents = [] - user_selected_documents = [] - user_selected_sources = [] - - try: - # First, fetch user-selected documents if any - if configuration.document_ids_to_add_in_context: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"📋 Including {len(configuration.document_ids_to_add_in_context)} user-selected documents..." - ) - } - ) - - ( - user_selected_sources, - user_selected_documents, - ) = await fetch_documents_by_ids( - document_ids=configuration.document_ids_to_add_in_context, - search_space_id=configuration.search_space_id, - db_session=state.db_session, - ) - - if user_selected_documents: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"✅ Successfully added {len(user_selected_documents)} user-selected documents to context" - ) - } - ) - - # Create connector service using state db_session - connector_service = ConnectorService( - state.db_session, search_space_id=configuration.search_space_id - ) - await connector_service.initialize_counter() - - # Use the reformulated query as a single research question - research_questions = [reformulated_query, user_query] - - # Calculate time filter: last 2 years from now (hardcoded for now) - end_date = datetime.now(UTC) - start_date = end_date - timedelta(days=DEFAULT_TIME_FILTER_YEARS * 365) - - relevant_documents = await fetch_relevant_documents( - research_questions=research_questions, - search_space_id=configuration.search_space_id, - db_session=state.db_session, - connectors_to_search=configuration.connectors_to_search, - writer=writer, - state=state, - top_k=top_k, - connector_service=connector_service, - user_selected_sources=user_selected_sources, - start_date=start_date, - end_date=end_date, - ) - except Exception as e: - error_message = f"Error fetching relevant documents for QNA: {e!s}" - print(error_message) - writer({"yield_value": streaming_service.format_error(error_message)}) - # Continue with empty documents - the QNA agent will handle this gracefully - relevant_documents = [] - - # Combine user-selected documents with connector-fetched documents - all_documents = user_selected_documents + relevant_documents - - print(f"Fetched {len(relevant_documents)} relevant documents for QNA") - print(f"Added {len(user_selected_documents)} user-selected documents for QNA") - print(f"Total documents for QNA: {len(all_documents)}") - - # Extract and stream sources from all_documents - if all_documents: - sources_to_stream = extract_sources_from_documents(all_documents) - writer( - {"yield_value": streaming_service.format_sources_delta(sources_to_stream)} - ) - - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"🧠 Generating comprehensive answer using {len(all_documents)} total sources ({len(user_selected_documents)} user-selected + {len(relevant_documents)} connector-found)..." - ) - } - ) - - # Prepare configuration for the QNA agent - qna_config = { - "configurable": { - "user_query": user_query, # Use the reformulated query - "reformulated_query": reformulated_query, - "relevant_documents": all_documents, # Use combined documents - "search_space_id": configuration.search_space_id, - "language": configuration.language, - } - } - - # Create the state for the QNA agent (it has a different state structure) - # Pass streaming_service so the QNA agent can stream tokens directly - qna_state = { - "db_session": state.db_session, - "chat_history": state.chat_history, - "streaming_service": streaming_service, - } - - try: - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - "✍️ Writing comprehensive answer ..." - ) - } - ) - - # Track streaming content for real-time updates - complete_content = "" - captured_reranked_documents = [] - - # Call the QNA agent with both custom and values streaming modes - # - "custom" captures token-by-token streams from answer_question via writer() - # - "values" captures state updates including final_answer and reranked_documents - async for stream_mode, chunk in qna_agent_graph.astream( - qna_state, qna_config, stream_mode=["custom", "values"] - ): - if stream_mode == "custom": - # Handle custom stream events (token chunks from answer_question) - if isinstance(chunk, dict) and "yield_value" in chunk: - # Forward the streamed token to the parent writer - writer(chunk) - elif stream_mode == "values" and isinstance(chunk, dict): - # Handle state value updates - # Capture the final answer from state - if chunk.get("final_answer"): - complete_content = chunk["final_answer"] - - # Capture reranked documents from QNA agent for further question generation - if chunk.get("reranked_documents"): - captured_reranked_documents = chunk["reranked_documents"] - - # Set default if no content was received - if not complete_content: - complete_content = "I couldn't find relevant information in your knowledge base to answer this question." - - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - "🎉 Q&A answer generated successfully!" - ) - } - ) - - # Return the final answer and captured reranked documents for further question generation - return { - "final_written_report": complete_content, - "reranked_documents": captured_reranked_documents, - } - - except Exception as e: - error_message = f"Error generating QNA answer: {e!s}" - print(error_message) - writer({"yield_value": streaming_service.format_error(error_message)}) - - return {"final_written_report": f"Error generating answer: {e!s}"} - - -async def generate_further_questions( - state: State, config: RunnableConfig, writer: StreamWriter -) -> dict[str, Any]: - """ - Generate contextually relevant follow-up questions based on chat history and available documents. - - This node takes the chat history and reranked documents from the QNA agent - and uses an LLM to generate follow-up questions that would naturally extend the conversation - and provide additional value to the user. - - Returns: - Dict containing the further questions in the "further_questions" key for state update. - """ - from app.services.llm_service import get_fast_llm - - # Get configuration and state data - configuration = Configuration.from_runnable_config(config) - chat_history = state.chat_history - search_space_id = configuration.search_space_id - streaming_service = state.streaming_service - - # Get reranked documents from the state (will be populated by sub-agents) - reranked_documents = getattr(state, "reranked_documents", None) or [] - - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - "🤔 Generating follow-up questions..." - ) - } - ) - - # Get search space's fast LLM - llm = await get_fast_llm(state.db_session, search_space_id) - if not llm: - error_message = f"No fast LLM configured for search space {search_space_id}" - print(error_message) - writer({"yield_value": streaming_service.format_error(error_message)}) - - # Stream empty further questions to UI - writer({"yield_value": streaming_service.format_further_questions_delta([])}) - return {"further_questions": []} - - # Format chat history for the prompt - chat_history_xml = "\n" - for message in chat_history: - if hasattr(message, "type"): - if message.type == "human": - chat_history_xml += f"{message.content}\n" - elif message.type == "ai": - chat_history_xml += f"{message.content}\n" - else: - # Handle other message types if needed - chat_history_xml += f"{message!s}\n" - chat_history_xml += "" - - # Format available documents for the prompt - documents_xml = "\n" - for i, doc in enumerate(reranked_documents): - document_info = doc.get("document", {}) - source_id = document_info.get("id", f"doc_{i}") - source_type = document_info.get("document_type", "UNKNOWN") - content = doc.get("content", "") - - documents_xml += "\n" - documents_xml += "\n" - documents_xml += f"{source_id}\n" - documents_xml += f"{source_type}\n" - documents_xml += "\n" - documents_xml += f"\n{content}\n" - documents_xml += "\n" - documents_xml += "" - - # Create the human message content - human_message_content = f""" - {chat_history_xml} - - {documents_xml} - - Based on the chat history and available documents above, generate 3-5 contextually relevant follow-up questions that would naturally extend the conversation and provide additional value to the user. Make sure the questions can be reasonably answered using the available documents or knowledge base. - - Your response MUST be valid JSON in exactly this format: - {{ - "further_questions": [ - {{ - "id": 0, - "question": "further qn 1" - }}, - {{ - "id": 1, - "question": "further qn 2" - }} - ] - }} - - Do not include any other text or explanation. Only return the JSON. - """ - - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - "🧠 Analyzing conversation context to suggest relevant questions..." - ) - } - ) - - # Create messages for the LLM - messages = [ - SystemMessage(content=get_further_questions_system_prompt()), - HumanMessage(content=human_message_content), - ] - - try: - # Call the LLM - response = await llm.ainvoke(messages) - - # Parse the JSON response - content = response.content - - # Find the JSON in the content - json_start = content.find("{") - json_end = content.rfind("}") + 1 - if json_start >= 0 and json_end > json_start: - json_str = content[json_start:json_end] - - # Parse the JSON string - parsed_data = json.loads(json_str) - - # Extract the further_questions array - further_questions = parsed_data.get("further_questions", []) - - writer( - { - "yield_value": streaming_service.format_terminal_info_delta( - f"✅ Generated {len(further_questions)} contextual follow-up questions!" - ) - } - ) - - # Stream the further questions to the UI - writer( - { - "yield_value": streaming_service.format_further_questions_delta( - further_questions - ) - } - ) - - print(f"Successfully generated {len(further_questions)} further questions") - - return {"further_questions": further_questions} - else: - # If JSON structure not found, return empty list - error_message = ( - "Could not find valid JSON in LLM response for further questions" - ) - print(error_message) - writer( - { - "yield_value": streaming_service.format_error( - f"Warning: {error_message}" - ) - } - ) - - # Stream empty further questions to UI - writer( - {"yield_value": streaming_service.format_further_questions_delta([])} - ) - return {"further_questions": []} - - except (json.JSONDecodeError, ValueError) as e: - # Log the error and return empty list - error_message = f"Error parsing further questions response: {e!s}" - print(error_message) - writer( - {"yield_value": streaming_service.format_error(f"Warning: {error_message}")} - ) - - # Stream empty further questions to UI - writer({"yield_value": streaming_service.format_further_questions_delta([])}) - return {"further_questions": []} - - except Exception as e: - # Handle any other errors - error_message = f"Error generating further questions: {e!s}" - print(error_message) - writer( - {"yield_value": streaming_service.format_error(f"Warning: {error_message}")} - ) - - # Stream empty further questions to UI - writer({"yield_value": streaming_service.format_further_questions_delta([])}) - return {"further_questions": []} diff --git a/surfsense_backend/app/agents/researcher/prompts.py b/surfsense_backend/app/agents/researcher/prompts.py deleted file mode 100644 index 794a594f2..000000000 --- a/surfsense_backend/app/agents/researcher/prompts.py +++ /dev/null @@ -1,140 +0,0 @@ -import datetime - - -def _build_language_instruction(language: str | None = None): - """Build language instruction for prompts.""" - if language: - return f"\n\nIMPORTANT: Please respond in {language} language. All your responses, explanations, and analysis should be written in {language}." - return "" - - -def get_further_questions_system_prompt(): - return f""" -Today's date: {datetime.datetime.now().strftime("%Y-%m-%d")} - -You are an expert research assistant specializing in generating contextually relevant follow-up questions. Your task is to analyze the chat history and available documents to suggest further questions that would naturally extend the conversation and provide additional value to the user. - - -- chat_history: Provided in XML format within tags, containing and message pairs that show the chronological conversation flow. This provides context about what has already been discussed. -- available_documents: Provided in XML format within tags, containing individual elements with and sections. Each document contains multiple `...` blocks inside . This helps understand what information is accessible for answering potential follow-up questions. - - - -A JSON object with the following structure: -{{ - "further_questions": [ - {{ - "id": 0, - "question": "further qn 1" - }}, - {{ - "id": 1, - "question": "further qn 2" - }} - ] -}} - - - -1. **Analyze Chat History:** Review the entire conversation flow to understand: - * The main topics and themes discussed - * The user's interests and areas of focus - * Questions that have been asked and answered - * Any gaps or areas that could be explored further - * The depth level of the current discussion - -2. **Evaluate Available Documents:** Consider the documents in context to identify: - * Additional information that hasn't been explored yet - * Related topics that could be of interest - * Specific details or data points that could warrant deeper investigation - * Cross-references or connections between different documents - -3. **Generate Relevant Follow-up Questions:** Create 3-5 further questions that: - * Are directly related to the ongoing conversation but explore new angles - * Can be reasonably answered using the available documents or knowledge base - * Progress the conversation forward rather than repeating previous topics - * Match the user's apparent level of interest and expertise - * Are specific and actionable rather than overly broad - * Consider practical applications, comparisons, deeper analysis, or related concepts - -4. **Ensure Question Quality:** Each question should: - * Be clear and well-formulated - * Provide genuine value to the user - * Be distinct from other suggested questions - * Be answerable within the current context - * Encourage meaningful exploration of the topic - -5. **Prioritize and Order:** Arrange questions by relevance and natural progression: - * Most directly related to the current discussion first - * Questions that build upon previous answers - * Questions that explore practical applications or implications - * Questions that introduce related but new concepts - -6. **Adhere Strictly to Output Format:** Ensure the final output is a valid JSON object with: - * Correct field names (`further_questions`, `id`, `question`) - * Sequential numbering starting from 0 - * Proper data types and JSON formatting - - - -Input: - -What are the best machine learning algorithms for text classification? -For text classification, several algorithms work well depending on your specific needs: - -**Traditional Methods:** -- **Support Vector Machines (SVM)** - Excellent for high-dimensional text data -- **Naive Bayes** - Simple, fast, and works well with small datasets -- **Logistic Regression** - Good baseline with interpretable results - -**Modern Deep Learning:** -- **Neural Networks** - More complex but can capture intricate patterns -- **Transformer models** - State-of-the-art for most text classification tasks - -The choice depends on your dataset size, computational resources, and accuracy requirements. - - - - - -101 -FILE - - -# Machine Learning for Text Classification: A Comprehensive Guide - -## Performance Comparison -Recent studies show that transformer-based models achieve 95%+ accuracy on most text classification benchmarks, while traditional methods like SVM typically achieve 85-90% accuracy. - -## Dataset Considerations -- Small datasets (< 1000 samples): Naive Bayes, SVM -- Large datasets (> 10,000 samples): Neural networks, transformers -- Imbalanced datasets: Require special handling with techniques like SMOTE - - - - -Output: -{{ - "further_questions": [ - {{ - "id": 0, - "question": "What are the key differences in performance between traditional algorithms like SVM and modern deep learning approaches for text classification?" - }}, - {{ - "id": 1, - "question": "How do you handle imbalanced datasets when training text classification models?" - }}, - {{ - "id": 2, - "question": "What preprocessing techniques are most effective for improving text classification accuracy?" - }}, - {{ - "id": 3, - "question": "Are there specific domains or use cases where certain classification algorithms perform better than others?" - }} - ] -}} - - -""" diff --git a/surfsense_backend/app/agents/researcher/qna_agent/__init__.py b/surfsense_backend/app/agents/researcher/qna_agent/__init__.py deleted file mode 100644 index 163b8bf63..000000000 --- a/surfsense_backend/app/agents/researcher/qna_agent/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""QnA Agent.""" - -from .graph import graph - -__all__ = ["graph"] diff --git a/surfsense_backend/app/agents/researcher/qna_agent/configuration.py b/surfsense_backend/app/agents/researcher/qna_agent/configuration.py deleted file mode 100644 index e7dd9175e..000000000 --- a/surfsense_backend/app/agents/researcher/qna_agent/configuration.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Define the configurable parameters for the agent.""" - -from __future__ import annotations - -from dataclasses import dataclass, fields -from typing import Any - -from langchain_core.runnables import RunnableConfig - - -@dataclass(kw_only=True) -class Configuration: - """The configuration for the Q&A agent.""" - - # Configuration parameters for the Q&A agent - user_query: str # The user's question to answer - reformulated_query: str # The reformulated query - relevant_documents: list[ - Any - ] # Documents provided directly to the agent for answering - search_space_id: int # Search space identifier - language: str | None = None # Language for responses - - @classmethod - def from_runnable_config( - cls, config: RunnableConfig | None = None - ) -> Configuration: - """Create a Configuration instance from a RunnableConfig object.""" - configurable = (config.get("configurable") or {}) if config else {} - _fields = {f.name for f in fields(cls) if f.init} - return cls(**{k: v for k, v in configurable.items() if k in _fields}) diff --git a/surfsense_backend/app/agents/researcher/qna_agent/default_prompts.py b/surfsense_backend/app/agents/researcher/qna_agent/default_prompts.py deleted file mode 100644 index 72ae636cb..000000000 --- a/surfsense_backend/app/agents/researcher/qna_agent/default_prompts.py +++ /dev/null @@ -1,201 +0,0 @@ -"""Default system prompts for Q&A agent. - -The prompt system is modular with 3 parts: -- Part 1 (Base): Core instructions for answering questions (no citations) -- Part 2 (Citations): Citation-specific instructions and formatting rules -- Part 3 (Custom): User's custom instructions (empty by default) - -Combinations: -- Part 1 only: Answers without citations -- Part 1 + Part 2: Answers with citations -- Part 1 + Part 2 + Part 3: Answers with citations and custom instructions -""" - -# Part 1: Base system prompt for answering without citations -DEFAULT_QNA_BASE_PROMPT = """Today's date: {date} -You are SurfSense, an advanced AI research assistant that provides detailed, well-researched answers to user questions by synthesizing information from multiple personal knowledge sources.{language_instruction} -{chat_history_section} - -- EXTENSION: "Web content saved via SurfSense browser extension" (personal browsing history) -- FILE: "User-uploaded documents (PDFs, Word, etc.)" (personal files) -- SLACK_CONNECTOR: "Slack conversations and shared content" (personal workspace communications) -- NOTION_CONNECTOR: "Notion workspace pages and databases" (personal knowledge management) -- YOUTUBE_VIDEO: "YouTube video transcripts and metadata" (personally saved videos) -- GITHUB_CONNECTOR: "GitHub repository content and issues" (personal repositories and interactions) -- ELASTICSEARCH_CONNECTOR: "Elasticsearch indexed documents and data" (personal Elasticsearch instances and custom data sources) -- LINEAR_CONNECTOR: "Linear project issues and discussions" (personal project management) -- JIRA_CONNECTOR: "Jira project issues, tickets, and comments" (personal project tracking) -- CONFLUENCE_CONNECTOR: "Confluence pages and comments" (personal project documentation) -- CLICKUP_CONNECTOR: "ClickUp tasks and project data" (personal task management) -- GOOGLE_CALENDAR_CONNECTOR: "Google Calendar events, meetings, and schedules" (personal calendar and time management) -- GOOGLE_GMAIL_CONNECTOR: "Google Gmail emails and conversations" (personal emails and communications) -- DISCORD_CONNECTOR: "Discord server conversations and shared content" (personal community communications) -- AIRTABLE_CONNECTOR: "Airtable records, tables, and database content" (personal data management and organization) -- TAVILY_API: "Tavily search API results" (personalized search results) -- LINKUP_API: "Linkup search API results" (personalized search results) -- LUMA_CONNECTOR: "Luma events" -- WEBCRAWLER_CONNECTOR: "Webpages indexed by SurfSense" (personally selected websites) - - - -1. Review the chat history to understand the conversation context and any previous topics discussed. -2. Carefully analyze all provided documents in the sections. -3. Extract relevant information that directly addresses the user's question. -4. Provide a comprehensive, detailed answer using information from the user's personal knowledge sources. -5. Structure your answer logically and conversationally, as if having a detailed discussion with the user. -6. Use your own words to synthesize and connect ideas from the documents. -7. If documents contain conflicting information, acknowledge this and present both perspectives. -8. If the user's question cannot be fully answered with the provided documents, clearly state what information is missing. -9. Provide actionable insights and practical information when relevant to the user's question. -10. Use the chat history to maintain conversation continuity and refer to previous discussions when relevant. -11. Remember that all knowledge sources contain personal information - provide answers that reflect this personal context. -12. Be conversational and engaging while maintaining accuracy. - - - -- Write in a clear, conversational tone suitable for detailed Q&A discussions -- Provide comprehensive answers that thoroughly address the user's question -- Use appropriate paragraphs and structure for readability -- ALWAYS provide personalized answers that reflect the user's own knowledge and context -- Be thorough and detailed in your explanations while remaining focused on the user's specific question -- If asking follow-up questions would be helpful, suggest them at the end of your response - - - -When you see a user query, focus exclusively on providing a detailed, comprehensive answer using information from the provided documents, which contain the user's personal knowledge and data. - -Make sure your response: -1. Considers the chat history for context and conversation continuity -2. Directly and thoroughly answers the user's question with personalized information from their own knowledge sources -3. Is conversational, engaging, and detailed -4. Acknowledges the personal nature of the information being provided -5. Offers follow-up suggestions when appropriate - -""" - -# Part 2: Citation-specific instructions to add citation capabilities -DEFAULT_QNA_CITATION_INSTRUCTIONS = """ - -CRITICAL CITATION REQUIREMENTS: - -1. For EVERY piece of information you include from the documents, add a citation in the format [citation:chunk_id] where chunk_id is the exact value from the `` tag inside ``. -2. Make sure ALL factual statements from the documents have proper citations. -3. If multiple chunks support the same point, include all relevant citations [citation:chunk_id1], [citation:chunk_id2]. -4. You MUST use the exact chunk_id values from the `` attributes. Do not create your own citation numbers. -5. Every citation MUST be in the format [citation:chunk_id] where chunk_id is the exact chunk id value. -6. Never modify or change the chunk_id - always use the original values exactly as provided in the chunk tags. -7. Do not return citations as clickable links. -8. Never format citations as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only. -9. Citations must ONLY appear as [citation:chunk_id] or [citation:chunk_id1], [citation:chunk_id2] format - never with parentheses, hyperlinks, or other formatting. -10. Never make up chunk IDs. Only use chunk_id values that are explicitly provided in the `` tags. -11. If you are unsure about a chunk_id, do not include a citation rather than guessing or making one up. - - -The documents you receive are structured like this: - - - - 42 - GITHUB_CONNECTOR - <![CDATA[Some repo / file / issue title]]> - - - - - - - - - - -IMPORTANT: You MUST cite using the chunk ids (e.g. 123, 124). Do NOT cite document_id. - - - -- Every fact from the documents must have a citation in the format [citation:chunk_id] where chunk_id is the EXACT id value from a `` tag -- Citations should appear at the end of the sentence containing the information they support -- Multiple citations should be separated by commas: [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3] -- No need to return references section. Just citations in answer. -- NEVER create your own citation format - use the exact chunk_id values from the documents in the [citation:chunk_id] format -- NEVER format citations as clickable links or as markdown links like "([citation:5](https://example.com))". Always use plain square brackets only -- NEVER make up chunk IDs if you are unsure about the chunk_id. It is better to omit the citation than to guess - - - -CORRECT citation formats: -- [citation:5] -- [citation:chunk_id1], [citation:chunk_id2], [citation:chunk_id3] - -INCORRECT citation formats (DO NOT use): -- Using parentheses and markdown links: ([citation:5](https://github.com/MODSetter/SurfSense)) -- Using parentheses around brackets: ([citation:5]) -- Using hyperlinked text: [link to source 5](https://example.com) -- Using footnote style: ... library¹ -- Making up source IDs when source_id is unknown -- Using old IEEE format: [1], [2], [3] -- Using source types instead of IDs: [citation:GITHUB_CONNECTOR] instead of [citation:5] - - - -Based on your GitHub repositories and video content, Python's asyncio library provides tools for writing concurrent code using the async/await syntax [citation:5]. It's particularly useful for I/O-bound and high-level structured network code [citation:5]. - -The key advantage of asyncio is that it can improve performance by allowing other code to run while waiting for I/O operations to complete [citation:12]. This makes it excellent for scenarios like web scraping, API calls, database operations, or any situation where your program spends time waiting for external resources. - -However, from your video learning, it's important to note that asyncio is not suitable for CPU-bound tasks as it runs on a single thread [citation:12]. For computationally intensive work, you'd want to use multiprocessing instead. - - -""" - -# Part 3: User's custom instructions (empty by default, can be set by user from UI) -DEFAULT_QNA_CUSTOM_INSTRUCTIONS = "" - -# Full prompt with all parts combined (for backward compatibility and migration) -DEFAULT_QNA_CITATION_PROMPT = ( - DEFAULT_QNA_BASE_PROMPT - + DEFAULT_QNA_CITATION_INSTRUCTIONS - + DEFAULT_QNA_CUSTOM_INSTRUCTIONS -) - -DEFAULT_QNA_NO_DOCUMENTS_PROMPT = """Today's date: {date} -You are SurfSense, an advanced AI research assistant that provides helpful, detailed answers to user questions in a conversational manner.{language_instruction} -{chat_history_section} - -The user has asked a question but there are no specific documents from their personal knowledge base available to answer it. You should provide a helpful response based on: -1. The conversation history and context -2. Your general knowledge and expertise -3. Understanding of the user's needs and interests based on our conversation - - - -1. Provide a comprehensive, helpful answer to the user's question -2. Draw upon the conversation history to understand context and the user's specific needs -3. Use your general knowledge to provide accurate, detailed information -4. Be conversational and engaging, as if having a detailed discussion with the user -5. Acknowledge when you're drawing from general knowledge rather than their personal sources -6. Provide actionable insights and practical information when relevant -7. Structure your answer logically and clearly -8. If the question would benefit from personalized information from their knowledge base, gently suggest they might want to add relevant content to SurfSense -9. Be honest about limitations while still being maximally helpful -10. Maintain the helpful, knowledgeable tone that users expect from SurfSense - - - -- Write in a clear, conversational tone suitable for detailed Q&A discussions -- Provide comprehensive answers that thoroughly address the user's question -- Use appropriate paragraphs and structure for readability -- No citations are needed since you're using general knowledge -- Be thorough and detailed in your explanations while remaining focused on the user's specific question -- If asking follow-up questions would be helpful, suggest them at the end of your response -- When appropriate, mention that adding relevant content to their SurfSense knowledge base could provide more personalized answers - - - -When answering the user's question without access to their personal documents: -1. Review the chat history to understand conversation context and maintain continuity -2. Provide the most helpful and comprehensive answer possible using general knowledge -3. Be conversational and engaging -4. Draw upon conversation history for context -5. Be clear that you're providing general information -6. Suggest ways the user could get more personalized answers by expanding their knowledge base when relevant - -""" diff --git a/surfsense_backend/app/agents/researcher/qna_agent/graph.py b/surfsense_backend/app/agents/researcher/qna_agent/graph.py deleted file mode 100644 index 0d9c8bac8..000000000 --- a/surfsense_backend/app/agents/researcher/qna_agent/graph.py +++ /dev/null @@ -1,21 +0,0 @@ -from langgraph.graph import StateGraph - -from .configuration import Configuration -from .nodes import answer_question, rerank_documents -from .state import State - -# Define a new graph -workflow = StateGraph(State, config_schema=Configuration) - -# Add the nodes to the graph -workflow.add_node("rerank_documents", rerank_documents) -workflow.add_node("answer_question", answer_question) - -# Connect the nodes -workflow.add_edge("__start__", "rerank_documents") -workflow.add_edge("rerank_documents", "answer_question") -workflow.add_edge("answer_question", "__end__") - -# Compile the workflow into an executable graph -graph = workflow.compile() -graph.name = "SurfSense QnA Agent" # This defines the custom name in LangSmith diff --git a/surfsense_backend/app/agents/researcher/qna_agent/nodes.py b/surfsense_backend/app/agents/researcher/qna_agent/nodes.py deleted file mode 100644 index 28c35a20b..000000000 --- a/surfsense_backend/app/agents/researcher/qna_agent/nodes.py +++ /dev/null @@ -1,297 +0,0 @@ -import datetime -from typing import Any - -from langchain_core.messages import HumanMessage, SystemMessage -from langchain_core.runnables import RunnableConfig -from langgraph.types import StreamWriter -from sqlalchemy import select - -from app.db import SearchSpace -from app.services.reranker_service import RerankerService - -from ..utils import ( - calculate_token_count, - format_documents_section, - langchain_chat_history_to_str, - optimize_documents_for_token_limit, -) -from .configuration import Configuration -from .default_prompts import ( - DEFAULT_QNA_BASE_PROMPT, - DEFAULT_QNA_CITATION_INSTRUCTIONS, - DEFAULT_QNA_NO_DOCUMENTS_PROMPT, -) -from .state import State - - -def _build_language_instruction(language: str | None = None): - """Build language instruction for prompts.""" - if language: - return f"\n\nIMPORTANT: Please respond in {language} language. All your responses, explanations, and analysis should be written in {language}." - return "" - - -def _build_chat_history_section(chat_history: str | None = None): - """Build chat history section for prompts.""" - if chat_history: - return f""" - -{chat_history if chat_history else "NO CHAT HISTORY PROVIDED"} - -""" - return """ - -NO CHAT HISTORY PROVIDED - -""" - - -def _format_system_prompt( - prompt_template: str, - chat_history: str | None = None, - language: str | None = None, -): - """Format a system prompt template with dynamic values.""" - date = datetime.datetime.now().strftime("%Y-%m-%d") - language_instruction = _build_language_instruction(language) - chat_history_section = _build_chat_history_section(chat_history) - - return prompt_template.format( - date=date, - language_instruction=language_instruction, - chat_history_section=chat_history_section, - ) - - -async def rerank_documents(state: State, config: RunnableConfig) -> dict[str, Any]: - """ - Rerank the documents based on relevance to the user's question. - - This node takes the relevant documents provided in the configuration, - reranks them using the reranker service based on the user's query, - and updates the state with the reranked documents. - - Documents are now document-grouped with a `chunks` list. Reranking is done - using the concatenated `content` field, and the full structure (including - `chunks`) is preserved for proper citation formatting. - - If reranking is disabled, returns the original documents without processing. - - Returns: - Dict containing the reranked documents. - """ - # Get configuration and relevant documents - configuration = Configuration.from_runnable_config(config) - documents = configuration.relevant_documents - user_query = configuration.user_query - reformulated_query = configuration.reformulated_query - - # If no documents were provided, return empty list - if not documents or len(documents) == 0: - return {"reranked_documents": []} - - # Get reranker service from app config - reranker_service = RerankerService.get_reranker_instance() - - # If reranking is not enabled, sort by existing score and return - if not reranker_service: - print("Reranking is disabled. Sorting documents by existing score.") - sorted_documents = sorted( - documents, key=lambda x: x.get("score", 0), reverse=True - ) - return {"reranked_documents": sorted_documents} - - # Perform reranking - try: - # Pass documents directly to reranker - it will use: - # - "content" (concatenated chunk text) for scoring - # - "chunk_id" (primary chunk id) for matching - # The full document structure including "chunks" is preserved - reranked_docs = reranker_service.rerank_documents( - user_query + "\n" + reformulated_query, documents - ) - - # Sort by score in descending order - reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True) - - print(f"Reranked {len(reranked_docs)} documents for Q&A query: {user_query}") - - return {"reranked_documents": reranked_docs} - - except Exception as e: - print(f"Error during reranking: {e!s}") - # Fall back to original documents if reranking fails - return {"reranked_documents": documents} - - -async def answer_question( - state: State, config: RunnableConfig, writer: StreamWriter -) -> dict[str, Any]: - """ - Answer the user's question using the provided documents with real-time streaming. - - This node takes the relevant documents provided in the configuration and uses - an LLM to generate a comprehensive answer to the user's question with - proper citations. The citations follow [citation:chunk_id] format using chunk IDs from the - `` tags in the provided documents. If no documents are provided, it will use chat history to generate - an answer. - - The response is streamed token-by-token for real-time updates to the frontend. - - Returns: - Dict containing the final answer in the "final_answer" key. - """ - from app.services.llm_service import get_fast_llm - - # Get configuration and relevant documents from configuration - configuration = Configuration.from_runnable_config(config) - documents = state.reranked_documents - user_query = configuration.user_query - search_space_id = configuration.search_space_id - language = configuration.language - - # Get streaming service from state - streaming_service = state.streaming_service - - # Fetch search space to get QnA configuration - result = await state.db_session.execute( - select(SearchSpace).where(SearchSpace.id == search_space_id) - ) - search_space = result.scalar_one_or_none() - - if not search_space: - error_message = f"Search space {search_space_id} not found" - print(error_message) - raise RuntimeError(error_message) - - # Get QnA configuration from search space - citations_enabled = search_space.citations_enabled - custom_instructions_text = search_space.qna_custom_instructions or "" - - # Use constants for base prompt and citation instructions - qna_base_prompt = DEFAULT_QNA_BASE_PROMPT - qna_citation_instructions = ( - DEFAULT_QNA_CITATION_INSTRUCTIONS if citations_enabled else "" - ) - qna_custom_instructions = ( - f"\n\n{custom_instructions_text}\n" - if custom_instructions_text - else "" - ) - - # Get search space's fast LLM - llm = await get_fast_llm(state.db_session, search_space_id) - if not llm: - error_message = f"No fast LLM configured for search space {search_space_id}" - print(error_message) - raise RuntimeError(error_message) - - # Determine if we have documents and optimize for token limits - has_documents_initially = documents and len(documents) > 0 - chat_history_str = langchain_chat_history_to_str(state.chat_history) - - if has_documents_initially: - # Compose the full citation prompt: base + citation instructions + custom instructions - full_citation_prompt_template = ( - qna_base_prompt + qna_citation_instructions + qna_custom_instructions - ) - - # Create base message template for token calculation (without documents) - base_human_message_template = f""" - - User's question: - - {user_query} - - - Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner. - """ - - # Use initial system prompt for token calculation - initial_system_prompt = _format_system_prompt( - full_citation_prompt_template, chat_history_str, language - ) - base_messages = [ - SystemMessage(content=initial_system_prompt), - HumanMessage(content=base_human_message_template), - ] - - # Optimize documents to fit within token limits - optimized_documents, has_optimized_documents = ( - optimize_documents_for_token_limit(documents, base_messages, llm.model) - ) - - # Update state based on optimization result - documents = optimized_documents - has_documents = has_optimized_documents - else: - has_documents = False - - # Choose system prompt based on final document availability - # With documents: use base + citation instructions + custom instructions - # Without documents: use the default no-documents prompt from constants - if has_documents: - full_citation_prompt_template = ( - qna_base_prompt + qna_citation_instructions + qna_custom_instructions - ) - system_prompt = _format_system_prompt( - full_citation_prompt_template, chat_history_str, language - ) - else: - system_prompt = _format_system_prompt( - DEFAULT_QNA_NO_DOCUMENTS_PROMPT + qna_custom_instructions, - chat_history_str, - language, - ) - - # Generate documents section - documents_text = ( - format_documents_section( - documents, "Source material from your personal knowledge base" - ) - if has_documents - else "" - ) - - # Create final human message content - instruction_text = ( - "Please provide a detailed, comprehensive answer to the user's question using the information from their personal knowledge sources. Make sure to cite all information appropriately and engage in a conversational manner." - if has_documents - else "Please provide a helpful answer to the user's question based on our conversation history and your general knowledge. Engage in a conversational manner." - ) - - human_message_content = f""" - {documents_text} - - User's question: - - {user_query} - - - {instruction_text} - """ - - # Create final messages for the LLM - messages_with_chat_history = [ - SystemMessage(content=system_prompt), - HumanMessage(content=human_message_content), - ] - - # Log final token count - total_tokens = calculate_token_count(messages_with_chat_history, llm.model) - print(f"Final token count: {total_tokens}") - - # Stream the LLM response token by token - final_answer = "" - - async for chunk in llm.astream(messages_with_chat_history): - # Extract the content from the chunk - if hasattr(chunk, "content") and chunk.content: - token = chunk.content - final_answer += token - - # Stream the token to the frontend via custom stream - if streaming_service: - writer({"yield_value": streaming_service.format_text_chunk(token)}) - - return {"final_answer": final_answer} diff --git a/surfsense_backend/app/agents/researcher/qna_agent/state.py b/surfsense_backend/app/agents/researcher/qna_agent/state.py deleted file mode 100644 index 4113b9286..000000000 --- a/surfsense_backend/app/agents/researcher/qna_agent/state.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Define the state structures for the agent.""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Any - -from sqlalchemy.ext.asyncio import AsyncSession - -from app.services.streaming_service import StreamingService - - -@dataclass -class State: - """Defines the dynamic state for the Q&A agent during execution. - - This state tracks the database session, chat history, and the outputs - generated by the agent's nodes during question answering. - See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state - for more information. - """ - - # Runtime context - db_session: AsyncSession - - # Streaming service for real-time token streaming - streaming_service: StreamingService | None = None - - chat_history: list[Any] | None = field(default_factory=list) - # OUTPUT: Populated by agent nodes - reranked_documents: list[Any] | None = None - final_answer: str | None = None diff --git a/surfsense_backend/app/agents/researcher/state.py b/surfsense_backend/app/agents/researcher/state.py deleted file mode 100644 index 90f7039be..000000000 --- a/surfsense_backend/app/agents/researcher/state.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Define the state structures for the agent.""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Any - -from sqlalchemy.ext.asyncio import AsyncSession - -from app.services.streaming_service import StreamingService - - -@dataclass -class State: - """Defines the dynamic state for the agent during execution. - - This state tracks the database session and the outputs generated by the agent's nodes. - See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state - for more information. - """ - - # Runtime context (not part of actual graph state) - db_session: AsyncSession - - # Streaming service - streaming_service: StreamingService - - chat_history: list[Any] | None = field(default_factory=list) - - reformulated_query: str | None = field(default=None) - further_questions: Any | None = field(default=None) - - # Temporary field to hold reranked documents from sub-agents for further question generation - reranked_documents: list[Any] | None = field(default=None) - - # OUTPUT: Populated by agent nodes - # Using field to explicitly mark as part of state - final_written_report: str | None = field(default=None) diff --git a/surfsense_backend/app/agents/researcher/utils.py b/surfsense_backend/app/agents/researcher/utils.py deleted file mode 100644 index 9cb0dcbde..000000000 --- a/surfsense_backend/app/agents/researcher/utils.py +++ /dev/null @@ -1,292 +0,0 @@ -import json -from typing import Any, NamedTuple - -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage -from litellm import get_model_info, token_counter - - -class DocumentTokenInfo(NamedTuple): - """Information about a document and its token cost.""" - - index: int - document: dict[str, Any] - formatted_content: str - token_count: int - - -def get_connector_emoji(connector_name: str) -> str: - """Get an appropriate emoji for a connector type.""" - connector_emojis = { - "YOUTUBE_VIDEO": "📹", - "EXTENSION": "🧩", - "FILE": "📄", - "SLACK_CONNECTOR": "💬", - "NOTION_CONNECTOR": "📘", - "GITHUB_CONNECTOR": "🐙", - "LINEAR_CONNECTOR": "📊", - "JIRA_CONNECTOR": "🎫", - "DISCORD_CONNECTOR": "🗨️", - "TAVILY_API": "🔍", - "LINKUP_API": "🔗", - "BAIDU_SEARCH_API": "🇨🇳", - "GOOGLE_CALENDAR_CONNECTOR": "📅", - "AIRTABLE_CONNECTOR": "🗃️", - "LUMA_CONNECTOR": "✨", - "ELASTICSEARCH_CONNECTOR": "⚡", - "WEBCRAWLER_CONNECTOR": "🌐", - "BOOKSTACK_CONNECTOR": "📚", - "NOTE": "📝", - } - return connector_emojis.get(connector_name, "🔎") - - -def get_connector_friendly_name(connector_name: str) -> str: - """Convert technical connector IDs to user-friendly names.""" - connector_friendly_names = { - "YOUTUBE_VIDEO": "YouTube", - "EXTENSION": "Browser Extension", - "FILE": "Files", - "SLACK_CONNECTOR": "Slack", - "NOTION_CONNECTOR": "Notion", - "GITHUB_CONNECTOR": "GitHub", - "LINEAR_CONNECTOR": "Linear", - "JIRA_CONNECTOR": "Jira", - "CONFLUENCE_CONNECTOR": "Confluence", - "GOOGLE_CALENDAR_CONNECTOR": "Google Calendar", - "DISCORD_CONNECTOR": "Discord", - "TAVILY_API": "Tavily Search", - "LINKUP_API": "Linkup Search", - "BAIDU_SEARCH_API": "Baidu Search", - "AIRTABLE_CONNECTOR": "Airtable", - "LUMA_CONNECTOR": "Luma", - "ELASTICSEARCH_CONNECTOR": "Elasticsearch", - "WEBCRAWLER_CONNECTOR": "Web Pages", - "BOOKSTACK_CONNECTOR": "BookStack", - "NOTE": "Notes", - } - return connector_friendly_names.get(connector_name, connector_name) - - -def convert_langchain_messages_to_dict( - messages: list[BaseMessage], -) -> list[dict[str, str]]: - """Convert LangChain messages to format expected by token_counter.""" - role_mapping = {"system": "system", "human": "user", "ai": "assistant"} - - converted_messages = [] - for msg in messages: - role = role_mapping.get(getattr(msg, "type", None), "user") - converted_messages.append({"role": role, "content": str(msg.content)}) - - return converted_messages - - -def format_document_for_citation(document: dict[str, Any]) -> str: - """Format a single document for citation in the new document+chunks XML format. - - IMPORTANT: - - Citations must reference real DB chunk IDs: `[citation:]` - - Document metadata is included under , but citations are NOT document_id-based. - """ - - def _to_cdata(value: Any) -> str: - text = "" if value is None else str(value) - # Safely nest CDATA even if the content includes "]]>" - return "", "]]]]>") + "]]>" - - doc_info = document.get("document", {}) or {} - metadata = doc_info.get("metadata", {}) or {} - - doc_id = doc_info.get("id", "") - title = doc_info.get("title", "") - document_type = doc_info.get("document_type", "CRAWLED_URL") - url = ( - metadata.get("url") - or metadata.get("source") - or metadata.get("page_url") - or metadata.get("VisitedWebPageURL") - or "" - ) - - metadata_json = json.dumps(metadata, ensure_ascii=False) - - chunks = document.get("chunks") or [] - if not chunks: - # Fallback: treat `content` as a single chunk (no chunk_id available for citation) - chunks = [{"chunk_id": "", "content": document.get("content", "")}] - - chunks_xml = "\n".join( - [ - f"{_to_cdata(chunk.get('content', ''))}" - for chunk in chunks - ] - ) - - return f""" - -{doc_id} -{document_type} -{_to_cdata(title)} -{_to_cdata(url)} -{_to_cdata(metadata_json)} - - - -{chunks_xml} - -""" - - -def format_documents_section( - documents: list[dict[str, Any]], section_title: str = "Source material" -) -> str: - """Format multiple documents into a complete documents section.""" - if not documents: - return "" - - formatted_docs = [format_document_for_citation(doc) for doc in documents] - - return f"""{section_title}: - - {chr(10).join(formatted_docs)} - """ - - -def calculate_document_token_costs( - documents: list[dict[str, Any]], model: str -) -> list[DocumentTokenInfo]: - """Pre-calculate token costs for each document.""" - document_token_info = [] - - for i, doc in enumerate(documents): - formatted_doc = format_document_for_citation(doc) - - # Calculate token count for this document - token_count = token_counter( - messages=[{"role": "user", "content": formatted_doc}], model=model - ) - - document_token_info.append( - DocumentTokenInfo( - index=i, - document=doc, - formatted_content=formatted_doc, - token_count=token_count, - ) - ) - - return document_token_info - - -def find_optimal_documents_with_binary_search( - document_tokens: list[DocumentTokenInfo], available_tokens: int -) -> list[DocumentTokenInfo]: - """Use binary search to find the maximum number of documents that fit within token limit.""" - if not document_tokens or available_tokens <= 0: - return [] - - left, right = 0, len(document_tokens) - optimal_docs = [] - - while left <= right: - mid = (left + right) // 2 - current_docs = document_tokens[:mid] - current_token_sum = sum(doc_info.token_count for doc_info in current_docs) - - if current_token_sum <= available_tokens: - optimal_docs = current_docs - left = mid + 1 - else: - right = mid - 1 - - return optimal_docs - - -def get_model_context_window(model_name: str) -> int: - """Get the total context window size for a model (input + output tokens).""" - try: - model_info = get_model_info(model_name) - context_window = model_info.get("max_input_tokens", 4096) # Default fallback - return context_window - except Exception as e: - print( - f"Warning: Could not get model info for {model_name}, using default 4096 tokens. Error: {e}" - ) - return 4096 # Conservative fallback - - -def optimize_documents_for_token_limit( - documents: list[dict[str, Any]], base_messages: list[BaseMessage], model_name: str -) -> tuple[list[dict[str, Any]], bool]: - """ - Optimize documents to fit within token limits using binary search. - - Args: - documents: List of documents with content and metadata - base_messages: Base messages without documents (chat history + system + human message template) - model_name: Model name for token counting (required) - output_token_buffer: Number of tokens to reserve for model output - - Returns: - Tuple of (optimized_documents, has_documents_remaining) - """ - if not documents: - return [], False - - model = model_name - context_window = get_model_context_window(model) - - # Calculate base token cost - base_messages_dict = convert_langchain_messages_to_dict(base_messages) - base_tokens = token_counter(messages=base_messages_dict, model=model) - available_tokens_for_docs = context_window - base_tokens - - print( - f"Token optimization: Context window={context_window}, Base={base_tokens}, Available for docs={available_tokens_for_docs}" - ) - - if available_tokens_for_docs <= 0: - print("No tokens available for documents after base content and output buffer") - return [], False - - # Calculate token costs for all documents - document_token_info = calculate_document_token_costs(documents, model) - - # Find optimal number of documents using binary search - optimal_doc_info = find_optimal_documents_with_binary_search( - document_token_info, available_tokens_for_docs - ) - - # Extract the original document objects - optimized_documents = [doc_info.document for doc_info in optimal_doc_info] - has_documents_remaining = len(optimized_documents) > 0 - - print( - f"Token optimization result: Using {len(optimized_documents)}/{len(documents)} documents" - ) - - return optimized_documents, has_documents_remaining - - -def calculate_token_count(messages: list[BaseMessage], model_name: str) -> int: - """Calculate token count for a list of LangChain messages.""" - model = model_name - messages_dict = convert_langchain_messages_to_dict(messages) - return token_counter(messages=messages_dict, model=model) - - -def langchain_chat_history_to_str(chat_history: list[BaseMessage]) -> str: - """ - Convert a list of chat history messages to a string. - """ - chat_history_str = "" - - for chat_message in chat_history: - if isinstance(chat_message, HumanMessage): - chat_history_str += f"{chat_message.content}\n" - elif isinstance(chat_message, AIMessage): - chat_history_str += f"{chat_message.content}\n" - elif isinstance(chat_message, SystemMessage): - chat_history_str += f"{chat_message.content}\n" - - return chat_history_str diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 9b7811022..8a3a7e30e 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -77,10 +77,6 @@ class SearchSourceConnectorType(str, Enum): BOOKSTACK_CONNECTOR = "BOOKSTACK_CONNECTOR" -class ChatType(str, Enum): - QNA = "QNA" - - class LiteLLMProvider(str, Enum): """ Enum for LLM providers supported by LiteLLM. @@ -317,21 +313,6 @@ class BaseModel(Base): id = Column(Integer, primary_key=True, index=True) -class Chat(BaseModel, TimestampMixin): - __tablename__ = "chats" - - type = Column(SQLAlchemyEnum(ChatType), nullable=False) - title = Column(String, nullable=False, index=True) - initial_connectors = Column(ARRAY(String), nullable=True) - messages = Column(JSON, nullable=False) - state_version = Column(BigInteger, nullable=False, default=1) - - search_space_id = Column( - Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False - ) - search_space = relationship("SearchSpace", back_populates="chats") - - class NewChatMessageRole(str, Enum): """Role enum for new chat messages.""" @@ -363,9 +344,6 @@ class NewChatThread(BaseModel, TimestampMixin): search_space_id = Column( Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False ) - user_id = Column( - UUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False - ) # Relationships search_space = relationship("SearchSpace", back_populates="new_chat_threads") @@ -445,23 +423,6 @@ class Chunk(BaseModel, TimestampMixin): document = relationship("Document", back_populates="chunks") -class Podcast(BaseModel, TimestampMixin): - __tablename__ = "podcasts" - - title = Column(String, nullable=False, index=True) - podcast_transcript = Column(JSON, nullable=False, default={}) - file_location = Column(String(500), nullable=False, default="") - chat_id = Column( - Integer, ForeignKey("chats.id", ondelete="CASCADE"), nullable=True - ) # If generated from a chat, this will be the chat id, else null ( can be from a document or a chat ) - chat_state_version = Column(BigInteger, nullable=True) - - search_space_id = Column( - Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False - ) - search_space = relationship("SearchSpace", back_populates="podcasts") - - class SearchSpace(BaseModel, TimestampMixin): __tablename__ = "searchspaces" @@ -492,18 +453,6 @@ class SearchSpace(BaseModel, TimestampMixin): order_by="Document.id", cascade="all, delete-orphan", ) - podcasts = relationship( - "Podcast", - back_populates="search_space", - order_by="Podcast.id", - cascade="all, delete-orphan", - ) - chats = relationship( - "Chat", - back_populates="search_space", - order_by="Chat.id", - cascade="all, delete-orphan", - ) new_chat_threads = relationship( "NewChatThread", back_populates="search_space", diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index 5430e8b1e..732693eb5 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -3,7 +3,6 @@ from fastapi import APIRouter from .airtable_add_connector_route import ( router as airtable_add_connector_router, ) -from .chats_routes import router as chats_router from .documents_routes import router as documents_router from .editor_routes import router as editor_router from .google_calendar_add_connector_route import ( @@ -17,7 +16,6 @@ from .logs_routes import router as logs_router from .luma_add_connector_route import router as luma_add_connector_router from .new_chat_routes import router as new_chat_router from .notes_routes import router as notes_router -from .podcasts_routes import router as podcasts_router from .rbac_routes import router as rbac_router from .search_source_connectors_routes import router as search_source_connectors_router from .search_spaces_routes import router as search_spaces_router @@ -29,9 +27,7 @@ router.include_router(rbac_router) # RBAC routes for roles, members, invites router.include_router(editor_router) router.include_router(documents_router) router.include_router(notes_router) -router.include_router(podcasts_router) -router.include_router(chats_router) -router.include_router(new_chat_router) # New chat with assistant-ui persistence +router.include_router(new_chat_router) # Chat with assistant-ui persistence router.include_router(search_source_connectors_router) router.include_router(google_calendar_add_connector_router) router.include_router(google_gmail_add_connector_router) diff --git a/surfsense_backend/app/routes/chats_routes.py b/surfsense_backend/app/routes/chats_routes.py deleted file mode 100644 index 2a65b637c..000000000 --- a/surfsense_backend/app/routes/chats_routes.py +++ /dev/null @@ -1,617 +0,0 @@ -from fastapi import APIRouter, Depends, HTTPException -from fastapi.responses import StreamingResponse -from langchain_core.messages import AIMessage, HumanMessage -from sqlalchemy.exc import IntegrityError, OperationalError -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select -from sqlalchemy.orm import selectinload - -from app.db import ( - Chat, - Permission, - SearchSpace, - SearchSpaceMembership, - User, - get_async_session, -) -from app.schemas import ( - AISDKChatRequest, - ChatCreate, - ChatRead, - ChatReadWithoutMessages, - ChatUpdate, - NewChatRequest, -) -from app.services.new_streaming_service import VercelStreamingService -from app.tasks.chat.stream_connector_search_results import ( - stream_connector_search_results, -) -from app.tasks.chat.stream_new_chat import stream_new_chat -from app.users import current_active_user -from app.utils.rbac import check_permission -from app.utils.validators import ( - validate_connectors, - validate_document_ids, - validate_messages, - validate_research_mode, - validate_search_space_id, - validate_top_k, -) - -router = APIRouter() - - -@router.post("/chat") -async def handle_chat_data( - request: AISDKChatRequest, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - # Validate and sanitize all input data - messages = validate_messages(request.messages) - - if messages[-1]["role"] != "user": - raise HTTPException( - status_code=400, detail="Last message must be a user message" - ) - - user_query = messages[-1]["content"] - - # Extract and validate data from request - request_data = request.data or {} - search_space_id = validate_search_space_id(request_data.get("search_space_id")) - research_mode = validate_research_mode(request_data.get("research_mode")) - selected_connectors = validate_connectors(request_data.get("selected_connectors")) - document_ids_to_add_in_context = validate_document_ids( - request_data.get("document_ids_to_add_in_context") - ) - top_k = validate_top_k(request_data.get("top_k")) - # print("RESQUEST DATA:", request_data) - # print("SELECTED CONNECTORS:", selected_connectors) - - # Check if the user has chat access to the search space - try: - await check_permission( - session, - user, - search_space_id, - Permission.CHATS_CREATE.value, - "You don't have permission to use chat in this search space", - ) - - # Get search space with LLM configs (preferences are now stored at search space level) - search_space_result = await session.execute( - select(SearchSpace) - .options(selectinload(SearchSpace.llm_configs)) - .filter(SearchSpace.id == search_space_id) - ) - search_space = search_space_result.scalars().first() - - language = None - llm_configs = [] # Initialize to empty list - - if search_space and search_space.llm_configs: - llm_configs = search_space.llm_configs - - # Get language from configured LLM preferences - # LLM preferences are now stored on the SearchSpace model - from app.config import config as app_config - - for llm_id in [ - search_space.fast_llm_id, - search_space.long_context_llm_id, - search_space.strategic_llm_id, - ]: - if llm_id is not None: - # Check if it's a global config (negative ID) - if llm_id < 0: - # Look in global configs - for global_cfg in app_config.GLOBAL_LLM_CONFIGS: - if global_cfg.get("id") == llm_id: - language = global_cfg.get("language") - if language: - break - else: - # Look in custom configs - for llm_config in llm_configs: - if llm_config.id == llm_id and getattr( - llm_config, "language", None - ): - language = llm_config.language - break - if language: - break - - if not language and llm_configs: - first_llm_config = llm_configs[0] - language = getattr(first_llm_config, "language", None) - - except HTTPException: - raise HTTPException( - status_code=403, detail="You don't have access to this search space" - ) from None - - langchain_chat_history = [] - for message in messages[:-1]: - if message["role"] == "user": - langchain_chat_history.append(HumanMessage(content=message["content"])) - elif message["role"] == "assistant": - langchain_chat_history.append(AIMessage(content=message["content"])) - - response = StreamingResponse( - stream_connector_search_results( - user_query, - user.id, - search_space_id, - session, - research_mode, - selected_connectors, - langchain_chat_history, - document_ids_to_add_in_context, - language, - top_k, - ) - ) - - response.headers["x-vercel-ai-data-stream"] = "v1" - return response - - -@router.post("/new_chat") -async def handle_new_chat( - request: NewChatRequest, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Handle new chat requests using the SurfSense deep agent. - - This endpoint uses the new deep agent with the Vercel AI SDK - Data Stream Protocol (SSE format). - - Args: - request: NewChatRequest containing chat_id, user_query, and search_space_id - session: Database session - user: Current authenticated user - - Returns: - StreamingResponse with SSE formatted data - """ - # Validate the user query - if not request.user_query or not request.user_query.strip(): - raise HTTPException(status_code=400, detail="User query cannot be empty") - - # Check if the user has chat access to the search space - try: - await check_permission( - session, - user, - request.search_space_id, - Permission.CHATS_CREATE.value, - "You don't have permission to use chat in this search space", - ) - except HTTPException: - raise HTTPException( - status_code=403, detail="You don't have access to this search space" - ) from None - - # Get LLM config ID from search space preferences (optional enhancement) - # For now, we use the default global config (-1) - llm_config_id = -1 - - # Optionally load LLM preferences from search space - try: - search_space_result = await session.execute( - select(SearchSpace).filter(SearchSpace.id == request.search_space_id) - ) - search_space = search_space_result.scalars().first() - - if search_space: - # Use strategic_llm_id if available, otherwise fall back to fast_llm_id - if search_space.strategic_llm_id is not None: - llm_config_id = search_space.strategic_llm_id - elif search_space.fast_llm_id is not None: - llm_config_id = search_space.fast_llm_id - except Exception: - # Fall back to default config on any error - pass - - # Create the streaming response - # chat_id is used as LangGraph's thread_id for automatic chat history management - response = StreamingResponse( - stream_new_chat( - user_query=request.user_query.strip(), - user_id=user.id, - search_space_id=request.search_space_id, - chat_id=request.chat_id, - session=session, - llm_config_id=llm_config_id, - messages=request.messages, # Pass message history from frontend - ), - media_type="text/event-stream", - ) - - # Set the required headers for Vercel AI SDK - headers = VercelStreamingService.get_response_headers() - for key, value in headers.items(): - response.headers[key] = value - - return response - - -@router.post("/chats", response_model=ChatRead) -async def create_chat( - chat: ChatCreate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Create a new chat. - Requires CHATS_CREATE permission. - """ - try: - await check_permission( - session, - user, - chat.search_space_id, - Permission.CHATS_CREATE.value, - "You don't have permission to create chats in this search space", - ) - db_chat = Chat(**chat.model_dump()) - session.add(db_chat) - await session.commit() - await session.refresh(db_chat) - return db_chat - except HTTPException: - raise - except IntegrityError: - await session.rollback() - raise HTTPException( - status_code=400, - detail="Database constraint violation. Please check your input data.", - ) from None - except OperationalError: - await session.rollback() - raise HTTPException( - status_code=503, detail="Database operation failed. Please try again later." - ) from None - except Exception: - await session.rollback() - raise HTTPException( - status_code=500, - detail="An unexpected error occurred while creating the chat.", - ) from None - - -@router.get("/chats", response_model=list[ChatReadWithoutMessages]) -async def read_chats( - skip: int = 0, - limit: int = 100, - search_space_id: int | None = None, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - List chats the user has access to. - Requires CHATS_READ permission for the search space(s). - """ - # Validate pagination parameters - if skip < 0: - raise HTTPException( - status_code=400, detail="skip must be a non-negative integer" - ) - - if limit <= 0 or limit > 1000: # Reasonable upper limit - raise HTTPException(status_code=400, detail="limit must be between 1 and 1000") - - # Validate search_space_id if provided - if search_space_id is not None and search_space_id <= 0: - raise HTTPException( - status_code=400, detail="search_space_id must be a positive integer" - ) - try: - if search_space_id is not None: - # Check permission for specific search space - await check_permission( - session, - user, - search_space_id, - Permission.CHATS_READ.value, - "You don't have permission to read chats in this search space", - ) - # Select specific fields excluding messages - query = ( - select( - Chat.id, - Chat.type, - Chat.title, - Chat.initial_connectors, - Chat.search_space_id, - Chat.created_at, - Chat.state_version, - ) - .filter(Chat.search_space_id == search_space_id) - .order_by(Chat.created_at.desc()) - ) - else: - # Get chats from all search spaces user has membership in - query = ( - select( - Chat.id, - Chat.type, - Chat.title, - Chat.initial_connectors, - Chat.search_space_id, - Chat.created_at, - Chat.state_version, - ) - .join(SearchSpace) - .join(SearchSpaceMembership) - .filter(SearchSpaceMembership.user_id == user.id) - .order_by(Chat.created_at.desc()) - ) - - result = await session.execute(query.offset(skip).limit(limit)) - return result.all() - except HTTPException: - raise - except OperationalError: - raise HTTPException( - status_code=503, detail="Database operation failed. Please try again later." - ) from None - except Exception: - raise HTTPException( - status_code=500, detail="An unexpected error occurred while fetching chats." - ) from None - - -@router.get("/chats/search", response_model=list[ChatReadWithoutMessages]) -async def search_chats( - title: str, - skip: int = 0, - limit: int = 100, - search_space_id: int | None = None, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Search chats by title substring. - Requires CHATS_READ permission for the search space(s). - - Args: - title: Case-insensitive substring to match against chat titles. Required. - skip: Number of items to skip from the beginning. Default: 0. - limit: Maximum number of items to return. Default: 100. - search_space_id: Filter results to a specific search space. Default: None. - session: Database session (injected). - user: Current authenticated user (injected). - - Returns: - List of chats matching the search query. - - Notes: - - Title matching uses ILIKE (case-insensitive). - - Results are ordered by creation date (most recent first). - """ - # Validate pagination parameters - if skip < 0: - raise HTTPException( - status_code=400, detail="skip must be a non-negative integer" - ) - - if limit <= 0 or limit > 1000: - raise HTTPException(status_code=400, detail="limit must be between 1 and 1000") - - # Validate search_space_id if provided - if search_space_id is not None and search_space_id <= 0: - raise HTTPException( - status_code=400, detail="search_space_id must be a positive integer" - ) - - try: - if search_space_id is not None: - # Check permission for specific search space - await check_permission( - session, - user, - search_space_id, - Permission.CHATS_READ.value, - "You don't have permission to read chats in this search space", - ) - # Select specific fields excluding messages - query = ( - select( - Chat.id, - Chat.type, - Chat.title, - Chat.initial_connectors, - Chat.search_space_id, - Chat.created_at, - Chat.state_version, - ) - .filter(Chat.search_space_id == search_space_id) - .order_by(Chat.created_at.desc()) - ) - else: - # Get chats from all search spaces user has membership in - query = ( - select( - Chat.id, - Chat.type, - Chat.title, - Chat.initial_connectors, - Chat.search_space_id, - Chat.created_at, - Chat.state_version, - ) - .join(SearchSpace) - .join(SearchSpaceMembership) - .filter(SearchSpaceMembership.user_id == user.id) - .order_by(Chat.created_at.desc()) - ) - - # Apply title search filter (case-insensitive) - query = query.filter(Chat.title.ilike(f"%{title}%")) - - result = await session.execute(query.offset(skip).limit(limit)) - return result.all() - except HTTPException: - raise - except OperationalError: - raise HTTPException( - status_code=503, detail="Database operation failed. Please try again later." - ) from None - except Exception: - raise HTTPException( - status_code=500, - detail="An unexpected error occurred while searching chats.", - ) from None - - -@router.get("/chats/{chat_id}", response_model=ChatRead) -async def read_chat( - chat_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Get a specific chat by ID. - Requires CHATS_READ permission for the search space. - """ - try: - result = await session.execute(select(Chat).filter(Chat.id == chat_id)) - chat = result.scalars().first() - - if not chat: - raise HTTPException( - status_code=404, - detail="Chat not found", - ) - - # Check permission for the search space - await check_permission( - session, - user, - chat.search_space_id, - Permission.CHATS_READ.value, - "You don't have permission to read chats in this search space", - ) - - return chat - except HTTPException: - raise - except OperationalError: - raise HTTPException( - status_code=503, detail="Database operation failed. Please try again later." - ) from None - except Exception: - raise HTTPException( - status_code=500, - detail="An unexpected error occurred while fetching the chat.", - ) from None - - -@router.put("/chats/{chat_id}", response_model=ChatRead) -async def update_chat( - chat_id: int, - chat_update: ChatUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Update a chat. - Requires CHATS_UPDATE permission for the search space. - """ - try: - result = await session.execute(select(Chat).filter(Chat.id == chat_id)) - db_chat = result.scalars().first() - - if not db_chat: - raise HTTPException(status_code=404, detail="Chat not found") - - # Check permission for the search space - await check_permission( - session, - user, - db_chat.search_space_id, - Permission.CHATS_UPDATE.value, - "You don't have permission to update chats in this search space", - ) - - update_data = chat_update.model_dump(exclude_unset=True) - for key, value in update_data.items(): - if key == "messages": - db_chat.state_version = len(update_data["messages"]) - setattr(db_chat, key, value) - - await session.commit() - await session.refresh(db_chat) - return db_chat - except HTTPException: - raise - except IntegrityError: - await session.rollback() - raise HTTPException( - status_code=400, - detail="Database constraint violation. Please check your input data.", - ) from None - except OperationalError: - await session.rollback() - raise HTTPException( - status_code=503, detail="Database operation failed. Please try again later." - ) from None - except Exception: - await session.rollback() - raise HTTPException( - status_code=500, - detail="An unexpected error occurred while updating the chat.", - ) from None - - -@router.delete("/chats/{chat_id}", response_model=dict) -async def delete_chat( - chat_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Delete a chat. - Requires CHATS_DELETE permission for the search space. - """ - try: - result = await session.execute(select(Chat).filter(Chat.id == chat_id)) - db_chat = result.scalars().first() - - if not db_chat: - raise HTTPException(status_code=404, detail="Chat not found") - - # Check permission for the search space - await check_permission( - session, - user, - db_chat.search_space_id, - Permission.CHATS_DELETE.value, - "You don't have permission to delete chats in this search space", - ) - - await session.delete(db_chat) - await session.commit() - return {"message": "Chat deleted successfully"} - except HTTPException: - raise - except IntegrityError: - await session.rollback() - raise HTTPException( - status_code=400, detail="Cannot delete chat due to existing dependencies." - ) from None - except OperationalError: - await session.rollback() - raise HTTPException( - status_code=503, detail="Database operation failed. Please try again later." - ) from None - except Exception: - await session.rollback() - raise HTTPException( - status_code=500, - detail="An unexpected error occurred while deleting the chat.", - ) from None diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index 209c25a15..8d2734808 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -13,6 +13,7 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui: from datetime import UTC, datetime from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import StreamingResponse from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select @@ -23,12 +24,14 @@ from app.db import ( NewChatMessageRole, NewChatThread, Permission, + SearchSpace, User, get_async_session, ) from app.schemas.new_chat import ( NewChatMessageAppend, NewChatMessageRead, + NewChatRequest, NewChatThreadCreate, NewChatThreadRead, NewChatThreadUpdate, @@ -37,6 +40,7 @@ from app.schemas.new_chat import ( ThreadListItem, ThreadListResponse, ) +from app.tasks.chat.stream_new_chat import stream_new_chat from app.users import current_active_user from app.utils.rbac import check_permission @@ -74,13 +78,10 @@ async def list_threads( "You don't have permission to read chats in this search space", ) - # Get all threads for this user in this search space + # Get all threads in this search space query = ( select(NewChatThread) - .filter( - NewChatThread.search_space_id == search_space_id, - NewChatThread.user_id == user.id, - ) + .filter(NewChatThread.search_space_id == search_space_id) .order_by(NewChatThread.updated_at.desc()) ) @@ -153,7 +154,6 @@ async def search_threads( select(NewChatThread) .filter( NewChatThread.search_space_id == search_space_id, - NewChatThread.user_id == user.id, NewChatThread.title.ilike(f"%{title}%"), ) .order_by(NewChatThread.updated_at.desc()) @@ -211,7 +211,6 @@ async def create_thread( title=thread.title, archived=thread.archived, search_space_id=thread.search_space_id, - user_id=user.id, updated_at=now, ) session.add(db_thread) @@ -273,12 +272,6 @@ async def get_thread_messages( "You don't have permission to read chats in this search space", ) - # Ensure user owns this thread - if thread.user_id != user.id: - raise HTTPException( - status_code=403, detail="You don't have access to this thread" - ) - # Return messages in the format expected by assistant-ui messages = [ NewChatMessageRead( @@ -336,11 +329,6 @@ async def get_thread_full( "You don't have permission to read chats in this search space", ) - if thread.user_id != user.id: - raise HTTPException( - status_code=403, detail="You don't have access to this thread" - ) - return thread except HTTPException: @@ -386,11 +374,6 @@ async def update_thread( "You don't have permission to update chats in this search space", ) - if db_thread.user_id != user.id: - raise HTTPException( - status_code=403, detail="You don't have access to this thread" - ) - # Update fields update_data = thread_update.model_dump(exclude_unset=True) for key, value in update_data.items(): @@ -451,11 +434,6 @@ async def delete_thread( "You don't have permission to delete chats in this search space", ) - if db_thread.user_id != user.id: - raise HTTPException( - status_code=403, detail="You don't have access to this thread" - ) - await session.delete(db_thread) await session.commit() return {"message": "Thread deleted successfully"} @@ -530,11 +508,6 @@ async def append_message( "You don't have permission to update chats in this search space", ) - if thread.user_id != user.id: - raise HTTPException( - status_code=403, detail="You don't have access to this thread" - ) - # Convert string role to enum role_str = ( message.role.lower() if isinstance(message.role, str) else message.role @@ -639,11 +612,6 @@ async def list_messages( "You don't have permission to read chats in this search space", ) - if thread.user_id != user.id: - raise HTTPException( - status_code=403, detail="You don't have access to this thread" - ) - # Get messages query = ( select(NewChatMessage) @@ -667,3 +635,79 @@ async def list_messages( status_code=500, detail=f"An unexpected error occurred while fetching messages: {e!s}", ) from None + + +# ============================================================================= +# Chat Streaming Endpoint +# ============================================================================= + + +@router.post("/new_chat") +async def handle_new_chat( + request: NewChatRequest, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Stream chat responses from the deep agent. + + This endpoint handles the new chat functionality with streaming responses + using Server-Sent Events (SSE) format compatible with Vercel AI SDK. + + Requires CHATS_CREATE permission. + """ + try: + # Verify thread exists and user has permission + result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == request.chat_id) + ) + thread = result.scalars().first() + + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + await check_permission( + session, + user, + thread.search_space_id, + Permission.CHATS_CREATE.value, + "You don't have permission to chat in this search space", + ) + + # Get search space to check LLM config preferences + search_space_result = await session.execute( + select(SearchSpace).filter(SearchSpace.id == request.search_space_id) + ) + search_space = search_space_result.scalars().first() + + # Determine LLM config ID (use search space preference or default) + llm_config_id = -1 # Default to first global config + if search_space and search_space.fast_llm_id: + llm_config_id = search_space.fast_llm_id + + # Return streaming response + return StreamingResponse( + stream_new_chat( + user_query=request.user_query, + user_id=str(user.id), + search_space_id=request.search_space_id, + chat_id=request.chat_id, + session=session, + llm_config_id=llm_config_id, + messages=request.messages, + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"An unexpected error occurred: {e!s}", + ) from None diff --git a/surfsense_backend/app/routes/podcasts_routes.py b/surfsense_backend/app/routes/podcasts_routes.py deleted file mode 100644 index 904de20a3..000000000 --- a/surfsense_backend/app/routes/podcasts_routes.py +++ /dev/null @@ -1,509 +0,0 @@ -import os -from pathlib import Path - -from fastapi import APIRouter, Depends, HTTPException -from fastapi.responses import StreamingResponse -from sqlalchemy.exc import IntegrityError, SQLAlchemyError -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.future import select - -from app.db import ( - Chat, - Permission, - Podcast, - SearchSpace, - SearchSpaceMembership, - User, - get_async_session, -) -from app.schemas import ( - PodcastCreate, - PodcastGenerateRequest, - PodcastRead, - PodcastUpdate, -) -from app.tasks.podcast_tasks import generate_chat_podcast -from app.users import current_active_user -from app.utils.rbac import check_permission - -router = APIRouter() - - -@router.post("/podcasts", response_model=PodcastRead) -async def create_podcast( - podcast: PodcastCreate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Create a new podcast. - Requires PODCASTS_CREATE permission. - """ - try: - await check_permission( - session, - user, - podcast.search_space_id, - Permission.PODCASTS_CREATE.value, - "You don't have permission to create podcasts in this search space", - ) - db_podcast = Podcast(**podcast.model_dump()) - session.add(db_podcast) - await session.commit() - await session.refresh(db_podcast) - return db_podcast - except HTTPException as he: - raise he - except IntegrityError: - await session.rollback() - raise HTTPException( - status_code=400, - detail="Podcast creation failed due to constraint violation", - ) from None - except SQLAlchemyError: - await session.rollback() - raise HTTPException( - status_code=500, detail="Database error occurred while creating podcast" - ) from None - except Exception: - await session.rollback() - raise HTTPException( - status_code=500, detail="An unexpected error occurred" - ) from None - - -@router.get("/podcasts", response_model=list[PodcastRead]) -async def read_podcasts( - skip: int = 0, - limit: int = 100, - search_space_id: int | None = None, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - List podcasts the user has access to. - Requires PODCASTS_READ permission for the search space(s). - """ - if skip < 0 or limit < 1: - raise HTTPException(status_code=400, detail="Invalid pagination parameters") - try: - if search_space_id is not None: - # Check permission for specific search space - await check_permission( - session, - user, - search_space_id, - Permission.PODCASTS_READ.value, - "You don't have permission to read podcasts in this search space", - ) - result = await session.execute( - select(Podcast) - .filter(Podcast.search_space_id == search_space_id) - .offset(skip) - .limit(limit) - ) - else: - # Get podcasts from all search spaces user has membership in - result = await session.execute( - select(Podcast) - .join(SearchSpace) - .join(SearchSpaceMembership) - .filter(SearchSpaceMembership.user_id == user.id) - .offset(skip) - .limit(limit) - ) - return result.scalars().all() - except HTTPException: - raise - except SQLAlchemyError: - raise HTTPException( - status_code=500, detail="Database error occurred while fetching podcasts" - ) from None - - -@router.get("/podcasts/{podcast_id}", response_model=PodcastRead) -async def read_podcast( - podcast_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Get a specific podcast by ID. - Requires PODCASTS_READ permission for the search space. - """ - try: - result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id)) - podcast = result.scalars().first() - - if not podcast: - raise HTTPException( - status_code=404, - detail="Podcast not found", - ) - - # Check permission for the search space - await check_permission( - session, - user, - podcast.search_space_id, - Permission.PODCASTS_READ.value, - "You don't have permission to read podcasts in this search space", - ) - - return podcast - except HTTPException as he: - raise he - except SQLAlchemyError: - raise HTTPException( - status_code=500, detail="Database error occurred while fetching podcast" - ) from None - - -@router.put("/podcasts/{podcast_id}", response_model=PodcastRead) -async def update_podcast( - podcast_id: int, - podcast_update: PodcastUpdate, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Update a podcast. - Requires PODCASTS_UPDATE permission for the search space. - """ - try: - result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id)) - db_podcast = result.scalars().first() - - if not db_podcast: - raise HTTPException(status_code=404, detail="Podcast not found") - - # Check permission for the search space - await check_permission( - session, - user, - db_podcast.search_space_id, - Permission.PODCASTS_UPDATE.value, - "You don't have permission to update podcasts in this search space", - ) - - update_data = podcast_update.model_dump(exclude_unset=True) - for key, value in update_data.items(): - setattr(db_podcast, key, value) - await session.commit() - await session.refresh(db_podcast) - return db_podcast - except HTTPException as he: - raise he - except IntegrityError: - await session.rollback() - raise HTTPException( - status_code=400, detail="Update failed due to constraint violation" - ) from None - except SQLAlchemyError: - await session.rollback() - raise HTTPException( - status_code=500, detail="Database error occurred while updating podcast" - ) from None - - -@router.delete("/podcasts/{podcast_id}", response_model=dict) -async def delete_podcast( - podcast_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Delete a podcast. - Requires PODCASTS_DELETE permission for the search space. - """ - try: - result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id)) - db_podcast = result.scalars().first() - - if not db_podcast: - raise HTTPException(status_code=404, detail="Podcast not found") - - # Check permission for the search space - await check_permission( - session, - user, - db_podcast.search_space_id, - Permission.PODCASTS_DELETE.value, - "You don't have permission to delete podcasts in this search space", - ) - - await session.delete(db_podcast) - await session.commit() - return {"message": "Podcast deleted successfully"} - except HTTPException as he: - raise he - except SQLAlchemyError: - await session.rollback() - raise HTTPException( - status_code=500, detail="Database error occurred while deleting podcast" - ) from None - - -async def generate_chat_podcast_with_new_session( - chat_id: int, - search_space_id: int, - user_id: int, - podcast_title: str | None = None, - user_prompt: str | None = None, -): - """Create a new session and process chat podcast generation.""" - from app.db import async_session_maker - - async with async_session_maker() as session: - try: - await generate_chat_podcast( - session, chat_id, search_space_id, user_id, podcast_title, user_prompt - ) - except Exception as e: - import logging - - logging.error(f"Error generating podcast from chat: {e!s}") - - -@router.post("/podcasts/generate") -async def generate_podcast( - request: PodcastGenerateRequest, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Generate a podcast from a chat or document. - Requires PODCASTS_CREATE permission. - """ - try: - # Check if the user has permission to create podcasts - await check_permission( - session, - user, - request.search_space_id, - Permission.PODCASTS_CREATE.value, - "You don't have permission to create podcasts in this search space", - ) - - if request.type == "CHAT": - # Verify that all chat IDs belong to this user and search space - query = ( - select(Chat) - .filter( - Chat.id.in_(request.ids), - Chat.search_space_id == request.search_space_id, - ) - .join(SearchSpace) - .filter(SearchSpace.user_id == user.id) - ) - - result = await session.execute(query) - valid_chats = result.scalars().all() - valid_chat_ids = [chat.id for chat in valid_chats] - - # If any requested ID is not in valid IDs, raise error immediately - if len(valid_chat_ids) != len(request.ids): - raise HTTPException( - status_code=403, - detail="One or more chat IDs do not belong to this user or search space", - ) - - from app.tasks.celery_tasks.podcast_tasks import ( - generate_chat_podcast_task, - ) - - # Add Celery tasks for each chat ID - for chat_id in valid_chat_ids: - generate_chat_podcast_task.delay( - chat_id, - request.search_space_id, - user.id, - request.podcast_title, - request.user_prompt, - ) - - return { - "message": "Podcast generation started", - } - except HTTPException as he: - raise he - except IntegrityError: - await session.rollback() - raise HTTPException( - status_code=400, - detail="Podcast generation failed due to constraint violation", - ) from None - except SQLAlchemyError: - await session.rollback() - raise HTTPException( - status_code=500, detail="Database error occurred while generating podcast" - ) from None - except Exception as e: - await session.rollback() - raise HTTPException( - status_code=500, detail=f"An unexpected error occurred: {e!s}" - ) from e - - -@router.get("/podcasts/{podcast_id}/stream") -async def stream_podcast( - podcast_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Stream a podcast audio file. - Requires PODCASTS_READ permission for the search space. - """ - try: - result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id)) - podcast = result.scalars().first() - - if not podcast: - raise HTTPException( - status_code=404, - detail="Podcast not found", - ) - - # Check permission for the search space - await check_permission( - session, - user, - podcast.search_space_id, - Permission.PODCASTS_READ.value, - "You don't have permission to access podcasts in this search space", - ) - - # Get the file path - file_path = podcast.file_location - - # Check if the file exists - if not os.path.isfile(file_path): - raise HTTPException(status_code=404, detail="Podcast audio file not found") - - # Define a generator function to stream the file - def iterfile(): - with open(file_path, mode="rb") as file_like: - yield from file_like - - # Return a streaming response with appropriate headers - return StreamingResponse( - iterfile(), - media_type="audio/mpeg", - headers={ - "Accept-Ranges": "bytes", - "Content-Disposition": f"inline; filename={Path(file_path).name}", - }, - ) - - except HTTPException as he: - raise he - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Error streaming podcast: {e!s}" - ) from e - - -@router.get("/podcasts/by-chat/{chat_id}", response_model=PodcastRead | None) -async def get_podcast_by_chat_id( - chat_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_active_user), -): - """ - Get a podcast by its associated chat ID. - Requires PODCASTS_READ permission for the search space. - """ - try: - # First get the chat to find its search space - chat_result = await session.execute(select(Chat).filter(Chat.id == chat_id)) - chat = chat_result.scalars().first() - - if not chat: - return None - - # Check permission for the search space - await check_permission( - session, - user, - chat.search_space_id, - Permission.PODCASTS_READ.value, - "You don't have permission to read podcasts in this search space", - ) - - # Get the podcast - result = await session.execute( - select(Podcast).filter(Podcast.chat_id == chat_id) - ) - podcast = result.scalars().first() - - return podcast - except HTTPException as he: - raise he - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Error fetching podcast: {e!s}" - ) from e - - -@router.get("/podcasts/task/{task_id}/status") -async def get_podcast_task_status( - task_id: str, - user: User = Depends(current_active_user), -): - """ - Get the status of a podcast generation task. - Used by new-chat frontend to poll for completion. - - Returns: - - status: "processing" | "success" | "error" - - podcast_id: (only if status == "success") - - title: (only if status == "success") - - error: (only if status == "error") - """ - try: - from celery.result import AsyncResult - - from app.celery_app import celery_app - - result = AsyncResult(task_id, app=celery_app) - - if result.ready(): - # Task completed - if result.successful(): - task_result = result.result - if isinstance(task_result, dict): - if task_result.get("status") == "success": - return { - "status": "success", - "podcast_id": task_result.get("podcast_id"), - "title": task_result.get("title"), - "transcript_entries": task_result.get("transcript_entries"), - } - else: - return { - "status": "error", - "error": task_result.get("error", "Unknown error"), - } - else: - return { - "status": "error", - "error": "Unexpected task result format", - } - else: - # Task failed - return { - "status": "error", - "error": str(result.result) if result.result else "Task failed", - } - else: - # Task still processing - return { - "status": "processing", - "state": result.state, - } - - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Error checking task status: {e!s}" - ) from e diff --git a/surfsense_backend/app/schemas/__init__.py b/surfsense_backend/app/schemas/__init__.py index 92f9cdc78..4d0eb3595 100644 --- a/surfsense_backend/app/schemas/__init__.py +++ b/surfsense_backend/app/schemas/__init__.py @@ -1,13 +1,4 @@ from .base import IDModel, TimestampModel -from .chats import ( - AISDKChatRequest, - ChatBase, - ChatCreate, - ChatRead, - ChatReadWithoutMessages, - ChatUpdate, - NewChatRequest, -) from .chunks import ChunkBase, ChunkCreate, ChunkRead, ChunkUpdate from .documents import ( DocumentBase, @@ -22,9 +13,11 @@ from .documents import ( from .llm_config import LLMConfigBase, LLMConfigCreate, LLMConfigRead, LLMConfigUpdate from .logs import LogBase, LogCreate, LogFilter, LogRead, LogUpdate from .new_chat import ( + ChatMessage, NewChatMessageAppend, NewChatMessageCreate, NewChatMessageRead, + NewChatRequest, NewChatThreadCreate, NewChatThreadRead, NewChatThreadUpdate, @@ -33,13 +26,6 @@ from .new_chat import ( ThreadListItem, ThreadListResponse, ) -from .podcasts import ( - PodcastBase, - PodcastCreate, - PodcastGenerateRequest, - PodcastRead, - PodcastUpdate, -) from .rbac_schemas import ( InviteAcceptRequest, InviteAcceptResponse, @@ -73,44 +59,8 @@ from .search_space import ( from .users import UserCreate, UserRead, UserUpdate __all__ = [ - "AISDKChatRequest", - "ChatBase", - "ChatCreate", - "ChatRead", - "ChatReadWithoutMessages", - "ChatUpdate", - "ChunkBase", - "ChunkCreate", - "ChunkRead", - "ChunkUpdate", - "DocumentBase", - "DocumentRead", - "DocumentUpdate", - "DocumentWithChunksRead", - "DocumentsCreate", - "ExtensionDocumentContent", - "ExtensionDocumentMetadata", - "IDModel", - # RBAC schemas - "InviteAcceptRequest", - "InviteAcceptResponse", - "InviteCreate", - "InviteInfoResponse", - "InviteRead", - "InviteUpdate", - "LLMConfigBase", - "LLMConfigCreate", - "LLMConfigRead", - "LLMConfigUpdate", - "LogBase", - "LogCreate", - "LogFilter", - "LogRead", - "LogUpdate", - "MembershipRead", - "MembershipReadWithUser", - "MembershipUpdate", - # New chat schemas (assistant-ui integration) + # Chat schemas (assistant-ui integration) + "ChatMessage", "NewChatMessageAppend", "NewChatMessageCreate", "NewChatMessageRead", @@ -119,30 +69,64 @@ __all__ = [ "NewChatThreadRead", "NewChatThreadUpdate", "NewChatThreadWithMessages", + "ThreadHistoryLoadResponse", + "ThreadListItem", + "ThreadListResponse", + # Chunk schemas + "ChunkBase", + "ChunkCreate", + "ChunkRead", + "ChunkUpdate", + # Document schemas + "DocumentBase", + "DocumentRead", + "DocumentUpdate", + "DocumentWithChunksRead", + "DocumentsCreate", + "ExtensionDocumentContent", + "ExtensionDocumentMetadata", "PaginatedResponse", + # Base schemas + "IDModel", + "TimestampModel", + # LLM Config schemas + "LLMConfigBase", + "LLMConfigCreate", + "LLMConfigRead", + "LLMConfigUpdate", + # Log schemas + "LogBase", + "LogCreate", + "LogFilter", + "LogRead", + "LogUpdate", + # RBAC schemas + "InviteAcceptRequest", + "InviteAcceptResponse", + "InviteCreate", + "InviteInfoResponse", + "InviteRead", + "InviteUpdate", + "MembershipRead", + "MembershipReadWithUser", + "MembershipUpdate", "PermissionInfo", "PermissionsListResponse", - "PodcastBase", - "PodcastCreate", - "PodcastGenerateRequest", - "PodcastRead", - "PodcastUpdate", "RoleCreate", "RoleRead", "RoleUpdate", + # Search source connector schemas "SearchSourceConnectorBase", "SearchSourceConnectorCreate", "SearchSourceConnectorRead", "SearchSourceConnectorUpdate", + # Search space schemas "SearchSpaceBase", "SearchSpaceCreate", "SearchSpaceRead", "SearchSpaceUpdate", "SearchSpaceWithStats", - "ThreadHistoryLoadResponse", - "ThreadListItem", - "ThreadListResponse", - "TimestampModel", + # User schemas "UserCreate", "UserRead", "UserSearchSpaceAccess", diff --git a/surfsense_backend/app/schemas/chats.py b/surfsense_backend/app/schemas/chats.py deleted file mode 100644 index 3109130f5..000000000 --- a/surfsense_backend/app/schemas/chats.py +++ /dev/null @@ -1,80 +0,0 @@ -from typing import Any - -from pydantic import BaseModel, ConfigDict - -from app.db import ChatType - -from .base import IDModel, TimestampModel - - -class ChatBase(BaseModel): - type: ChatType - title: str - initial_connectors: list[str] | None = None - messages: list[Any] - search_space_id: int - state_version: int = 1 - - -class ChatBaseWithoutMessages(BaseModel): - type: ChatType - title: str - search_space_id: int - state_version: int = 1 - - -class ClientAttachment(BaseModel): - name: str - content_type: str - url: str - - -class ToolInvocation(BaseModel): - tool_call_id: str - tool_name: str - args: dict - result: dict - - -# class ClientMessage(BaseModel): -# role: str -# content: str -# experimental_attachments: Optional[List[ClientAttachment]] = None -# toolInvocations: Optional[List[ToolInvocation]] = None - - -class AISDKChatRequest(BaseModel): - messages: list[Any] - data: dict[str, Any] | None = None - - -class ChatMessage(BaseModel): - """A single message in the chat history.""" - - role: str # "user" or "assistant" - content: str - - -class NewChatRequest(BaseModel): - """Request schema for the new deep agent chat endpoint.""" - - chat_id: int - user_query: str - search_space_id: int - messages: list[ChatMessage] | None = None # Optional chat history from frontend - - -class ChatCreate(ChatBase): - pass - - -class ChatUpdate(ChatBase): - pass - - -class ChatRead(ChatBase, IDModel, TimestampModel): - model_config = ConfigDict(from_attributes=True) - - -class ChatReadWithoutMessages(ChatBaseWithoutMessages, IDModel, TimestampModel): - model_config = ConfigDict(from_attributes=True) diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index e1cf4efb8..1fdb50777 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -127,3 +127,24 @@ class ThreadListResponse(BaseModel): threads: list[ThreadListItem] archived_threads: list[ThreadListItem] + + +# ============================================================================= +# Chat Request Schemas (for deep agent) +# ============================================================================= + + +class ChatMessage(BaseModel): + """A single message in the chat history.""" + + role: str # "user" or "assistant" + content: str + + +class NewChatRequest(BaseModel): + """Request schema for the deep agent chat endpoint.""" + + chat_id: int + user_query: str + search_space_id: int + messages: list[ChatMessage] | None = None # Optional chat history from frontend diff --git a/surfsense_backend/app/schemas/podcasts.py b/surfsense_backend/app/schemas/podcasts.py deleted file mode 100644 index b6a6a9a24..000000000 --- a/surfsense_backend/app/schemas/podcasts.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Any, Literal - -from pydantic import BaseModel, ConfigDict - -from .base import IDModel, TimestampModel - - -class PodcastBase(BaseModel): - title: str - podcast_transcript: list[Any] - file_location: str = "" - search_space_id: int - chat_state_version: int | None = None - - -class PodcastCreate(PodcastBase): - pass - - -class PodcastUpdate(PodcastBase): - pass - - -class PodcastRead(PodcastBase, IDModel, TimestampModel): - model_config = ConfigDict(from_attributes=True) - - -class PodcastGenerateRequest(BaseModel): - type: Literal["DOCUMENT", "CHAT"] - ids: list[int] - search_space_id: int - podcast_title: str | None = None - user_prompt: str | None = None diff --git a/surfsense_backend/app/tasks/chat/stream_connector_search_results.py b/surfsense_backend/app/tasks/chat/stream_connector_search_results.py deleted file mode 100644 index a4b9b6665..000000000 --- a/surfsense_backend/app/tasks/chat/stream_connector_search_results.py +++ /dev/null @@ -1,75 +0,0 @@ -from collections.abc import AsyncGenerator -from typing import Any -from uuid import UUID - -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.researcher.graph import graph as researcher_graph -from app.agents.researcher.state import State -from app.services.streaming_service import StreamingService - - -async def stream_connector_search_results( - user_query: str, - user_id: str | UUID, - search_space_id: int, - session: AsyncSession, - research_mode: str, - selected_connectors: list[str], - langchain_chat_history: list[Any], - document_ids_to_add_in_context: list[int], - language: str | None = None, - top_k: int = 10, -) -> AsyncGenerator[str, None]: - """ - Stream connector search results to the client - - Args: - user_query: The user's query - user_id: The user's ID (can be UUID object or string) - search_space_id: The search space ID - session: The database session - research_mode: The research mode - selected_connectors: List of selected connectors - - Yields: - str: Formatted response strings - """ - streaming_service = StreamingService() - - # Convert UUID to string if needed - user_id_str = str(user_id) if isinstance(user_id, UUID) else user_id - - # Sample configuration - config = { - "configurable": { - "user_query": user_query, - "connectors_to_search": selected_connectors, - "user_id": user_id_str, - "search_space_id": search_space_id, - "document_ids_to_add_in_context": document_ids_to_add_in_context, - "language": language, # Add language to the configuration - "top_k": top_k, # Add top_k to the configuration - } - } - # print(f"Researcher configuration: {config['configurable']}") # Debug print - # Initialize state with database session and streaming service - initial_state = State( - db_session=session, - streaming_service=streaming_service, - chat_history=langchain_chat_history, - ) - - # Run the graph directly - print("\nRunning the complete researcher workflow...") - - # Use streaming with config parameter - async for chunk in researcher_graph.astream( - initial_state, - config=config, - stream_mode="custom", - ): - if isinstance(chunk, dict) and "yield_value" in chunk: - yield chunk["yield_value"] - - yield streaming_service.format_completion() diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 5ddd097e6..e1061e745 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -18,7 +18,7 @@ from app.agents.new_chat.llm_config import ( create_chat_litellm_from_config, load_llm_config_from_yaml, ) -from app.schemas.chats import ChatMessage +from app.schemas.new_chat import ChatMessage from app.services.connector_service import ConnectorService from app.services.new_streaming_service import VercelStreamingService