mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-28 21:49:40 +02:00
feat: SurfSense v0.0.6 init
This commit is contained in:
parent
18fc19e8d9
commit
da23012970
58 changed files with 8284 additions and 2076 deletions
21
surfsense_backend/.env.example
Normal file
21
surfsense_backend/.env.example
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
DATABASE_URL="postgresql+asyncpg://postgres:postgres@localhost:5432/surfsense"
|
||||
|
||||
SECRET_KEY="SECRET"
|
||||
GOOGLE_OAUTH_CLIENT_ID="924507538m"
|
||||
GOOGLE_OAUTH_CLIENT_SECRET="GOCSV"
|
||||
NEXT_FRONTEND_URL="http://localhost:3000"
|
||||
EMBEDDING_MODEL="mixedbread-ai/mxbai-embed-large-v1"
|
||||
|
||||
RERANKERS_MODEL_NAME="ms-marco-MiniLM-L-12-v2"
|
||||
RERANKERS_MODEL_TYPE="flashrank"
|
||||
|
||||
FAST_LLM="litellm:openai/gpt-4o-mini"
|
||||
SMART_LLM="litellm:openai/gpt-4o-mini"
|
||||
STRATEGIC_LLM="litellm:openai/gpt-4o-mini"
|
||||
LONG_CONTEXT_LLM="litellm:gemini/gemini-2.0-flash-thinking-exp-01-21"
|
||||
|
||||
OPENAI_API_KEY="sk-proj-iA"
|
||||
GEMINI_API_KEY="AIzaSyB6-1641124124124124124124124124124"
|
||||
|
||||
UNSTRUCTURED_API_KEY="Tpu3P0U8iy"
|
||||
FIRECRAWL_API_KEY="fcr-01J0000000000000000000000"
|
||||
6
surfsense_backend/.gitignore
vendored
Normal file
6
surfsense_backend/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
.env
|
||||
.venv
|
||||
venv/
|
||||
data/
|
||||
__pycache__/
|
||||
.flashrank_cache
|
||||
1
surfsense_backend/.python-version
Normal file
1
surfsense_backend/.python-version
Normal file
|
|
@ -0,0 +1 @@
|
|||
3.12
|
||||
16
surfsense_backend/.vscode/launch.json
vendored
Normal file
16
surfsense_backend/.vscode/launch.json
vendored
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python Debugger: main.py",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
}
|
||||
]
|
||||
}
|
||||
119
surfsense_backend/README.md
Normal file
119
surfsense_backend/README.md
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
# Surf Backend
|
||||
|
||||
## Technology Stack Overview
|
||||
|
||||
This application is a modern AI-powered search and knowledge management platform built with the following technology stack:
|
||||
|
||||
### Core Framework and Environment
|
||||
- **Python 3.12+**: The application requires Python 3.12 or newer
|
||||
- **FastAPI**: Modern, fast web framework for building APIs with Python
|
||||
- **Uvicorn**: ASGI server implementation, running the FastAPI application
|
||||
- **PostgreSQL with pgvector**: Database with vector search capabilities for similarity searches
|
||||
- **SQLAlchemy**: SQL toolkit and ORM (Object-Relational Mapping) for database interactions
|
||||
- **FastAPI Users**: Authentication and user management with JWT and OAuth support
|
||||
|
||||
### Key Features and Components
|
||||
|
||||
#### Authentication and User Management
|
||||
- JWT-based authentication
|
||||
- OAuth integration (Google)
|
||||
- User registration, login, and password reset flows
|
||||
|
||||
#### Search and Retrieval System
|
||||
- **Hybrid Search**: Combines vector similarity and full-text search for optimal results using Reciprocal Rank Fusion (RRF)
|
||||
- **Vector Embeddings**: Document and text embeddings for semantic search
|
||||
- **pgvector**: PostgreSQL extension for efficient vector similarity operations
|
||||
- **Chonkie**: Advanced document chunking and embedding library
|
||||
- Uses `AutoEmbeddings` for flexible embedding model selection
|
||||
- `LateChunker` for optimized document chunking based on embedding model's max sequence length
|
||||
|
||||
#### AI and NLP Capabilities
|
||||
- **LangChain**: Framework for developing AI-powered applications
|
||||
- Used for document processing, research, and response generation
|
||||
- Integration with various LLM models through LiteLLM
|
||||
- Document conversion utilities for standardized processing
|
||||
- **GPT Integration**: Integration with LLM models through LiteLLM
|
||||
- Multiple LLM configurations for different use cases:
|
||||
- Fast LLM: Quick responses (default: gpt-4o-mini)
|
||||
- Smart LLM: More comprehensive analysis (default: gpt-4o-mini)
|
||||
- Strategic LLM: Complex reasoning (default: gpt-4o-mini)
|
||||
- Long Context LLM: For processing large documents (default: gemini-2.0-flash-thinking)
|
||||
- **Rerankers with FlashRank**: Advanced result ranking for improved search relevance
|
||||
- Configurable reranking models (default: ms-marco-MiniLM-L-12-v2)
|
||||
- Supports multiple reranking backends (FlashRank, Cohere, etc.)
|
||||
- Improves search result quality by reordering based on semantic relevance
|
||||
- **GPT-Researcher**: Advanced research capabilities
|
||||
- Multiple research modes (GENERAL, DEEP, DEEPER)
|
||||
- Customizable report formats with proper citations
|
||||
- Streaming research results for real-time updates
|
||||
|
||||
#### External Integrations
|
||||
- **Slack Connector**: Integration with Slack for data retrieval and notifications
|
||||
- **Notion Connector**: Integration with Notion for document retrieval
|
||||
- **Search APIs**: Integration with Tavily and Serper API for web search
|
||||
- **Firecrawl**: Web crawling and data extraction capabilities
|
||||
|
||||
#### Data Processing
|
||||
- **Unstructured**: Tools for processing unstructured data
|
||||
- **Markdownify**: Converting HTML to Markdown
|
||||
- **Playwright**: Web automation and scraping capabilities
|
||||
|
||||
#### Main Modules
|
||||
- **Search Spaces**: Isolated search environments for different contexts or projects
|
||||
- **Documents**: Storage and retrieval of various document types
|
||||
- **Chunks**: Document fragments for more precise retrieval
|
||||
- **Chats**: Conversation management with different depth levels (GENERAL, DEEP)
|
||||
- **Podcasts**: Audio content management with generation capabilities
|
||||
- **Search Source Connectors**: Integration with various data sources
|
||||
|
||||
### Development Tools
|
||||
- **Poetry**: Python dependency management (indicated by pyproject.toml)
|
||||
- **CORS support**: Cross-Origin Resource Sharing enabled for API access
|
||||
- **Environment Variables**: Configuration through .env files
|
||||
|
||||
## Database Schema
|
||||
|
||||
The application uses a relational database with the following main entities:
|
||||
- Users: Authentication and user management
|
||||
- SearchSpaces: Isolated search environments owned by users
|
||||
- Documents: Various document types with content and embeddings
|
||||
- Chunks: Smaller pieces of documents for granular retrieval
|
||||
- Chats: Conversation tracking with different depth levels
|
||||
- Podcasts: Audio content with generation capabilities
|
||||
- SearchSourceConnectors: External data source integrations
|
||||
|
||||
## API Endpoints
|
||||
|
||||
The API is structured with the following main route groups:
|
||||
- `/auth/*`: Authentication endpoints (JWT, OAuth)
|
||||
- `/users/*`: User management
|
||||
- `/api/v1/search-spaces/*`: Search space management
|
||||
- `/api/v1/documents/*`: Document management
|
||||
- `/api/v1/podcasts/*`: Podcast functionality
|
||||
- `/api/v1/chats/*`: Chat and conversation endpoints
|
||||
- `/api/v1/search-source-connectors/*`: External data source management
|
||||
|
||||
## Deployment
|
||||
|
||||
The application is configured to run with Uvicorn and can be deployed with:
|
||||
```
|
||||
python main.py
|
||||
```
|
||||
|
||||
This will start the server on all interfaces (0.0.0.0) with info-level logging.
|
||||
|
||||
## Requirements
|
||||
|
||||
See pyproject.toml for detailed dependency information. Key dependencies include:
|
||||
- asyncpg: Asynchronous PostgreSQL client
|
||||
- chonkie: Document chunking and embedding library
|
||||
- fastapi and related packages
|
||||
- fastapi-users: Authentication and user management
|
||||
- firecrawl-py: Web crawling capabilities
|
||||
- gpt-researcher: Advanced research capabilities
|
||||
- langchain components for AI workflows
|
||||
- litellm: LLM model integration
|
||||
- pgvector: Vector similarity search in PostgreSQL
|
||||
- rerankers with FlashRank: Advanced result ranking
|
||||
- Various AI and NLP libraries
|
||||
- Integration clients for Slack, Notion, etc.
|
||||
0
surfsense_backend/app/__init__.py
Normal file
0
surfsense_backend/app/__init__.py
Normal file
80
surfsense_backend/app/app.py
Normal file
80
surfsense_backend/app/app.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import User, create_db_and_tables, get_async_session
|
||||
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
from app.schemas import UserCreate, UserRead, UserUpdate
|
||||
from app.users import (
|
||||
SECRET,
|
||||
auth_backend,
|
||||
fastapi_users,
|
||||
google_oauth_client,
|
||||
current_active_user,
|
||||
)
|
||||
from app.routes import router as crud_router
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Not needed if you setup a migration system like Alembic
|
||||
await create_db_and_tables()
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allows all origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", tags=["auth"]
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_register_router(UserRead, UserCreate),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_reset_password_router(),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_verify_router(UserRead),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_users_router(UserRead, UserUpdate),
|
||||
prefix="/users",
|
||||
tags=["users"],
|
||||
)
|
||||
app.include_router(
|
||||
fastapi_users.get_oauth_router(google_oauth_client, auth_backend, SECRET, is_verified_by_default=True),
|
||||
prefix="/auth/google",
|
||||
tags=["auth"],
|
||||
)
|
||||
app.include_router(crud_router, prefix="/api/v1", tags=["crud"])
|
||||
|
||||
|
||||
@app.get("/authenticated-route")
|
||||
async def authenticated_route(user: User = Depends(current_active_user), session: AsyncSession = Depends(get_async_session)):
|
||||
retriever = ChucksHybridSearchRetriever(session)
|
||||
results = await retriever.hybrid_search(
|
||||
query_text="SurfSense",
|
||||
top_k=1,
|
||||
user_id=user.id,
|
||||
search_space_id=1,
|
||||
document_type="CRAWLED_URL"
|
||||
)
|
||||
return results
|
||||
98
surfsense_backend/app/config/__init__.py
Normal file
98
surfsense_backend/app/config/__init__.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from chonkie import AutoEmbeddings, LateChunker
|
||||
from rerankers import Reranker
|
||||
from langchain_community.chat_models import ChatLiteLLM
|
||||
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Get the base directory of the project
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
||||
|
||||
env_file = BASE_DIR / ".env"
|
||||
load_dotenv(env_file)
|
||||
|
||||
|
||||
def extract_model_name(llm_string: str) -> str:
|
||||
"""Extract the model name from an LLM string.
|
||||
Example: "litellm:openai/gpt-4o-mini" -> "openai/gpt-4o-mini"
|
||||
|
||||
Args:
|
||||
llm_string: The LLM string with optional prefix
|
||||
|
||||
Returns:
|
||||
str: The extracted model name
|
||||
"""
|
||||
return llm_string.split(":", 1)[1] if ":" in llm_string else llm_string
|
||||
|
||||
class Config:
|
||||
# Database
|
||||
DATABASE_URL = os.getenv("DATABASE_URL")
|
||||
|
||||
# Google OAuth
|
||||
GOOGLE_OAUTH_CLIENT_ID = os.getenv("GOOGLE_OAUTH_CLIENT_ID")
|
||||
GOOGLE_OAUTH_CLIENT_SECRET = os.getenv("GOOGLE_OAUTH_CLIENT_SECRET")
|
||||
NEXT_FRONTEND_URL = os.getenv("NEXT_FRONTEND_URL")
|
||||
|
||||
# LONG-CONTEXT LLMS
|
||||
LONG_CONTEXT_LLM = os.getenv("LONG_CONTEXT_LLM")
|
||||
long_context_llm_instance = ChatLiteLLM(model=extract_model_name(LONG_CONTEXT_LLM))
|
||||
|
||||
# GPT Researcher
|
||||
FAST_LLM = os.getenv("FAST_LLM")
|
||||
SMART_LLM = os.getenv("SMART_LLM")
|
||||
STRATEGIC_LLM = os.getenv("STRATEGIC_LLM")
|
||||
fast_llm_instance = ChatLiteLLM(model=extract_model_name(FAST_LLM))
|
||||
smart_llm_instance = ChatLiteLLM(model=extract_model_name(SMART_LLM))
|
||||
strategic_llm_instance = ChatLiteLLM(model=extract_model_name(STRATEGIC_LLM))
|
||||
|
||||
|
||||
# Chonkie Configuration | Edit this to your needs
|
||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
||||
embedding_model_instance = AutoEmbeddings.get_embeddings(EMBEDDING_MODEL)
|
||||
chunker_instance = LateChunker(
|
||||
embedding_model=EMBEDDING_MODEL,
|
||||
chunk_size=embedding_model_instance.max_seq_length,
|
||||
)
|
||||
|
||||
# Reranker's Configuration | Pinecode, Cohere etc. Read more at https://github.com/AnswerDotAI/rerankers?tab=readme-ov-file#usage
|
||||
RERANKERS_MODEL_NAME = os.getenv("RERANKERS_MODEL_NAME")
|
||||
RERANKERS_MODEL_TYPE = os.getenv("RERANKERS_MODEL_TYPE")
|
||||
reranker_instance = Reranker(
|
||||
model_name=RERANKERS_MODEL_NAME,
|
||||
model_type=RERANKERS_MODEL_TYPE,
|
||||
)
|
||||
|
||||
# OAuth JWT
|
||||
SECRET_KEY = os.getenv("SECRET_KEY")
|
||||
|
||||
# Unstructured API Key
|
||||
UNSTRUCTURED_API_KEY = os.getenv("UNSTRUCTURED_API_KEY")
|
||||
|
||||
# Firecrawl API Key
|
||||
FIRECRAWL_API_KEY = os.getenv("FIRECRAWL_API_KEY", None)
|
||||
|
||||
# Validation Checks
|
||||
# Check embedding dimension
|
||||
if hasattr(embedding_model_instance, 'dimension') and embedding_model_instance.dimension > 2000:
|
||||
raise ValueError(
|
||||
f"Embedding dimension for Model: {EMBEDDING_MODEL} "
|
||||
f"has {embedding_model_instance.dimension} dimensions, which "
|
||||
f"exceeds the maximum of 2000 allowed by PGVector."
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_settings(cls):
|
||||
"""Get all settings as a dictionary."""
|
||||
return {
|
||||
key: value
|
||||
for key, value in cls.__dict__.items()
|
||||
if not key.startswith("_") and not callable(value)
|
||||
}
|
||||
|
||||
|
||||
# Create a config instance
|
||||
config = Config()
|
||||
225
surfsense_backend/app/connectors/notion_history.py
Normal file
225
surfsense_backend/app/connectors/notion_history.py
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
from notion_client import Client
|
||||
|
||||
class NotionHistoryConnector:
|
||||
def __init__(self, token):
|
||||
"""
|
||||
Initialize the NotionPageFetcher with a token.
|
||||
|
||||
Args:
|
||||
token (str): Notion integration token
|
||||
"""
|
||||
self.notion = Client(auth=token)
|
||||
|
||||
def get_all_pages(self, start_date=None, end_date=None):
|
||||
"""
|
||||
Fetches all pages shared with your integration and their content.
|
||||
|
||||
Args:
|
||||
start_date (str, optional): ISO 8601 date string (e.g., "2023-01-01T00:00:00Z")
|
||||
end_date (str, optional): ISO 8601 date string (e.g., "2023-12-31T23:59:59Z")
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing page data
|
||||
"""
|
||||
# Build the filter for the search
|
||||
# Note: Notion API requires specific filter structure
|
||||
search_params = {}
|
||||
|
||||
# Filter for pages only (not databases)
|
||||
search_params["filter"] = {
|
||||
"value": "page",
|
||||
"property": "object"
|
||||
}
|
||||
|
||||
# Add date filters if provided
|
||||
if start_date or end_date:
|
||||
date_filter = {}
|
||||
|
||||
if start_date:
|
||||
date_filter["on_or_after"] = start_date
|
||||
|
||||
if end_date:
|
||||
date_filter["on_or_before"] = end_date
|
||||
|
||||
# Add the date filter to the search params
|
||||
if date_filter:
|
||||
search_params["sort"] = {
|
||||
"direction": "descending",
|
||||
"timestamp": "last_edited_time"
|
||||
}
|
||||
|
||||
# First, get a list of all pages the integration has access to
|
||||
search_results = self.notion.search(**search_params)
|
||||
|
||||
pages = search_results["results"]
|
||||
all_page_data = []
|
||||
|
||||
for page in pages:
|
||||
page_id = page["id"]
|
||||
|
||||
# Get detailed page information
|
||||
page_content = self.get_page_content(page_id)
|
||||
|
||||
all_page_data.append({
|
||||
"page_id": page_id,
|
||||
"title": self.get_page_title(page),
|
||||
"content": page_content
|
||||
})
|
||||
|
||||
return all_page_data
|
||||
|
||||
def get_page_title(self, page):
|
||||
"""
|
||||
Extracts the title from a page object.
|
||||
|
||||
Args:
|
||||
page (dict): Notion page object
|
||||
|
||||
Returns:
|
||||
str: Page title or a fallback string
|
||||
"""
|
||||
# Title can be in different properties depending on the page type
|
||||
if "properties" in page:
|
||||
# Try to find a title property
|
||||
for prop_name, prop_data in page["properties"].items():
|
||||
if prop_data["type"] == "title" and len(prop_data["title"]) > 0:
|
||||
return " ".join([text_obj["plain_text"] for text_obj in prop_data["title"]])
|
||||
|
||||
# If no title found, return the page ID as fallback
|
||||
return f"Untitled page ({page['id']})"
|
||||
|
||||
def get_page_content(self, page_id):
|
||||
"""
|
||||
Fetches the content (blocks) of a specific page.
|
||||
|
||||
Args:
|
||||
page_id (str): The ID of the page to fetch
|
||||
|
||||
Returns:
|
||||
list: List of processed blocks from the page
|
||||
"""
|
||||
blocks = []
|
||||
has_more = True
|
||||
cursor = None
|
||||
|
||||
# Paginate through all blocks
|
||||
while has_more:
|
||||
if cursor:
|
||||
response = self.notion.blocks.children.list(block_id=page_id, start_cursor=cursor)
|
||||
else:
|
||||
response = self.notion.blocks.children.list(block_id=page_id)
|
||||
|
||||
blocks.extend(response["results"])
|
||||
has_more = response["has_more"]
|
||||
|
||||
if has_more:
|
||||
cursor = response["next_cursor"]
|
||||
|
||||
# Process nested blocks recursively
|
||||
processed_blocks = []
|
||||
for block in blocks:
|
||||
processed_block = self.process_block(block)
|
||||
processed_blocks.append(processed_block)
|
||||
|
||||
return processed_blocks
|
||||
|
||||
def process_block(self, block):
|
||||
"""
|
||||
Processes a block and recursively fetches any child blocks.
|
||||
|
||||
Args:
|
||||
block (dict): The block to process
|
||||
|
||||
Returns:
|
||||
dict: Processed block with content and children
|
||||
"""
|
||||
block_id = block["id"]
|
||||
block_type = block["type"]
|
||||
|
||||
# Extract block content based on its type
|
||||
content = self.extract_block_content(block)
|
||||
|
||||
# Check if block has children
|
||||
has_children = block.get("has_children", False)
|
||||
child_blocks = []
|
||||
|
||||
if has_children:
|
||||
# Fetch and process child blocks
|
||||
children_response = self.notion.blocks.children.list(block_id=block_id)
|
||||
for child_block in children_response["results"]:
|
||||
child_blocks.append(self.process_block(child_block))
|
||||
|
||||
return {
|
||||
"id": block_id,
|
||||
"type": block_type,
|
||||
"content": content,
|
||||
"children": child_blocks
|
||||
}
|
||||
|
||||
def extract_block_content(self, block):
|
||||
"""
|
||||
Extracts the content from a block based on its type.
|
||||
|
||||
Args:
|
||||
block (dict): The block to extract content from
|
||||
|
||||
Returns:
|
||||
str: Extracted content as a string
|
||||
"""
|
||||
block_type = block["type"]
|
||||
|
||||
# Different block types have different structures
|
||||
if block_type in block and "rich_text" in block[block_type]:
|
||||
return "".join([text_obj["plain_text"] for text_obj in block[block_type]["rich_text"]])
|
||||
elif block_type == "image":
|
||||
# Instead of returning the raw URL which may contain sensitive AWS credentials,
|
||||
# return a placeholder or reference to the image
|
||||
if "file" in block["image"]:
|
||||
# For Notion-hosted images (which use AWS S3 pre-signed URLs)
|
||||
return "[Notion Image]"
|
||||
elif "external" in block["image"]:
|
||||
# For external images, we can return a sanitized reference
|
||||
url = block["image"]["external"]["url"]
|
||||
# Only return the domain part of external URLs to avoid potential sensitive parameters
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
parsed_url = urlparse(url)
|
||||
return f"[External Image from {parsed_url.netloc}]"
|
||||
except:
|
||||
return "[External Image]"
|
||||
elif block_type == "code":
|
||||
language = block["code"]["language"]
|
||||
code_text = "".join([text_obj["plain_text"] for text_obj in block["code"]["rich_text"]])
|
||||
return f"```{language}\n{code_text}\n```"
|
||||
elif block_type == "equation":
|
||||
return block["equation"]["expression"]
|
||||
# Add more block types as needed
|
||||
|
||||
# Return empty string for unsupported block types
|
||||
return ""
|
||||
|
||||
|
||||
# Example usage
|
||||
# if __name__ == "__main__":
|
||||
# # Simple example of how to use this module
|
||||
# import argparse
|
||||
|
||||
# parser = argparse.ArgumentParser(description="Fetch Notion pages using an integration token")
|
||||
# parser.add_argument("--token", help="Your Notion integration token")
|
||||
# parser.add_argument("--start-date", help="Start date in ISO format (e.g., 2023-01-01T00:00:00Z)")
|
||||
# parser.add_argument("--end-date", help="End date in ISO format (e.g., 2023-12-31T23:59:59Z)")
|
||||
# args = parser.parse_args()
|
||||
|
||||
# token = args.token
|
||||
# if not token:
|
||||
# token = input("Enter your Notion integration token: ")
|
||||
|
||||
# fetcher = NotionPageFetcher(token)
|
||||
|
||||
# try:
|
||||
# pages = fetcher.get_all_pages(args.start_date, args.end_date)
|
||||
# print(f"Fetched {len(pages)} pages from Notion")
|
||||
# for page in pages:
|
||||
# print(f"- {page['title']}")
|
||||
# except Exception as e:
|
||||
# print(f"Error: {str(e)}")
|
||||
301
surfsense_backend/app/connectors/slack_history.py
Normal file
301
surfsense_backend/app/connectors/slack_history.py
Normal file
|
|
@ -0,0 +1,301 @@
|
|||
"""
|
||||
Slack History Module
|
||||
|
||||
A module for retrieving conversation history from Slack channels.
|
||||
Allows fetching channel lists and message history with date range filtering.
|
||||
"""
|
||||
|
||||
import os
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union
|
||||
|
||||
|
||||
class SlackHistory:
|
||||
"""Class for retrieving conversation history from Slack channels."""
|
||||
|
||||
def __init__(self, token: str = None):
|
||||
"""
|
||||
Initialize the SlackHistory class.
|
||||
|
||||
Args:
|
||||
token: Slack API token (optional, can be set later with set_token)
|
||||
"""
|
||||
self.client = WebClient(token=token) if token else None
|
||||
|
||||
def set_token(self, token: str) -> None:
|
||||
"""
|
||||
Set the Slack API token.
|
||||
|
||||
Args:
|
||||
token: Slack API token
|
||||
"""
|
||||
self.client = WebClient(token=token)
|
||||
|
||||
def get_all_channels(self, include_private: bool = True) -> Dict[str, str]:
|
||||
"""
|
||||
Fetch all channels that the bot has access to.
|
||||
|
||||
Args:
|
||||
include_private: Whether to include private channels
|
||||
|
||||
Returns:
|
||||
Dictionary mapping channel names to channel IDs
|
||||
|
||||
Raises:
|
||||
ValueError: If no Slack client has been initialized
|
||||
SlackApiError: If there's an error calling the Slack API
|
||||
"""
|
||||
if not self.client:
|
||||
raise ValueError("Slack client not initialized. Call set_token() first.")
|
||||
|
||||
channels_dict = {}
|
||||
types = "public_channel"
|
||||
if include_private:
|
||||
types += ",private_channel"
|
||||
|
||||
try:
|
||||
# Call the conversations.list method
|
||||
result = self.client.conversations_list(
|
||||
types=types,
|
||||
limit=1000 # Maximum allowed by API
|
||||
)
|
||||
channels = result["channels"]
|
||||
|
||||
# Handle pagination for workspaces with many channels
|
||||
while result.get("response_metadata", {}).get("next_cursor"):
|
||||
next_cursor = result["response_metadata"]["next_cursor"]
|
||||
|
||||
# Get the next batch of channels
|
||||
result = self.client.conversations_list(
|
||||
types=types,
|
||||
cursor=next_cursor,
|
||||
limit=1000
|
||||
)
|
||||
channels.extend(result["channels"])
|
||||
|
||||
# Create a dictionary mapping channel names to IDs
|
||||
for channel in channels:
|
||||
channels_dict[channel["name"]] = channel["id"]
|
||||
|
||||
return channels_dict
|
||||
|
||||
except SlackApiError as e:
|
||||
raise SlackApiError(f"Error retrieving channels: {e}", e.response)
|
||||
|
||||
def get_conversation_history(
|
||||
self,
|
||||
channel_id: str,
|
||||
limit: int = 1000,
|
||||
oldest: Optional[int] = None,
|
||||
latest: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch conversation history for a channel.
|
||||
|
||||
Args:
|
||||
channel_id: The ID of the channel to fetch history for
|
||||
limit: Maximum number of messages to return per request (default 1000)
|
||||
oldest: Start of time range (Unix timestamp)
|
||||
latest: End of time range (Unix timestamp)
|
||||
|
||||
Returns:
|
||||
List of message objects
|
||||
|
||||
Raises:
|
||||
ValueError: If no Slack client has been initialized
|
||||
SlackApiError: If there's an error calling the Slack API
|
||||
"""
|
||||
if not self.client:
|
||||
raise ValueError("Slack client not initialized. Call set_token() first.")
|
||||
|
||||
try:
|
||||
# Call the conversations.history method
|
||||
messages = []
|
||||
next_cursor = None
|
||||
|
||||
while True:
|
||||
kwargs = {
|
||||
"channel": channel_id,
|
||||
"limit": min(limit, 1000), # API max is 1000
|
||||
}
|
||||
|
||||
if oldest:
|
||||
kwargs["oldest"] = oldest
|
||||
if latest:
|
||||
kwargs["latest"] = latest
|
||||
if next_cursor:
|
||||
kwargs["cursor"] = next_cursor
|
||||
|
||||
result = self.client.conversations_history(**kwargs)
|
||||
batch = result["messages"]
|
||||
messages.extend(batch)
|
||||
|
||||
# Check if we need to paginate
|
||||
if result.get("has_more", False) and len(messages) < limit:
|
||||
next_cursor = result["response_metadata"]["next_cursor"]
|
||||
else:
|
||||
break
|
||||
|
||||
# Respect the overall limit parameter
|
||||
return messages[:limit]
|
||||
|
||||
except SlackApiError as e:
|
||||
raise SlackApiError(f"Error retrieving history for channel {channel_id}: {e}", e.response)
|
||||
|
||||
@staticmethod
|
||||
def convert_date_to_timestamp(date_str: str) -> Optional[int]:
|
||||
"""
|
||||
Convert a date string in format YYYY-MM-DD to Unix timestamp.
|
||||
|
||||
Args:
|
||||
date_str: Date string in YYYY-MM-DD format
|
||||
|
||||
Returns:
|
||||
Unix timestamp (seconds since epoch) or None if invalid format
|
||||
"""
|
||||
try:
|
||||
dt = datetime.strptime(date_str, "%Y-%m-%d")
|
||||
return int(dt.timestamp())
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def get_history_by_date_range(
|
||||
self,
|
||||
channel_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
limit: int = 1000
|
||||
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
|
||||
"""
|
||||
Fetch conversation history within a date range.
|
||||
|
||||
Args:
|
||||
channel_id: The ID of the channel to fetch history for
|
||||
start_date: Start date in YYYY-MM-DD format
|
||||
end_date: End date in YYYY-MM-DD format (inclusive)
|
||||
limit: Maximum number of messages to return
|
||||
|
||||
Returns:
|
||||
Tuple containing (messages list, error message or None)
|
||||
"""
|
||||
oldest = self.convert_date_to_timestamp(start_date)
|
||||
if not oldest:
|
||||
return [], f"Invalid start date format: {start_date}. Please use YYYY-MM-DD."
|
||||
|
||||
latest = self.convert_date_to_timestamp(end_date)
|
||||
if not latest:
|
||||
return [], f"Invalid end date format: {end_date}. Please use YYYY-MM-DD."
|
||||
|
||||
# Add one day to end date to make it inclusive
|
||||
latest += 86400 # seconds in a day
|
||||
|
||||
try:
|
||||
messages = self.get_conversation_history(
|
||||
channel_id=channel_id,
|
||||
limit=limit,
|
||||
oldest=oldest,
|
||||
latest=latest
|
||||
)
|
||||
return messages, None
|
||||
except SlackApiError as e:
|
||||
return [], f"Slack API error: {str(e)}"
|
||||
except ValueError as e:
|
||||
return [], str(e)
|
||||
|
||||
def get_user_info(self, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about a user.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user to get info for
|
||||
|
||||
Returns:
|
||||
User information dictionary
|
||||
|
||||
Raises:
|
||||
ValueError: If no Slack client has been initialized
|
||||
SlackApiError: If there's an error calling the Slack API
|
||||
"""
|
||||
if not self.client:
|
||||
raise ValueError("Slack client not initialized. Call set_token() first.")
|
||||
|
||||
try:
|
||||
result = self.client.users_info(user=user_id)
|
||||
return result["user"]
|
||||
except SlackApiError as e:
|
||||
raise SlackApiError(f"Error retrieving user info for {user_id}: {e}", e.response)
|
||||
|
||||
def format_message(self, msg: Dict[str, Any], include_user_info: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Format a message for easier consumption.
|
||||
|
||||
Args:
|
||||
msg: The message object from Slack API
|
||||
include_user_info: Whether to fetch and include user info
|
||||
|
||||
Returns:
|
||||
Formatted message dictionary
|
||||
"""
|
||||
formatted = {
|
||||
"text": msg.get("text", ""),
|
||||
"timestamp": msg.get("ts"),
|
||||
"datetime": datetime.fromtimestamp(float(msg.get("ts", 0))).strftime('%Y-%m-%d %H:%M:%S'),
|
||||
"user_id": msg.get("user", "UNKNOWN"),
|
||||
"has_attachments": bool(msg.get("attachments")),
|
||||
"has_files": bool(msg.get("files")),
|
||||
"thread_ts": msg.get("thread_ts"),
|
||||
"is_thread": "thread_ts" in msg,
|
||||
}
|
||||
|
||||
if include_user_info and "user" in msg and self.client:
|
||||
try:
|
||||
user_info = self.get_user_info(msg["user"])
|
||||
formatted["user_name"] = user_info.get("real_name", "Unknown")
|
||||
formatted["user_email"] = user_info.get("profile", {}).get("email", "")
|
||||
except Exception:
|
||||
# If we can't get user info, just continue without it
|
||||
formatted["user_name"] = "Unknown"
|
||||
|
||||
return formatted
|
||||
|
||||
|
||||
# Example usage (uncomment to use):
|
||||
"""
|
||||
if __name__ == "__main__":
|
||||
# Set your token here or via environment variable
|
||||
token = os.environ.get("SLACK_API_TOKEN", "xoxb-your-token-here")
|
||||
|
||||
slack = SlackHistory(token)
|
||||
|
||||
# Get all channels
|
||||
try:
|
||||
channels = slack.get_all_channels()
|
||||
print("Available channels:")
|
||||
for name, channel_id in sorted(channels.items()):
|
||||
print(f"- {name}: {channel_id}")
|
||||
|
||||
# Example: Get history for a specific channel and date range
|
||||
channel_id = channels.get("general")
|
||||
if channel_id:
|
||||
messages, error = slack.get_history_by_date_range(
|
||||
channel_id=channel_id,
|
||||
start_date="2023-01-01",
|
||||
end_date="2023-01-31",
|
||||
limit=500
|
||||
)
|
||||
|
||||
if error:
|
||||
print(f"Error: {error}")
|
||||
else:
|
||||
print(f"\nRetrieved {len(messages)} messages from #general")
|
||||
|
||||
# Print formatted messages
|
||||
for msg in messages[:10]: # Show first 10 messages
|
||||
formatted = slack.format_message(msg, include_user_info=True)
|
||||
print(f"[{formatted['datetime']}] {formatted['user_name']}: {formatted['text']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
"""
|
||||
181
surfsense_backend/app/db.py
Normal file
181
surfsense_backend/app/db.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi_users.db import (
|
||||
SQLAlchemyBaseOAuthAccountTableUUID,
|
||||
SQLAlchemyBaseUserTableUUID,
|
||||
SQLAlchemyUserDatabase,
|
||||
)
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
Boolean,
|
||||
Column,
|
||||
Enum as SQLAlchemyEnum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
JSON,
|
||||
String,
|
||||
Text,
|
||||
text,
|
||||
TIMESTAMP
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, declared_attr, relationship
|
||||
|
||||
from app.config import config
|
||||
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
|
||||
DATABASE_URL = config.DATABASE_URL
|
||||
|
||||
|
||||
class DocumentType(str, Enum):
|
||||
EXTENSION = "EXTENSION"
|
||||
CRAWLED_URL = "CRAWLED_URL"
|
||||
FILE = "FILE"
|
||||
SLACK_CONNECTOR = "SLACK_CONNECTOR"
|
||||
NOTION_CONNECTOR = "NOTION_CONNECTOR"
|
||||
|
||||
class SearchSourceConnectorType(str, Enum):
|
||||
SERPER_API = "SERPER_API"
|
||||
TAVILY_API = "TAVILY_API"
|
||||
SLACK_CONNECTOR = "SLACK_CONNECTOR"
|
||||
NOTION_CONNECTOR = "NOTION_CONNECTOR"
|
||||
|
||||
class ChatType(str, Enum):
|
||||
GENERAL = "GENERAL"
|
||||
DEEP = "DEEP"
|
||||
DEEPER = "DEEPER"
|
||||
DEEPEST = "DEEPEST"
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class TimestampMixin:
|
||||
@declared_attr
|
||||
def created_at(cls):
|
||||
return Column(TIMESTAMP(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc), index=True)
|
||||
|
||||
class BaseModel(Base):
|
||||
__abstract__ = True
|
||||
__allow_unmapped__ = True
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
class Chat(BaseModel, TimestampMixin):
|
||||
__tablename__ = "chats"
|
||||
|
||||
type = Column(SQLAlchemyEnum(ChatType), nullable=False)
|
||||
title = Column(String(200), nullable=False, index=True)
|
||||
initial_connectors = Column(ARRAY(String), nullable=True)
|
||||
messages = Column(JSON, nullable=False)
|
||||
|
||||
search_space_id = Column(Integer, ForeignKey('searchspaces.id', ondelete='CASCADE'), nullable=False)
|
||||
search_space = relationship('SearchSpace', back_populates='chats')
|
||||
|
||||
class Document(BaseModel, TimestampMixin):
|
||||
__tablename__ = "documents"
|
||||
|
||||
title = Column(String(200), nullable=False, index=True)
|
||||
document_type = Column(SQLAlchemyEnum(DocumentType), nullable=False)
|
||||
document_metadata = Column(JSON, nullable=True)
|
||||
|
||||
content = Column(Text, nullable=False)
|
||||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||||
|
||||
search_space_id = Column(Integer, ForeignKey("searchspaces.id", ondelete='CASCADE'), nullable=False)
|
||||
search_space = relationship("SearchSpace", back_populates="documents")
|
||||
chunks = relationship("Chunk", back_populates="document", cascade="all, delete-orphan")
|
||||
|
||||
class Chunk(BaseModel, TimestampMixin):
|
||||
__tablename__ = "chunks"
|
||||
|
||||
content = Column(Text, nullable=False)
|
||||
embedding = Column(Vector(config.embedding_model_instance.dimension))
|
||||
|
||||
document_id = Column(Integer, ForeignKey("documents.id", ondelete='CASCADE'), nullable=False)
|
||||
document = relationship("Document", back_populates="chunks")
|
||||
|
||||
class Podcast(BaseModel, TimestampMixin):
|
||||
__tablename__ = "podcasts"
|
||||
|
||||
title = Column(String(200), nullable=False, index=True)
|
||||
is_generated = Column(Boolean, nullable=False, default=False)
|
||||
podcast_content = Column(Text, nullable=False, default="")
|
||||
file_location = Column(String(500), nullable=False, default="")
|
||||
|
||||
search_space_id = Column(Integer, ForeignKey("searchspaces.id", ondelete='CASCADE'), nullable=False)
|
||||
search_space = relationship("SearchSpace", back_populates="podcasts")
|
||||
|
||||
class SearchSpace(BaseModel, TimestampMixin):
|
||||
__tablename__ = "searchspaces"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
description = Column(String(500), nullable=True)
|
||||
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False)
|
||||
user = relationship("User", back_populates="search_spaces")
|
||||
|
||||
documents = relationship("Document", back_populates="search_space", order_by="Document.id", cascade="all, delete-orphan")
|
||||
podcasts = relationship("Podcast", back_populates="search_space", order_by="Podcast.id", cascade="all, delete-orphan")
|
||||
chats = relationship('Chat', back_populates='search_space', order_by='Chat.id', cascade="all, delete-orphan")
|
||||
|
||||
class SearchSourceConnector(BaseModel, TimestampMixin):
|
||||
__tablename__ = "search_source_connectors"
|
||||
|
||||
name = Column(String(100), nullable=False, index=True)
|
||||
connector_type = Column(SQLAlchemyEnum(SearchSourceConnectorType), nullable=False, unique=True)
|
||||
is_indexable = Column(Boolean, nullable=False, default=False)
|
||||
last_indexed_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
config = Column(JSON, nullable=False)
|
||||
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("user.id", ondelete='CASCADE'), nullable=False)
|
||||
user = relationship("User", back_populates="search_source_connectors")
|
||||
|
||||
|
||||
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
|
||||
pass
|
||||
|
||||
|
||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
|
||||
"OAuthAccount", lazy="joined"
|
||||
)
|
||||
search_spaces = relationship("SearchSpace", back_populates="user")
|
||||
search_source_connectors = relationship("SearchSourceConnector", back_populates="user")
|
||||
|
||||
|
||||
engine = create_async_engine(DATABASE_URL)
|
||||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
async def setup_indexes():
|
||||
async with engine.begin() as conn:
|
||||
# Create indexes
|
||||
# Document Summary Indexes
|
||||
await conn.execute(text('CREATE INDEX IF NOT EXISTS document_vector_index ON documents USING hnsw (embedding public.vector_cosine_ops)'))
|
||||
await conn.execute(text('CREATE INDEX IF NOT EXISTS document_search_index ON documents USING gin (to_tsvector(\'english\', content))'))
|
||||
# Document Chuck Indexes
|
||||
await conn.execute(text('CREATE INDEX IF NOT EXISTS chucks_vector_index ON chunks USING hnsw (embedding public.vector_cosine_ops)'))
|
||||
await conn.execute(text('CREATE INDEX IF NOT EXISTS chucks_search_index ON chunks USING gin (to_tsvector(\'english\', content))'))
|
||||
|
||||
async def create_db_and_tables():
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await setup_indexes()
|
||||
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with async_session_maker() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
|
||||
yield SQLAlchemyUserDatabase(session, User, OAuthAccount)
|
||||
|
||||
async def get_chucks_hybrid_search_retriever(session: AsyncSession = Depends(get_async_session)):
|
||||
return ChucksHybridSearchRetriever(session)
|
||||
103
surfsense_backend/app/prompts/__init__.py
Normal file
103
surfsense_backend/app/prompts/__init__.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from datetime import datetime, timezone
|
||||
|
||||
DATE_TODAY = "Today's date is " + datetime.now(timezone.utc).astimezone().isoformat() + '\n'
|
||||
|
||||
SUMMARY_PROMPT = DATE_TODAY + """
|
||||
<INSTRUCTIONS>
|
||||
<context>
|
||||
You are an expert document analyst and summarization specialist tasked with distilling complex information into clear,
|
||||
comprehensive summaries. Your role is to analyze documents thoroughly and create structured summaries that:
|
||||
1. Capture the complete essence and key insights of the source material
|
||||
2. Maintain perfect accuracy and factual precision
|
||||
3. Present information objectively without bias or interpretation
|
||||
4. Preserve critical context and logical relationships
|
||||
5. Structure content in a clear, hierarchical format
|
||||
</context>
|
||||
|
||||
<principles>
|
||||
<accuracy>
|
||||
- Maintain absolute factual accuracy and fidelity to source material
|
||||
- Avoid any subjective interpretation, inference or speculation
|
||||
- Preserve complete original meaning, nuance and contextual relationships
|
||||
- Report all quantitative data with precise values and appropriate units
|
||||
- Verify and cross-reference facts before inclusion
|
||||
- Flag any ambiguous or unclear information
|
||||
</accuracy>
|
||||
|
||||
<objectivity>
|
||||
- Present information with strict neutrality and impartiality
|
||||
- Exclude all forms of bias, personal opinions, and editorial commentary
|
||||
- Ensure balanced representation of all perspectives and viewpoints
|
||||
- Maintain objective professional distance from the content
|
||||
- Use precise, factual language free from emotional coloring
|
||||
- Focus solely on verifiable information and evidence
|
||||
</objectivity>
|
||||
|
||||
<comprehensiveness>
|
||||
- Capture all essential information, key themes, and central arguments
|
||||
- Preserve critical context and background necessary for understanding
|
||||
- Include relevant supporting details, examples, and evidence
|
||||
- Maintain logical flow and connections between concepts
|
||||
- Ensure hierarchical organization of information
|
||||
- Document relationships between different components
|
||||
- Highlight dependencies and causal links
|
||||
- Track chronological progression where relevant
|
||||
</comprehensiveness>
|
||||
</principles>
|
||||
|
||||
<output_format>
|
||||
<type>
|
||||
- Return summary in clean markdown format
|
||||
- Do not include markdown code block tags (```markdown ```)
|
||||
- Use standard markdown syntax for formatting (headers, lists, etc.)
|
||||
- Use # for main headings (e.g., # EXECUTIVE SUMMARY)
|
||||
- Use ## for subheadings where appropriate
|
||||
- Use bullet points (- item) for lists
|
||||
- Ensure proper indentation and spacing
|
||||
- Use appropriate emphasis (**bold**, *italic*) where needed
|
||||
</type>
|
||||
<style>
|
||||
- Use clear, concise language focused on key points
|
||||
- Maintain professional and objective tone throughout
|
||||
- Follow consistent formatting and style conventions
|
||||
- Provide descriptive section headings and subheadings
|
||||
- Utilize bullet points and lists for better readability
|
||||
- Structure content with clear hierarchy and organization
|
||||
- Avoid jargon and overly technical language
|
||||
- Include transition sentences between sections
|
||||
</style>
|
||||
</output_format>
|
||||
|
||||
<validation>
|
||||
<criteria>
|
||||
- Verify all facts and claims match source material exactly
|
||||
- Cross-reference and validate all numerical data points
|
||||
- Ensure logical flow and consistency throughout summary
|
||||
- Confirm comprehensive coverage of key information
|
||||
- Check for objective, unbiased language and tone
|
||||
- Validate accurate representation of source context
|
||||
- Review for proper attribution of ideas and quotes
|
||||
- Verify temporal accuracy and chronological order
|
||||
</criteria>
|
||||
</validation>
|
||||
|
||||
<length_guidelines>
|
||||
- Scale summary length proportionally to source document complexity and length
|
||||
- Minimum: 3-5 well-developed paragraphs per major section
|
||||
- Maximum: 8-10 paragraphs per section for highly complex documents
|
||||
- Adjust level of detail based on information density and importance
|
||||
- Ensure key concepts receive adequate coverage regardless of length
|
||||
</length_guidelines>
|
||||
|
||||
Now, create a summary of the following document:
|
||||
<document_to_summarize>
|
||||
{document}
|
||||
</document_to_summarize>
|
||||
</INSTRUCTIONS>
|
||||
"""
|
||||
|
||||
SUMMARY_PROMPT_TEMPLATE = PromptTemplate(
|
||||
input_variables=["document"],
|
||||
template=SUMMARY_PROMPT
|
||||
)
|
||||
0
surfsense_backend/app/retriver/__init__.py
Normal file
0
surfsense_backend/app/retriver/__init__.py
Normal file
243
surfsense_backend/app/retriver/chunks_hybrid_search.py
Normal file
243
surfsense_backend/app/retriver/chunks_hybrid_search.py
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
class ChucksHybridSearchRetriever:
|
||||
def __init__(self, db_session):
|
||||
"""
|
||||
Initialize the hybrid search retriever with a database session.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy AsyncSession from FastAPI dependency injection
|
||||
"""
|
||||
self.db_session = db_session
|
||||
|
||||
async def vector_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
|
||||
"""
|
||||
Perform vector similarity search on chunks.
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
top_k: Number of results to return
|
||||
user_id: The ID of the user performing the search
|
||||
search_space_id: Optional search space ID to filter results
|
||||
|
||||
Returns:
|
||||
List of chunks sorted by vector similarity
|
||||
"""
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.orm import joinedload
|
||||
from app.db import Chunk, Document, SearchSpace
|
||||
from app.config import config
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
|
||||
# Build the base query with user ownership check
|
||||
query = (
|
||||
select(Chunk)
|
||||
.options(joinedload(Chunk.document).joinedload(Document.search_space))
|
||||
.join(Document, Chunk.document_id == Document.id)
|
||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(SearchSpace.user_id == user_id)
|
||||
)
|
||||
|
||||
# Add search space filter if provided
|
||||
if search_space_id is not None:
|
||||
query = query.where(Document.search_space_id == search_space_id)
|
||||
|
||||
# Add vector similarity ordering
|
||||
query = (
|
||||
query
|
||||
.order_by(Chunk.embedding.op("<=>")(query_embedding))
|
||||
.limit(top_k)
|
||||
)
|
||||
|
||||
# Execute the query
|
||||
result = await self.db_session.execute(query)
|
||||
chunks = result.scalars().all()
|
||||
|
||||
return chunks
|
||||
|
||||
async def full_text_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None) -> list:
|
||||
"""
|
||||
Perform full-text keyword search on chunks.
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
top_k: Number of results to return
|
||||
user_id: The ID of the user performing the search
|
||||
search_space_id: Optional search space ID to filter results
|
||||
|
||||
Returns:
|
||||
List of chunks sorted by text relevance
|
||||
"""
|
||||
from sqlalchemy import select, func, text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from app.db import Chunk, Document, SearchSpace
|
||||
|
||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||
tsvector = func.to_tsvector('english', Chunk.content)
|
||||
tsquery = func.plainto_tsquery('english', query_text)
|
||||
|
||||
# Build the base query with user ownership check
|
||||
query = (
|
||||
select(Chunk)
|
||||
.options(joinedload(Chunk.document).joinedload(Document.search_space))
|
||||
.join(Document, Chunk.document_id == Document.id)
|
||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(SearchSpace.user_id == user_id)
|
||||
.where(tsvector.op("@@")(tsquery)) # Only include results that match the query
|
||||
)
|
||||
|
||||
# Add search space filter if provided
|
||||
if search_space_id is not None:
|
||||
query = query.where(Document.search_space_id == search_space_id)
|
||||
|
||||
# Add text search ranking
|
||||
query = (
|
||||
query
|
||||
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
||||
.limit(top_k)
|
||||
)
|
||||
|
||||
# Execute the query
|
||||
result = await self.db_session.execute(query)
|
||||
chunks = result.scalars().all()
|
||||
|
||||
return chunks
|
||||
|
||||
async def hybrid_search(self, query_text: str, top_k: int, user_id: str, search_space_id: int = None, document_type: str = None) -> list:
|
||||
"""
|
||||
Combine vector similarity and full-text search results using Reciprocal Rank Fusion.
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
top_k: Number of results to return
|
||||
user_id: The ID of the user performing the search
|
||||
search_space_id: Optional search space ID to filter results
|
||||
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
|
||||
|
||||
Returns:
|
||||
List of dictionaries containing chunk data and relevance scores
|
||||
"""
|
||||
from sqlalchemy import select, func, text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from app.db import Chunk, Document, SearchSpace, DocumentType
|
||||
from app.config import config
|
||||
|
||||
# Get embedding for the query
|
||||
embedding_model = config.embedding_model_instance
|
||||
query_embedding = embedding_model.embed(query_text)
|
||||
|
||||
# Constants for RRF calculation
|
||||
k = 60 # Constant for RRF calculation
|
||||
n_results = top_k * 2 # Get more results for better fusion
|
||||
|
||||
# Create tsvector and tsquery for PostgreSQL full-text search
|
||||
tsvector = func.to_tsvector('english', Chunk.content)
|
||||
tsquery = func.plainto_tsquery('english', query_text)
|
||||
|
||||
# Base conditions for document filtering
|
||||
base_conditions = [SearchSpace.user_id == user_id]
|
||||
|
||||
# Add search space filter if provided
|
||||
if search_space_id is not None:
|
||||
base_conditions.append(Document.search_space_id == search_space_id)
|
||||
|
||||
# Add document type filter if provided
|
||||
if document_type is not None:
|
||||
# Convert string to enum value if needed
|
||||
if isinstance(document_type, str):
|
||||
try:
|
||||
doc_type_enum = DocumentType[document_type]
|
||||
base_conditions.append(Document.document_type == doc_type_enum)
|
||||
except KeyError:
|
||||
# If the document type doesn't exist in the enum, return empty results
|
||||
return []
|
||||
else:
|
||||
base_conditions.append(Document.document_type == document_type)
|
||||
|
||||
# CTE for semantic search with user ownership check
|
||||
semantic_search_cte = (
|
||||
select(
|
||||
Chunk.id,
|
||||
func.rank().over(order_by=Chunk.embedding.op("<=>")(query_embedding)).label("rank")
|
||||
)
|
||||
.join(Document, Chunk.document_id == Document.id)
|
||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(*base_conditions)
|
||||
)
|
||||
|
||||
semantic_search_cte = (
|
||||
semantic_search_cte
|
||||
.order_by(Chunk.embedding.op("<=>")(query_embedding))
|
||||
.limit(n_results)
|
||||
.cte("semantic_search")
|
||||
)
|
||||
|
||||
# CTE for keyword search with user ownership check
|
||||
keyword_search_cte = (
|
||||
select(
|
||||
Chunk.id,
|
||||
func.rank().over(order_by=func.ts_rank_cd(tsvector, tsquery).desc()).label("rank")
|
||||
)
|
||||
.join(Document, Chunk.document_id == Document.id)
|
||||
.join(SearchSpace, Document.search_space_id == SearchSpace.id)
|
||||
.where(*base_conditions)
|
||||
.where(tsvector.op("@@")(tsquery))
|
||||
)
|
||||
|
||||
keyword_search_cte = (
|
||||
keyword_search_cte
|
||||
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
||||
.limit(n_results)
|
||||
.cte("keyword_search")
|
||||
)
|
||||
|
||||
# Final combined query using a FULL OUTER JOIN with RRF scoring
|
||||
final_query = (
|
||||
select(
|
||||
Chunk,
|
||||
(
|
||||
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0) +
|
||||
func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
|
||||
).label("score")
|
||||
)
|
||||
.select_from(
|
||||
semantic_search_cte.outerjoin(
|
||||
keyword_search_cte,
|
||||
semantic_search_cte.c.id == keyword_search_cte.c.id,
|
||||
full=True
|
||||
)
|
||||
)
|
||||
.join(
|
||||
Chunk,
|
||||
Chunk.id == func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id)
|
||||
)
|
||||
.options(joinedload(Chunk.document))
|
||||
.order_by(text("score DESC"))
|
||||
.limit(top_k)
|
||||
)
|
||||
|
||||
# Execute the query
|
||||
result = await self.db_session.execute(final_query)
|
||||
chunks_with_scores = result.all()
|
||||
|
||||
# If no results were found, return an empty list
|
||||
if not chunks_with_scores:
|
||||
return []
|
||||
|
||||
# Convert to serializable dictionaries if no reranker is available or if reranking failed
|
||||
serialized_results = []
|
||||
for chunk, score in chunks_with_scores:
|
||||
serialized_results.append({
|
||||
"chunk_id": chunk.id,
|
||||
"content": chunk.content,
|
||||
"score": float(score), # Ensure score is a Python float
|
||||
"document": {
|
||||
"id": chunk.document.id,
|
||||
"title": chunk.document.title,
|
||||
"document_type": chunk.document.document_type.value if hasattr(chunk.document, 'document_type') else None,
|
||||
"metadata": chunk.document.document_metadata
|
||||
}
|
||||
})
|
||||
|
||||
return serialized_results
|
||||
14
surfsense_backend/app/routes/__init__.py
Normal file
14
surfsense_backend/app/routes/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from fastapi import APIRouter
|
||||
from .search_spaces_routes import router as search_spaces_router
|
||||
from .documents_routes import router as documents_router
|
||||
from .podcasts_routes import router as podcasts_router
|
||||
from .chats_routes import router as chats_router
|
||||
from .search_source_connectors_routes import router as search_source_connectors_router
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
router.include_router(search_spaces_router)
|
||||
router.include_router(documents_router)
|
||||
router.include_router(podcasts_router)
|
||||
router.include_router(chats_router)
|
||||
router.include_router(search_source_connectors_router)
|
||||
260
surfsense_backend/app/routes/chats_routes.py
Normal file
260
surfsense_backend/app/routes/chats_routes.py
Normal file
|
|
@ -0,0 +1,260 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from typing import List
|
||||
from app.db import get_async_session, User, SearchSpace, Chat
|
||||
from app.schemas import ChatCreate, ChatUpdate, ChatRead, AISDKChatRequest
|
||||
from app.tasks.stream_connector_search_results import stream_connector_search_results
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/chat")
|
||||
async def handle_chat_data(
|
||||
request: AISDKChatRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
messages = request.messages
|
||||
if messages[-1].role != "user":
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Last message must be a user message")
|
||||
|
||||
user_query = messages[-1].content
|
||||
search_space_id = request.data.get('search_space_id')
|
||||
research_mode: str = request.data.get('research_mode')
|
||||
selected_connectors: List[str] = request.data.get('selected_connectors')
|
||||
|
||||
# Convert search_space_id to integer if it's a string
|
||||
if search_space_id and isinstance(search_space_id, str):
|
||||
try:
|
||||
search_space_id = int(search_space_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid search_space_id format")
|
||||
|
||||
# Check if the search space belongs to the current user
|
||||
try:
|
||||
await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
except HTTPException:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have access to this search space")
|
||||
|
||||
response = StreamingResponse(stream_connector_search_results(
|
||||
user_query,
|
||||
user.id,
|
||||
search_space_id,
|
||||
session,
|
||||
research_mode,
|
||||
selected_connectors
|
||||
))
|
||||
response.headers['x-vercel-ai-data-stream'] = 'v1'
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/chats/", response_model=ChatRead)
|
||||
async def create_chat(
|
||||
chat: ChatCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
await check_ownership(session, SearchSpace, chat.search_space_id, user)
|
||||
db_chat = Chat(**chat.model_dump())
|
||||
session.add(db_chat)
|
||||
await session.commit()
|
||||
await session.refresh(db_chat)
|
||||
return db_chat
|
||||
except HTTPException:
|
||||
raise
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Database constraint violation. Please check your input data.")
|
||||
except OperationalError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later.")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while creating the chat.")
|
||||
|
||||
|
||||
@router.get("/chats/", response_model=List[ChatRead])
|
||||
async def read_chats(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(Chat)
|
||||
.join(SearchSpace)
|
||||
.filter(SearchSpace.user_id == user.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
except OperationalError:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later.")
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while fetching chats.")
|
||||
|
||||
|
||||
@router.get("/chats/{chat_id}", response_model=ChatRead)
|
||||
async def read_chat(
|
||||
chat_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(Chat)
|
||||
.join(SearchSpace)
|
||||
.filter(Chat.id == chat_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
chat = result.scalars().first()
|
||||
if not chat:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Chat not found or you don't have permission to access it")
|
||||
return chat
|
||||
except OperationalError:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later.")
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while fetching the chat.")
|
||||
|
||||
|
||||
@router.put("/chats/{chat_id}", response_model=ChatRead)
|
||||
async def update_chat(
|
||||
chat_id: int,
|
||||
chat_update: ChatUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
db_chat = await read_chat(chat_id, session, user)
|
||||
update_data = chat_update.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_chat, key, value)
|
||||
await session.commit()
|
||||
await session.refresh(db_chat)
|
||||
return db_chat
|
||||
except HTTPException:
|
||||
raise
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Database constraint violation. Please check your input data.")
|
||||
except OperationalError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later.")
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while updating the chat.")
|
||||
|
||||
|
||||
@router.delete("/chats/{chat_id}", response_model=dict)
|
||||
async def delete_chat(
|
||||
chat_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
db_chat = await read_chat(chat_id, session, user)
|
||||
await session.delete(db_chat)
|
||||
await session.commit()
|
||||
return {"message": "Chat deleted successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Cannot delete chat due to existing dependencies.")
|
||||
except OperationalError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Database operation failed. Please try again later.")
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="An unexpected error occurred while deleting the chat.")
|
||||
|
||||
|
||||
# test_data = [
|
||||
# {
|
||||
# "type": "TERMINAL_INFO",
|
||||
# "content": [
|
||||
# {
|
||||
# "id": 1,
|
||||
# "text": "Starting to search for crawled URLs...",
|
||||
# "type": "info"
|
||||
# },
|
||||
# {
|
||||
# "id": 2,
|
||||
# "text": "Found 2 relevant crawled URLs",
|
||||
# "type": "success"
|
||||
# }
|
||||
# ]
|
||||
# },
|
||||
# {
|
||||
# "type": "SOURCES",
|
||||
# "content": [
|
||||
# {
|
||||
# "id": 1,
|
||||
# "name": "Crawled URLs",
|
||||
# "type": "CRAWLED_URL",
|
||||
# "sources": [
|
||||
# {
|
||||
# "id": 1,
|
||||
# "title": "Webpage Title",
|
||||
# "description": "Webpage Dec",
|
||||
# "url": "https://jsoneditoronline.org/"
|
||||
# },
|
||||
# {
|
||||
# "id": 2,
|
||||
# "title": "Webpage Title",
|
||||
# "description": "Webpage Dec",
|
||||
# "url": "https://www.google.com/"
|
||||
# }
|
||||
# ]
|
||||
# },
|
||||
# {
|
||||
# "id": 2,
|
||||
# "name": "Files",
|
||||
# "type": "FILE",
|
||||
# "sources": [
|
||||
# {
|
||||
# "id": 3,
|
||||
# "title": "Webpage Title",
|
||||
# "description": "Webpage Dec",
|
||||
# "url": "https://jsoneditoronline.org/"
|
||||
# },
|
||||
# {
|
||||
# "id": 4,
|
||||
# "title": "Webpage Title",
|
||||
# "description": "Webpage Dec",
|
||||
# "url": "https://www.google.com/"
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
# ]
|
||||
# },
|
||||
# {
|
||||
# "type": "ANSWER",
|
||||
# "content": [
|
||||
# "## SurfSense Introduction",
|
||||
# "Surfsense is A Personal NotebookLM and Perplexity-like AI Assistant for Everyone. Research and Never forget Anything. [1] [3]"
|
||||
# ]
|
||||
# }
|
||||
# ]
|
||||
262
surfsense_backend/app/routes/documents_routes.py
Normal file
262
surfsense_backend/app/routes/documents_routes.py
Normal file
|
|
@ -0,0 +1,262 @@
|
|||
from fastapi import APIRouter, Depends, BackgroundTasks, UploadFile, Form, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from typing import List
|
||||
from app.db import get_async_session, User, SearchSpace, Document, DocumentType
|
||||
from app.schemas import DocumentsCreate, DocumentUpdate, DocumentRead
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
from app.tasks.background_tasks import add_extension_received_document, add_received_file_document, add_crawled_url_document
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
from app.config import config
|
||||
import json
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/documents/")
|
||||
async def create_documents(
|
||||
request: DocumentsCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
fastapi_background_tasks: BackgroundTasks = BackgroundTasks()
|
||||
):
|
||||
try:
|
||||
# Check if the user owns the search space
|
||||
await check_ownership(session, SearchSpace, request.search_space_id, user)
|
||||
|
||||
if request.document_type == DocumentType.EXTENSION:
|
||||
for individual_document in request.content:
|
||||
fastapi_background_tasks.add_task(
|
||||
add_extension_received_document,
|
||||
session,
|
||||
individual_document,
|
||||
request.search_space_id
|
||||
)
|
||||
elif request.document_type == DocumentType.CRAWLED_URL:
|
||||
for url in request.content:
|
||||
fastapi_background_tasks.add_task(
|
||||
add_crawled_url_document,
|
||||
session,
|
||||
url,
|
||||
request.search_space_id
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid document type"
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
return {"message": "Documents processed successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to process documents: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/documents/fileupload")
|
||||
async def create_documents(
|
||||
files: list[UploadFile],
|
||||
search_space_id: int = Form(...),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
fastapi_background_tasks: BackgroundTasks = BackgroundTasks()
|
||||
):
|
||||
try:
|
||||
await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files provided")
|
||||
|
||||
for file in files:
|
||||
try:
|
||||
unstructured_loader = UnstructuredLoader(
|
||||
file=file.file,
|
||||
api_key=config.UNSTRUCTURED_API_KEY,
|
||||
partition_via_api=True,
|
||||
languages=["eng"],
|
||||
include_orig_elements=False,
|
||||
strategy="fast",
|
||||
)
|
||||
|
||||
unstructured_processed_elements = await unstructured_loader.aload()
|
||||
|
||||
fastapi_background_tasks.add_task(
|
||||
add_received_file_document,
|
||||
session,
|
||||
file.filename,
|
||||
unstructured_processed_elements,
|
||||
search_space_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Failed to process file {file.filename}: {str(e)}"
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
return {"message": "Files added for processing successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to process documents: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/documents/", response_model=List[DocumentRead])
|
||||
async def read_documents(
|
||||
skip: int = 0,
|
||||
limit: int = 300,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(Document)
|
||||
.join(SearchSpace)
|
||||
.filter(SearchSpace.user_id == user.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
db_documents = result.scalars().all()
|
||||
|
||||
# Convert database objects to API-friendly format
|
||||
api_documents = []
|
||||
for doc in db_documents:
|
||||
api_documents.append(DocumentRead(
|
||||
id=doc.id,
|
||||
title=doc.title,
|
||||
document_type=doc.document_type,
|
||||
document_metadata=doc.document_metadata,
|
||||
content=doc.content,
|
||||
created_at=doc.created_at,
|
||||
search_space_id=doc.search_space_id
|
||||
))
|
||||
|
||||
return api_documents
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch documents: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/documents/{document_id}", response_model=DocumentRead)
|
||||
async def read_document(
|
||||
document_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(Document)
|
||||
.join(SearchSpace)
|
||||
.filter(Document.id == document_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
if not document:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document with id {document_id} not found"
|
||||
)
|
||||
|
||||
# Convert database object to API-friendly format
|
||||
return DocumentRead(
|
||||
id=document.id,
|
||||
title=document.title,
|
||||
document_type=document.document_type,
|
||||
document_metadata=document.document_metadata,
|
||||
content=document.content,
|
||||
created_at=document.created_at,
|
||||
search_space_id=document.search_space_id
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch document: {str(e)}"
|
||||
)
|
||||
|
||||
@router.put("/documents/{document_id}", response_model=DocumentRead)
|
||||
async def update_document(
|
||||
document_id: int,
|
||||
document_update: DocumentUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
# Query the document directly instead of using read_document function
|
||||
result = await session.execute(
|
||||
select(Document)
|
||||
.join(SearchSpace)
|
||||
.filter(Document.id == document_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
db_document = result.scalars().first()
|
||||
|
||||
if not db_document:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document with id {document_id} not found"
|
||||
)
|
||||
|
||||
update_data = document_update.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_document, key, value)
|
||||
await session.commit()
|
||||
await session.refresh(db_document)
|
||||
|
||||
# Convert to DocumentRead for response
|
||||
return DocumentRead(
|
||||
id=db_document.id,
|
||||
title=db_document.title,
|
||||
document_type=db_document.document_type,
|
||||
document_metadata=db_document.document_metadata,
|
||||
content=db_document.content,
|
||||
created_at=db_document.created_at,
|
||||
search_space_id=db_document.search_space_id
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update document: {str(e)}"
|
||||
)
|
||||
|
||||
@router.delete("/documents/{document_id}", response_model=dict)
|
||||
async def delete_document(
|
||||
document_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
# Query the document directly instead of using read_document function
|
||||
result = await session.execute(
|
||||
select(Document)
|
||||
.join(SearchSpace)
|
||||
.filter(Document.id == document_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
document = result.scalars().first()
|
||||
|
||||
if not document:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Document with id {document_id} not found"
|
||||
)
|
||||
|
||||
await session.delete(document)
|
||||
await session.commit()
|
||||
return {"message": "Document deleted successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete document: {str(e)}"
|
||||
)
|
||||
122
surfsense_backend/app/routes/podcasts_routes.py
Normal file
122
surfsense_backend/app/routes/podcasts_routes.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from typing import List
|
||||
from app.db import get_async_session, User, SearchSpace, Podcast
|
||||
from app.schemas import PodcastCreate, PodcastUpdate, PodcastRead
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/podcasts/", response_model=PodcastRead)
|
||||
async def create_podcast(
|
||||
podcast: PodcastCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
await check_ownership(session, SearchSpace, podcast.search_space_id, user)
|
||||
db_podcast = Podcast(**podcast.model_dump())
|
||||
session.add(db_podcast)
|
||||
await session.commit()
|
||||
await session.refresh(db_podcast)
|
||||
return db_podcast
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=400, detail="Podcast creation failed due to constraint violation")
|
||||
except SQLAlchemyError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while creating podcast")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="An unexpected error occurred")
|
||||
|
||||
@router.get("/podcasts/", response_model=List[PodcastRead])
|
||||
async def read_podcasts(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
if skip < 0 or limit < 1:
|
||||
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(Podcast)
|
||||
.join(SearchSpace)
|
||||
.filter(SearchSpace.user_id == user.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while fetching podcasts")
|
||||
|
||||
@router.get("/podcasts/{podcast_id}", response_model=PodcastRead)
|
||||
async def read_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(Podcast)
|
||||
.join(SearchSpace)
|
||||
.filter(Podcast.id == podcast_id, SearchSpace.user_id == user.id)
|
||||
)
|
||||
podcast = result.scalars().first()
|
||||
if not podcast:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Podcast not found or you don't have permission to access it"
|
||||
)
|
||||
return podcast
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while fetching podcast")
|
||||
|
||||
@router.put("/podcasts/{podcast_id}", response_model=PodcastRead)
|
||||
async def update_podcast(
|
||||
podcast_id: int,
|
||||
podcast_update: PodcastUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
db_podcast = await read_podcast(podcast_id, session, user)
|
||||
update_data = podcast_update.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_podcast, key, value)
|
||||
await session.commit()
|
||||
await session.refresh(db_podcast)
|
||||
return db_podcast
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except IntegrityError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=400, detail="Update failed due to constraint violation")
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while updating podcast")
|
||||
|
||||
@router.delete("/podcasts/{podcast_id}", response_model=dict)
|
||||
async def delete_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
db_podcast = await read_podcast(podcast_id, session, user)
|
||||
await session.delete(db_podcast)
|
||||
await session.commit()
|
||||
return {"message": "Podcast deleted successfully"}
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(status_code=500, detail="Database error occurred while deleting podcast")
|
||||
418
surfsense_backend/app/routes/search_source_connectors_routes.py
Normal file
418
surfsense_backend/app/routes/search_source_connectors_routes.py
Normal file
|
|
@ -0,0 +1,418 @@
|
|||
"""
|
||||
SearchSourceConnector routes for CRUD operations:
|
||||
POST /search-source-connectors/ - Create a new connector
|
||||
GET /search-source-connectors/ - List all connectors for the current user
|
||||
GET /search-source-connectors/{connector_id} - Get a specific connector
|
||||
PUT /search-source-connectors/{connector_id} - Update a specific connector
|
||||
DELETE /search-source-connectors/{connector_id} - Delete a specific connector
|
||||
POST /search-source-connectors/{connector_id}/index - Index content from a connector to a search space
|
||||
|
||||
Note: Each user can have only one connector of each type (SERPER_API, TAVILY_API, SLACK_CONNECTOR, NOTION_CONNECTOR).
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from typing import List, Dict, Any
|
||||
from app.db import get_async_session, User, SearchSourceConnector, SearchSourceConnectorType, SearchSpace
|
||||
from app.schemas import SearchSourceConnectorCreate, SearchSourceConnectorUpdate, SearchSourceConnectorRead
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
from pydantic import ValidationError
|
||||
from app.tasks.connectors_indexing_tasks import index_slack_messages, index_notion_pages
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/search-source-connectors/", response_model=SearchSourceConnectorRead)
|
||||
async def create_search_source_connector(
|
||||
connector: SearchSourceConnectorCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
"""
|
||||
Create a new search source connector.
|
||||
|
||||
Each user can have only one connector of each type (SERPER_API, TAVILY_API, SLACK_CONNECTOR).
|
||||
The config must contain the appropriate keys for the connector type.
|
||||
"""
|
||||
try:
|
||||
# Check if a connector with the same type already exists for this user
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.connector_type == connector.connector_type
|
||||
)
|
||||
)
|
||||
existing_connector = result.scalars().first()
|
||||
|
||||
if existing_connector:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"A connector with type {connector.connector_type} already exists. Each user can have only one connector of each type."
|
||||
)
|
||||
|
||||
db_connector = SearchSourceConnector(**connector.model_dump(), user_id=user.id)
|
||||
session.add(db_connector)
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
return db_connector
|
||||
except ValidationError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Validation error: {str(e)}"
|
||||
)
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Integrity error: A connector with this type already exists. {str(e)}"
|
||||
)
|
||||
except HTTPException:
|
||||
await session.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to create search source connector: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/search-source-connectors/", response_model=List[SearchSourceConnectorRead])
|
||||
async def read_search_source_connectors(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
"""List all search source connectors for the current user."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(SearchSourceConnector.user_id == user.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch search source connectors: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/search-source-connectors/{connector_id}", response_model=SearchSourceConnectorRead)
|
||||
async def read_search_source_connector(
|
||||
connector_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
"""Get a specific search source connector by ID."""
|
||||
try:
|
||||
return await check_ownership(session, SearchSourceConnector, connector_id, user)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch search source connector: {str(e)}"
|
||||
)
|
||||
|
||||
@router.put("/search-source-connectors/{connector_id}", response_model=SearchSourceConnectorRead)
|
||||
async def update_search_source_connector(
|
||||
connector_id: int,
|
||||
connector_update: SearchSourceConnectorUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
"""
|
||||
Update a search source connector.
|
||||
|
||||
Each user can have only one connector of each type (SERPER_API, TAVILY_API, SLACK_CONNECTOR).
|
||||
The config must contain the appropriate keys for the connector type.
|
||||
"""
|
||||
try:
|
||||
db_connector = await check_ownership(session, SearchSourceConnector, connector_id, user)
|
||||
|
||||
# If connector type is being changed, check if one of that type already exists
|
||||
if connector_update.connector_type != db_connector.connector_type:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.connector_type == connector_update.connector_type,
|
||||
SearchSourceConnector.id != connector_id
|
||||
)
|
||||
)
|
||||
existing_connector = result.scalars().first()
|
||||
|
||||
if existing_connector:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"A connector with type {connector_update.connector_type} already exists. Each user can have only one connector of each type."
|
||||
)
|
||||
|
||||
update_data = connector_update.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_connector, key, value)
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
return db_connector
|
||||
except ValidationError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Validation error: {str(e)}"
|
||||
)
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Integrity error: A connector with this type already exists. {str(e)}"
|
||||
)
|
||||
except HTTPException:
|
||||
await session.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update search source connector: {str(e)}"
|
||||
)
|
||||
|
||||
@router.delete("/search-source-connectors/{connector_id}", response_model=dict)
|
||||
async def delete_search_source_connector(
|
||||
connector_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
"""Delete a search source connector."""
|
||||
try:
|
||||
db_connector = await check_ownership(session, SearchSourceConnector, connector_id, user)
|
||||
await session.delete(db_connector)
|
||||
await session.commit()
|
||||
return {"message": "Search source connector deleted successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete search source connector: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/search-source-connectors/{connector_id}/index", response_model=Dict[str, Any])
|
||||
async def index_connector_content(
|
||||
connector_id: int,
|
||||
search_space_id: int = Query(..., description="ID of the search space to store indexed content"),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
background_tasks: BackgroundTasks = None
|
||||
):
|
||||
"""
|
||||
Index content from a connector to a search space.
|
||||
|
||||
Currently supports:
|
||||
- SLACK_CONNECTOR: Indexes messages from all accessible Slack channels since the last indexing
|
||||
(or the last 365 days if never indexed before)
|
||||
- NOTION_CONNECTOR: Indexes pages from all accessible Notion pages since the last indexing
|
||||
(or the last 365 days if never indexed before)
|
||||
|
||||
Args:
|
||||
connector_id: ID of the connector to use
|
||||
search_space_id: ID of the search space to store indexed content
|
||||
background_tasks: FastAPI background tasks
|
||||
|
||||
Returns:
|
||||
Dictionary with indexing status
|
||||
"""
|
||||
try:
|
||||
# Check if the connector belongs to the user
|
||||
connector = await check_ownership(session, SearchSourceConnector, connector_id, user)
|
||||
|
||||
# Check if the search space belongs to the user
|
||||
search_space = await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
|
||||
# Handle different connector types
|
||||
if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
|
||||
# Determine the time range that will be indexed
|
||||
if not connector.last_indexed_at:
|
||||
start_date = "365 days ago"
|
||||
else:
|
||||
# Check if last_indexed_at is today
|
||||
today = datetime.now().date()
|
||||
if connector.last_indexed_at.date() == today:
|
||||
# If last indexed today, go back 1 day to ensure we don't miss anything
|
||||
start_date = (today - datetime.timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
else:
|
||||
start_date = connector.last_indexed_at.strftime("%Y-%m-%d")
|
||||
|
||||
# Add the indexing task to background tasks
|
||||
if background_tasks:
|
||||
background_tasks.add_task(
|
||||
run_slack_indexing,
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Slack indexing started in the background",
|
||||
"connector_type": connector.connector_type,
|
||||
"search_space": search_space.name,
|
||||
"indexing_from": start_date,
|
||||
"indexing_to": datetime.now().strftime("%Y-%m-%d")
|
||||
}
|
||||
else:
|
||||
# For testing or if background tasks are not available
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Background tasks not available",
|
||||
"connector_type": connector.connector_type
|
||||
}
|
||||
elif connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
|
||||
# Determine the time range that will be indexed
|
||||
if not connector.last_indexed_at:
|
||||
start_date = "365 days ago"
|
||||
else:
|
||||
# Check if last_indexed_at is today
|
||||
today = datetime.now().date()
|
||||
if connector.last_indexed_at.date() == today:
|
||||
# If last indexed today, go back 1 day to ensure we don't miss anything
|
||||
start_date = (today - datetime.timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
else:
|
||||
start_date = connector.last_indexed_at.strftime("%Y-%m-%d")
|
||||
|
||||
# Add the indexing task to background tasks
|
||||
if background_tasks:
|
||||
background_tasks.add_task(
|
||||
run_notion_indexing,
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Notion indexing started in the background",
|
||||
"connector_type": connector.connector_type,
|
||||
"search_space": search_space.name,
|
||||
"indexing_from": start_date,
|
||||
"indexing_to": datetime.now().strftime("%Y-%m-%d")
|
||||
}
|
||||
else:
|
||||
# For testing or if background tasks are not available
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Background tasks not available",
|
||||
"connector_type": connector.connector_type
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Indexing not supported for connector type: {connector.connector_type}"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start indexing: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to start indexing: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
async def update_connector_last_indexed(
|
||||
session: AsyncSession,
|
||||
connector_id: int
|
||||
):
|
||||
"""
|
||||
Update the last_indexed_at timestamp for a connector.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the connector to update
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(SearchSourceConnector.id == connector_id)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if connector:
|
||||
connector.last_indexed_at = datetime.now()
|
||||
await session.commit()
|
||||
logger.info(f"Updated last_indexed_at for connector {connector_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update last_indexed_at for connector {connector_id}: {str(e)}")
|
||||
await session.rollback()
|
||||
|
||||
async def run_slack_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int
|
||||
):
|
||||
"""
|
||||
Background task to run Slack indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Slack connector
|
||||
search_space_id: ID of the search space
|
||||
"""
|
||||
try:
|
||||
# Index Slack messages without updating last_indexed_at (we'll do it separately)
|
||||
documents_indexed, error_or_warning = await index_slack_messages(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
update_last_indexed=False # Don't update timestamp in the indexing function
|
||||
)
|
||||
|
||||
# Only update last_indexed_at if indexing was successful
|
||||
if documents_indexed > 0 and (error_or_warning is None or "Indexed" in error_or_warning):
|
||||
await update_connector_last_indexed(session, connector_id)
|
||||
logger.info(f"Slack indexing completed successfully: {documents_indexed} documents indexed")
|
||||
else:
|
||||
logger.error(f"Slack indexing failed or no documents indexed: {error_or_warning}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in background Slack indexing task: {str(e)}")
|
||||
|
||||
async def run_notion_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int
|
||||
):
|
||||
"""
|
||||
Background task to run Notion indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Notion connector
|
||||
search_space_id: ID of the search space
|
||||
"""
|
||||
try:
|
||||
# Index Notion pages without updating last_indexed_at (we'll do it separately)
|
||||
documents_indexed, error_or_warning = await index_notion_pages(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
update_last_indexed=False # Don't update timestamp in the indexing function
|
||||
)
|
||||
|
||||
# Only update last_indexed_at if indexing was successful
|
||||
if documents_indexed > 0 and (error_or_warning is None or "Indexed" in error_or_warning):
|
||||
await update_connector_last_indexed(session, connector_id)
|
||||
logger.info(f"Notion indexing completed successfully: {documents_indexed} documents indexed")
|
||||
else:
|
||||
logger.error(f"Notion indexing failed or no documents indexed: {error_or_warning}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in background Notion indexing task: {str(e)}")
|
||||
115
surfsense_backend/app/routes/search_spaces_routes.py
Normal file
115
surfsense_backend/app/routes/search_spaces_routes.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from typing import List
|
||||
from app.db import get_async_session, User, SearchSpace
|
||||
from app.schemas import SearchSpaceCreate, SearchSpaceUpdate, SearchSpaceRead
|
||||
from app.users import current_active_user
|
||||
from app.utils.check_ownership import check_ownership
|
||||
from fastapi import HTTPException
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/searchspaces/", response_model=SearchSpaceRead)
|
||||
async def create_search_space(
|
||||
search_space: SearchSpaceCreate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
db_search_space = SearchSpace(**search_space.model_dump(), user_id=user.id)
|
||||
session.add(db_search_space)
|
||||
await session.commit()
|
||||
await session.refresh(db_search_space)
|
||||
return db_search_space
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to create search space: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/searchspaces/", response_model=List[SearchSpaceRead])
|
||||
async def read_search_spaces(
|
||||
skip: int = 0,
|
||||
limit: int = 200,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSpace)
|
||||
.filter(SearchSpace.user_id == user.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch search spaces: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/searchspaces/{search_space_id}", response_model=SearchSpaceRead)
|
||||
async def read_search_space(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
search_space = await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
return search_space
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch search space: {str(e)}"
|
||||
)
|
||||
|
||||
@router.put("/searchspaces/{search_space_id}", response_model=SearchSpaceRead)
|
||||
async def update_search_space(
|
||||
search_space_id: int,
|
||||
search_space_update: SearchSpaceUpdate,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
db_search_space = await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
update_data = search_space_update.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(db_search_space, key, value)
|
||||
await session.commit()
|
||||
await session.refresh(db_search_space)
|
||||
return db_search_space
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update search space: {str(e)}"
|
||||
)
|
||||
|
||||
@router.delete("/searchspaces/{search_space_id}", response_model=dict)
|
||||
async def delete_search_space(
|
||||
search_space_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user)
|
||||
):
|
||||
try:
|
||||
db_search_space = await check_ownership(session, SearchSpace, search_space_id, user)
|
||||
await session.delete(db_search_space)
|
||||
await session.commit()
|
||||
return {"message": "Search space deleted successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete search space: {str(e)}"
|
||||
)
|
||||
50
surfsense_backend/app/schemas/__init__.py
Normal file
50
surfsense_backend/app/schemas/__init__.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from .base import TimestampModel, IDModel
|
||||
from .users import UserRead, UserCreate, UserUpdate
|
||||
from .search_space import SearchSpaceBase, SearchSpaceCreate, SearchSpaceUpdate, SearchSpaceRead
|
||||
from .documents import (
|
||||
ExtensionDocumentMetadata,
|
||||
ExtensionDocumentContent,
|
||||
DocumentBase,
|
||||
DocumentsCreate,
|
||||
DocumentUpdate,
|
||||
DocumentRead,
|
||||
)
|
||||
from .chunks import ChunkBase, ChunkCreate, ChunkUpdate, ChunkRead
|
||||
from .podcasts import PodcastBase, PodcastCreate, PodcastUpdate, PodcastRead
|
||||
from .chats import ChatBase, ChatCreate, ChatUpdate, ChatRead, AISDKChatRequest
|
||||
from .search_source_connector import SearchSourceConnectorBase, SearchSourceConnectorCreate, SearchSourceConnectorUpdate, SearchSourceConnectorRead
|
||||
|
||||
__all__ = [
|
||||
"AISDKChatRequest",
|
||||
"TimestampModel",
|
||||
"IDModel",
|
||||
"UserRead",
|
||||
"UserCreate",
|
||||
"UserUpdate",
|
||||
"SearchSpaceBase",
|
||||
"SearchSpaceCreate",
|
||||
"SearchSpaceUpdate",
|
||||
"SearchSpaceRead",
|
||||
"ExtensionDocumentMetadata",
|
||||
"ExtensionDocumentContent",
|
||||
"DocumentBase",
|
||||
"DocumentsCreate",
|
||||
"DocumentUpdate",
|
||||
"DocumentRead",
|
||||
"ChunkBase",
|
||||
"ChunkCreate",
|
||||
"ChunkUpdate",
|
||||
"ChunkRead",
|
||||
"PodcastBase",
|
||||
"PodcastCreate",
|
||||
"PodcastUpdate",
|
||||
"PodcastRead",
|
||||
"ChatBase",
|
||||
"ChatCreate",
|
||||
"ChatUpdate",
|
||||
"ChatRead",
|
||||
"SearchSourceConnectorBase",
|
||||
"SearchSourceConnectorCreate",
|
||||
"SearchSourceConnectorUpdate",
|
||||
"SearchSourceConnectorRead",
|
||||
]
|
||||
8
surfsense_backend/app/schemas/base.py
Normal file
8
surfsense_backend/app/schemas/base.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
from datetime import datetime
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TimestampModel(BaseModel):
|
||||
created_at: datetime
|
||||
|
||||
class IDModel(BaseModel):
|
||||
id: int
|
||||
46
surfsense_backend/app/schemas/chats.py
Normal file
46
surfsense_backend/app/schemas/chats.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import JSON
|
||||
from .base import IDModel, TimestampModel
|
||||
from app.db import ChatType
|
||||
|
||||
class ChatBase(BaseModel):
|
||||
type: ChatType
|
||||
title: str
|
||||
initial_connectors: Optional[List[str]] = None
|
||||
messages: List[Any]
|
||||
search_space_id: int
|
||||
|
||||
|
||||
class ClientAttachment(BaseModel):
|
||||
name: str
|
||||
contentType: str
|
||||
url: str
|
||||
|
||||
|
||||
class ToolInvocation(BaseModel):
|
||||
toolCallId: str
|
||||
toolName: str
|
||||
args: dict
|
||||
result: dict
|
||||
|
||||
|
||||
class ClientMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
experimental_attachments: Optional[List[ClientAttachment]] = None
|
||||
toolInvocations: Optional[List[ToolInvocation]] = None
|
||||
|
||||
class AISDKChatRequest(BaseModel):
|
||||
messages: List[ClientMessage]
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
|
||||
class ChatCreate(ChatBase):
|
||||
pass
|
||||
|
||||
class ChatUpdate(ChatBase):
|
||||
pass
|
||||
|
||||
class ChatRead(ChatBase, IDModel, TimestampModel):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
16
surfsense_backend/app/schemas/chunks.py
Normal file
16
surfsense_backend/app/schemas/chunks.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
from pydantic import BaseModel
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
class ChunkBase(BaseModel):
|
||||
content: str
|
||||
document_id: int
|
||||
|
||||
class ChunkCreate(ChunkBase):
|
||||
pass
|
||||
|
||||
class ChunkUpdate(ChunkBase):
|
||||
pass
|
||||
|
||||
class ChunkRead(ChunkBase, IDModel, TimestampModel):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
42
surfsense_backend/app/schemas/documents.py
Normal file
42
surfsense_backend/app/schemas/documents.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
from typing import List, Any
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import JSON
|
||||
from .base import IDModel, TimestampModel
|
||||
from app.db import DocumentType
|
||||
from datetime import datetime
|
||||
|
||||
class ExtensionDocumentMetadata(BaseModel):
|
||||
BrowsingSessionId: str
|
||||
VisitedWebPageURL: str
|
||||
VisitedWebPageTitle: str
|
||||
VisitedWebPageDateWithTimeInISOString: str
|
||||
VisitedWebPageReffererURL: str
|
||||
VisitedWebPageVisitDurationInMilliseconds: str
|
||||
|
||||
class ExtensionDocumentContent(BaseModel):
|
||||
metadata: ExtensionDocumentMetadata
|
||||
pageContent: str
|
||||
|
||||
class DocumentBase(BaseModel):
|
||||
document_type: DocumentType
|
||||
content: List[ExtensionDocumentContent] | List[str] | str # Updated to allow string content
|
||||
search_space_id: int
|
||||
|
||||
class DocumentsCreate(DocumentBase):
|
||||
pass
|
||||
|
||||
class DocumentUpdate(DocumentBase):
|
||||
pass
|
||||
|
||||
class DocumentRead(BaseModel):
|
||||
id: int
|
||||
title: str
|
||||
document_type: DocumentType
|
||||
document_metadata: dict
|
||||
content: str # Changed to string to match frontend
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
19
surfsense_backend/app/schemas/podcasts.py
Normal file
19
surfsense_backend/app/schemas/podcasts.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from pydantic import BaseModel
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
class PodcastBase(BaseModel):
|
||||
title: str
|
||||
is_generated: bool = False
|
||||
podcast_content: str = ""
|
||||
file_location: str = ""
|
||||
search_space_id: int
|
||||
|
||||
class PodcastCreate(PodcastBase):
|
||||
pass
|
||||
|
||||
class PodcastUpdate(PodcastBase):
|
||||
pass
|
||||
|
||||
class PodcastRead(PodcastBase, IDModel, TimestampModel):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
73
surfsense_backend/app/schemas/search_source_connector.py
Normal file
73
surfsense_backend/app/schemas/search_source_connector.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
from datetime import datetime
|
||||
import uuid
|
||||
from typing import Dict, Any
|
||||
from pydantic import BaseModel, field_validator
|
||||
from .base import IDModel, TimestampModel
|
||||
from app.db import SearchSourceConnectorType
|
||||
from fastapi import HTTPException
|
||||
|
||||
class SearchSourceConnectorBase(BaseModel):
|
||||
name: str
|
||||
connector_type: SearchSourceConnectorType
|
||||
is_indexable: bool
|
||||
last_indexed_at: datetime | None
|
||||
config: Dict[str, Any]
|
||||
|
||||
@field_validator('config')
|
||||
@classmethod
|
||||
def validate_config_for_connector_type(cls, config: Dict[str, Any], values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
connector_type = values.data.get('connector_type')
|
||||
|
||||
if connector_type == SearchSourceConnectorType.SERPER_API:
|
||||
# For SERPER_API, only allow SERPER_API_KEY
|
||||
allowed_keys = ["SERPER_API_KEY"]
|
||||
if set(config.keys()) != set(allowed_keys):
|
||||
raise ValueError(f"For SERPER_API connector type, config must only contain these keys: {allowed_keys}")
|
||||
|
||||
# Ensure the API key is not empty
|
||||
if not config.get("SERPER_API_KEY"):
|
||||
raise ValueError("SERPER_API_KEY cannot be empty")
|
||||
|
||||
elif connector_type == SearchSourceConnectorType.TAVILY_API:
|
||||
# For TAVILY_API, only allow TAVILY_API_KEY
|
||||
allowed_keys = ["TAVILY_API_KEY"]
|
||||
if set(config.keys()) != set(allowed_keys):
|
||||
raise ValueError(f"For TAVILY_API connector type, config must only contain these keys: {allowed_keys}")
|
||||
|
||||
# Ensure the API key is not empty
|
||||
if not config.get("TAVILY_API_KEY"):
|
||||
raise ValueError("TAVILY_API_KEY cannot be empty")
|
||||
|
||||
elif connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
|
||||
# For SLACK_CONNECTOR, only allow SLACK_BOT_TOKEN
|
||||
allowed_keys = ["SLACK_BOT_TOKEN"]
|
||||
if set(config.keys()) != set(allowed_keys):
|
||||
raise ValueError(f"For SLACK_CONNECTOR connector type, config must only contain these keys: {allowed_keys}")
|
||||
|
||||
# Ensure the bot token is not empty
|
||||
if not config.get("SLACK_BOT_TOKEN"):
|
||||
raise ValueError("SLACK_BOT_TOKEN cannot be empty")
|
||||
|
||||
elif connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
|
||||
# For NOTION_CONNECTOR, only allow NOTION_INTEGRATION_TOKEN
|
||||
allowed_keys = ["NOTION_INTEGRATION_TOKEN"]
|
||||
if set(config.keys()) != set(allowed_keys):
|
||||
raise ValueError(f"For NOTION_CONNECTOR connector type, config must only contain these keys: {allowed_keys}")
|
||||
|
||||
# Ensure the integration token is not empty
|
||||
if not config.get("NOTION_INTEGRATION_TOKEN"):
|
||||
raise ValueError("NOTION_INTEGRATION_TOKEN cannot be empty")
|
||||
|
||||
return config
|
||||
|
||||
class SearchSourceConnectorCreate(SearchSourceConnectorBase):
|
||||
pass
|
||||
|
||||
class SearchSourceConnectorUpdate(SearchSourceConnectorBase):
|
||||
pass
|
||||
|
||||
class SearchSourceConnectorRead(SearchSourceConnectorBase, IDModel, TimestampModel):
|
||||
user_id: uuid.UUID
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
23
surfsense_backend/app/schemas/search_space.py
Normal file
23
surfsense_backend/app/schemas/search_space.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
from datetime import datetime
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from .base import IDModel, TimestampModel
|
||||
|
||||
class SearchSpaceBase(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
class SearchSpaceCreate(SearchSpaceBase):
|
||||
pass
|
||||
|
||||
class SearchSpaceUpdate(SearchSpaceBase):
|
||||
pass
|
||||
|
||||
class SearchSpaceRead(SearchSpaceBase, IDModel, TimestampModel):
|
||||
id: int
|
||||
created_at: datetime
|
||||
user_id: uuid.UUID
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
11
surfsense_backend/app/schemas/users.py
Normal file
11
surfsense_backend/app/schemas/users.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
import uuid
|
||||
from fastapi_users import schemas
|
||||
|
||||
class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
pass
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
pass
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
pass
|
||||
0
surfsense_backend/app/tasks/__init__.py
Normal file
0
surfsense_backend/app/tasks/__init__.py
Normal file
246
surfsense_backend/app/tasks/background_tasks.py
Normal file
246
surfsense_backend/app/tasks/background_tasks.py
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
from typing import Optional, List
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from app.db import Document, DocumentType, Chunk
|
||||
from app.schemas import ExtensionDocumentContent
|
||||
from app.config import config
|
||||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||
from datetime import datetime
|
||||
from app.utils.document_converters import convert_document_to_markdown
|
||||
from langchain_core.documents import Document as LangChainDocument
|
||||
from langchain_community.document_loaders import FireCrawlLoader, AsyncChromiumLoader
|
||||
from langchain_community.document_transformers import MarkdownifyTransformer
|
||||
import validators
|
||||
|
||||
md = MarkdownifyTransformer()
|
||||
|
||||
|
||||
async def add_crawled_url_document(
|
||||
session: AsyncSession,
|
||||
url: str,
|
||||
search_space_id: int
|
||||
) -> Optional[Document]:
|
||||
try:
|
||||
|
||||
if not validators.url(url):
|
||||
raise ValueError(f"Url {url} is not a valid URL address")
|
||||
|
||||
if config.FIRECRAWL_API_KEY:
|
||||
crawl_loader = FireCrawlLoader(
|
||||
url=url,
|
||||
api_key=config.FIRECRAWL_API_KEY,
|
||||
mode="scrape",
|
||||
params={
|
||||
"formats": ["markdown"],
|
||||
"excludeTags": ["a"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
crawl_loader = AsyncChromiumLoader(urls=[url], headless=True)
|
||||
|
||||
url_crawled = await crawl_loader.aload()
|
||||
|
||||
if type(crawl_loader) == FireCrawlLoader:
|
||||
content_in_markdown = url_crawled[0].page_content
|
||||
elif type(crawl_loader) == AsyncChromiumLoader:
|
||||
content_in_markdown = md.transform_documents(url_crawled)[
|
||||
0].page_content
|
||||
|
||||
# Format document metadata in a more maintainable way
|
||||
metadata_sections = [
|
||||
("METADATA", [
|
||||
f"{key.upper()}: {value}" for key, value in url_crawled[0].metadata.items()
|
||||
]),
|
||||
("CONTENT", [
|
||||
"FORMAT: markdown",
|
||||
"TEXT_START",
|
||||
content_in_markdown,
|
||||
"TEXT_END"
|
||||
])
|
||||
]
|
||||
|
||||
# Build the document string more efficiently
|
||||
document_parts = []
|
||||
document_parts.append("<DOCUMENT>")
|
||||
|
||||
for section_title, section_content in metadata_sections:
|
||||
document_parts.append(f"<{section_title}>")
|
||||
document_parts.extend(section_content)
|
||||
document_parts.append(f"</{section_title}>")
|
||||
|
||||
document_parts.append("</DOCUMENT>")
|
||||
combined_document_string = '\n'.join(document_parts)
|
||||
|
||||
# Generate summary
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content)
|
||||
|
||||
# Process chunks
|
||||
chunks = [
|
||||
Chunk(content=chunk.text, embedding=chunk.embedding)
|
||||
for chunk in config.chunker_instance.chunk(content_in_markdown)
|
||||
]
|
||||
|
||||
# Create and store document
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=url_crawled[0].metadata['title'] if type(
|
||||
crawl_loader) == FireCrawlLoader else url_crawled[0].metadata['source'],
|
||||
document_type=DocumentType.CRAWLED_URL,
|
||||
document_metadata=url_crawled[0].metadata,
|
||||
content=summary_content,
|
||||
embedding=summary_embedding,
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
session.add(document)
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
return document
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise RuntimeError(f"Failed to crawl URL: {str(e)}")
|
||||
|
||||
|
||||
async def add_extension_received_document(
|
||||
session: AsyncSession,
|
||||
content: ExtensionDocumentContent,
|
||||
search_space_id: int
|
||||
) -> Optional[Document]:
|
||||
"""
|
||||
Process and store document content received from the SurfSense Extension.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
content: Document content from extension
|
||||
search_space_id: ID of the search space
|
||||
|
||||
Returns:
|
||||
Document object if successful, None if failed
|
||||
"""
|
||||
try:
|
||||
# Format document metadata in a more maintainable way
|
||||
metadata_sections = [
|
||||
("METADATA", [
|
||||
f"SESSION_ID: {content.metadata.BrowsingSessionId}",
|
||||
f"URL: {content.metadata.VisitedWebPageURL}",
|
||||
f"TITLE: {content.metadata.VisitedWebPageTitle}",
|
||||
f"REFERRER: {content.metadata.VisitedWebPageReffererURL}",
|
||||
f"TIMESTAMP: {content.metadata.VisitedWebPageDateWithTimeInISOString}",
|
||||
f"DURATION_MS: {content.metadata.VisitedWebPageVisitDurationInMilliseconds}"
|
||||
]),
|
||||
("CONTENT", [
|
||||
"FORMAT: markdown",
|
||||
"TEXT_START",
|
||||
content.pageContent,
|
||||
"TEXT_END"
|
||||
])
|
||||
]
|
||||
|
||||
# Build the document string more efficiently
|
||||
document_parts = []
|
||||
document_parts.append("<DOCUMENT>")
|
||||
|
||||
for section_title, section_content in metadata_sections:
|
||||
document_parts.append(f"<{section_title}>")
|
||||
document_parts.extend(section_content)
|
||||
document_parts.append(f"</{section_title}>")
|
||||
|
||||
document_parts.append("</DOCUMENT>")
|
||||
combined_document_string = '\n'.join(document_parts)
|
||||
|
||||
# Generate summary
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content)
|
||||
|
||||
# Process chunks
|
||||
chunks = [
|
||||
Chunk(content=chunk.text, embedding=chunk.embedding)
|
||||
for chunk in config.chunker_instance.chunk(content.pageContent)
|
||||
]
|
||||
|
||||
# Create and store document
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=content.metadata.VisitedWebPageTitle,
|
||||
document_type=DocumentType.EXTENSION,
|
||||
document_metadata=content.metadata.model_dump(),
|
||||
content=summary_content,
|
||||
embedding=summary_embedding,
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
session.add(document)
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
return document
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise RuntimeError(f"Failed to process extension document: {str(e)}")
|
||||
|
||||
|
||||
async def add_received_file_document(
|
||||
session: AsyncSession,
|
||||
file_name: str,
|
||||
unstructured_processed_elements: List[LangChainDocument],
|
||||
search_space_id: int
|
||||
) -> Optional[Document]:
|
||||
try:
|
||||
file_in_markdown = await convert_document_to_markdown(unstructured_processed_elements)
|
||||
|
||||
# TODO: Check if file_markdown exceeds token limit of embedding model
|
||||
|
||||
# Generate summary
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_result = await summary_chain.ainvoke({"document": file_in_markdown})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = config.embedding_model_instance.embed(
|
||||
summary_content)
|
||||
|
||||
# Process chunks
|
||||
chunks = [
|
||||
Chunk(content=chunk.text, embedding=chunk.embedding)
|
||||
for chunk in config.chunker_instance.chunk(file_in_markdown)
|
||||
]
|
||||
|
||||
# Create and store document
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=file_name,
|
||||
document_type=DocumentType.FILE,
|
||||
document_metadata={
|
||||
"FILE_NAME": file_name,
|
||||
"SAVED_AT": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
},
|
||||
content=summary_content,
|
||||
embedding=summary_embedding,
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
session.add(document)
|
||||
await session.commit()
|
||||
await session.refresh(document)
|
||||
|
||||
return document
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
raise db_error
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise RuntimeError(f"Failed to process file document: {str(e)}")
|
||||
486
surfsense_backend/app/tasks/connectors_indexing_tasks.py
Normal file
486
surfsense_backend/app/tasks/connectors_indexing_tasks.py
Normal file
|
|
@ -0,0 +1,486 @@
|
|||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.future import select
|
||||
from datetime import datetime, timedelta
|
||||
from app.db import Document, DocumentType, Chunk, SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.config import config
|
||||
from app.prompts import SUMMARY_PROMPT_TEMPLATE
|
||||
from app.connectors.slack_history import SlackHistory
|
||||
from app.connectors.notion_history import NotionHistoryConnector
|
||||
from slack_sdk.errors import SlackApiError
|
||||
import logging
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def index_slack_messages(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
update_last_indexed: bool = True
|
||||
) -> Tuple[int, Optional[str]]:
|
||||
"""
|
||||
Index Slack messages from all accessible channels.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Slack connector
|
||||
search_space_id: ID of the search space to store documents in
|
||||
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
|
||||
|
||||
Returns:
|
||||
Tuple containing (number of documents indexed, error message or None)
|
||||
"""
|
||||
try:
|
||||
# Get the connector
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
return 0, f"Connector with ID {connector_id} not found or is not a Slack connector"
|
||||
|
||||
# Get the Slack token from the connector config
|
||||
slack_token = connector.config.get("SLACK_BOT_TOKEN")
|
||||
if not slack_token:
|
||||
return 0, "Slack token not found in connector config"
|
||||
|
||||
# Initialize Slack client
|
||||
slack_client = SlackHistory(token=slack_token)
|
||||
|
||||
# Calculate date range
|
||||
end_date = datetime.now()
|
||||
|
||||
# Use last_indexed_at as start date if available, otherwise use 365 days ago
|
||||
if connector.last_indexed_at:
|
||||
# Check if last_indexed_at is today
|
||||
today = datetime.now().date()
|
||||
if connector.last_indexed_at.date() == today:
|
||||
# If last indexed today, go back 1 day to ensure we don't miss anything
|
||||
start_date = end_date - timedelta(days=7)
|
||||
else:
|
||||
start_date = connector.last_indexed_at
|
||||
else:
|
||||
start_date = end_date - timedelta(days=365)
|
||||
|
||||
# Format dates for Slack API
|
||||
start_date_str = start_date.strftime("%Y-%m-%d")
|
||||
end_date_str = end_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Get all channels
|
||||
try:
|
||||
channels = slack_client.get_all_channels()
|
||||
except Exception as e:
|
||||
return 0, f"Failed to get Slack channels: {str(e)}"
|
||||
|
||||
if not channels:
|
||||
return 0, "No Slack channels found"
|
||||
|
||||
# Track the number of documents indexed
|
||||
documents_indexed = 0
|
||||
skipped_channels = []
|
||||
|
||||
# Process each channel
|
||||
for channel_name, channel_id in channels.items():
|
||||
try:
|
||||
# Check if the bot is a member of the channel
|
||||
try:
|
||||
# First try to get channel info to check if bot is a member
|
||||
channel_info = slack_client.client.conversations_info(channel=channel_id)
|
||||
|
||||
# For private channels, the bot needs to be a member
|
||||
if channel_info.get("channel", {}).get("is_private", False):
|
||||
# Check if bot is a member
|
||||
is_member = channel_info.get("channel", {}).get("is_member", False)
|
||||
if not is_member:
|
||||
logger.warning(f"Bot is not a member of private channel {channel_name} ({channel_id}). Skipping.")
|
||||
skipped_channels.append(f"{channel_name} (private, bot not a member)")
|
||||
continue
|
||||
except SlackApiError as e:
|
||||
if "not_in_channel" in str(e) or "channel_not_found" in str(e):
|
||||
logger.warning(f"Bot cannot access channel {channel_name} ({channel_id}). Skipping.")
|
||||
skipped_channels.append(f"{channel_name} (access error)")
|
||||
continue
|
||||
else:
|
||||
# Re-raise if it's a different error
|
||||
raise
|
||||
|
||||
# Get messages for this channel
|
||||
messages, error = slack_client.get_history_by_date_range(
|
||||
channel_id=channel_id,
|
||||
start_date=start_date_str,
|
||||
end_date=end_date_str,
|
||||
limit=1000 # Limit to 1000 messages per channel
|
||||
)
|
||||
|
||||
if error:
|
||||
logger.warning(f"Error getting messages from channel {channel_name}: {error}")
|
||||
skipped_channels.append(f"{channel_name} (error: {error})")
|
||||
continue # Skip this channel if there's an error
|
||||
|
||||
if not messages:
|
||||
logger.info(f"No messages found in channel {channel_name} for the specified date range.")
|
||||
continue # Skip if no messages
|
||||
|
||||
# Format messages with user info
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
# Skip bot messages and system messages
|
||||
if msg.get("subtype") in ["bot_message", "channel_join", "channel_leave"]:
|
||||
continue
|
||||
|
||||
formatted_msg = slack_client.format_message(msg, include_user_info=True)
|
||||
formatted_messages.append(formatted_msg)
|
||||
|
||||
if not formatted_messages:
|
||||
logger.info(f"No valid messages found in channel {channel_name} after filtering.")
|
||||
continue # Skip if no valid messages after filtering
|
||||
|
||||
# Convert messages to markdown format
|
||||
channel_content = f"# Slack Channel: {channel_name}\n\n"
|
||||
|
||||
for msg in formatted_messages:
|
||||
user_name = msg.get("user_name", "Unknown User")
|
||||
timestamp = msg.get("datetime", "Unknown Time")
|
||||
text = msg.get("text", "")
|
||||
|
||||
channel_content += f"## {user_name} ({timestamp})\n\n{text}\n\n---\n\n"
|
||||
|
||||
# Format document metadata
|
||||
metadata_sections = [
|
||||
("METADATA", [
|
||||
f"CHANNEL_NAME: {channel_name}",
|
||||
f"CHANNEL_ID: {channel_id}",
|
||||
f"START_DATE: {start_date_str}",
|
||||
f"END_DATE: {end_date_str}",
|
||||
f"MESSAGE_COUNT: {len(formatted_messages)}",
|
||||
f"INDEXED_AT: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
]),
|
||||
("CONTENT", [
|
||||
"FORMAT: markdown",
|
||||
"TEXT_START",
|
||||
channel_content,
|
||||
"TEXT_END"
|
||||
])
|
||||
]
|
||||
|
||||
# Build the document string
|
||||
document_parts = []
|
||||
document_parts.append("<DOCUMENT>")
|
||||
|
||||
for section_title, section_content in metadata_sections:
|
||||
document_parts.append(f"<{section_title}>")
|
||||
document_parts.extend(section_content)
|
||||
document_parts.append(f"</{section_title}>")
|
||||
|
||||
document_parts.append("</DOCUMENT>")
|
||||
combined_document_string = '\n'.join(document_parts)
|
||||
|
||||
# Generate summary
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = config.embedding_model_instance.embed(summary_content)
|
||||
|
||||
# Process chunks
|
||||
chunks = [
|
||||
Chunk(content=chunk.text, embedding=chunk.embedding)
|
||||
for chunk in config.chunker_instance.chunk(channel_content)
|
||||
]
|
||||
|
||||
# Create and store document
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=f"Slack - {channel_name}",
|
||||
document_type=DocumentType.SLACK_CONNECTOR,
|
||||
document_metadata={
|
||||
"channel_name": channel_name,
|
||||
"channel_id": channel_id,
|
||||
"start_date": start_date_str,
|
||||
"end_date": end_date_str,
|
||||
"message_count": len(formatted_messages),
|
||||
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
},
|
||||
content=summary_content,
|
||||
embedding=summary_embedding,
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
session.add(document)
|
||||
documents_indexed += 1
|
||||
logger.info(f"Successfully indexed channel {channel_name} with {len(formatted_messages)} messages")
|
||||
|
||||
except SlackApiError as slack_error:
|
||||
logger.error(f"Slack API error for channel {channel_name}: {str(slack_error)}")
|
||||
skipped_channels.append(f"{channel_name} (Slack API error)")
|
||||
continue # Skip this channel and continue with others
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing channel {channel_name}: {str(e)}")
|
||||
skipped_channels.append(f"{channel_name} (processing error)")
|
||||
continue # Skip this channel and continue with others
|
||||
|
||||
# Update the last_indexed_at timestamp for the connector only if requested
|
||||
# and if we successfully indexed at least one channel
|
||||
if update_last_indexed and documents_indexed > 0:
|
||||
connector.last_indexed_at = datetime.now()
|
||||
|
||||
# Commit all changes
|
||||
await session.commit()
|
||||
|
||||
# Prepare result message
|
||||
result_message = None
|
||||
if skipped_channels:
|
||||
result_message = f"Indexed {documents_indexed} channels. Skipped {len(skipped_channels)} channels: {', '.join(skipped_channels)}"
|
||||
|
||||
return documents_indexed, result_message
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
logger.error(f"Database error: {str(db_error)}")
|
||||
return 0, f"Database error: {str(db_error)}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Failed to index Slack messages: {str(e)}")
|
||||
return 0, f"Failed to index Slack messages: {str(e)}"
|
||||
|
||||
async def index_notion_pages(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
update_last_indexed: bool = True
|
||||
) -> Tuple[int, Optional[str]]:
|
||||
"""
|
||||
Index Notion pages from all accessible pages.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Notion connector
|
||||
search_space_id: ID of the search space to store documents in
|
||||
update_last_indexed: Whether to update the last_indexed_at timestamp (default: True)
|
||||
|
||||
Returns:
|
||||
Tuple containing (number of documents indexed, error message or None)
|
||||
"""
|
||||
try:
|
||||
# Get the connector
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
return 0, f"Connector with ID {connector_id} not found or is not a Notion connector"
|
||||
|
||||
# Get the Notion token from the connector config
|
||||
notion_token = connector.config.get("NOTION_INTEGRATION_TOKEN")
|
||||
if not notion_token:
|
||||
return 0, "Notion integration token not found in connector config"
|
||||
|
||||
# Initialize Notion client
|
||||
logger.info(f"Initializing Notion client for connector {connector_id}")
|
||||
notion_client = NotionHistoryConnector(token=notion_token)
|
||||
|
||||
# Calculate date range
|
||||
end_date = datetime.now()
|
||||
|
||||
# Use last_indexed_at as start date if available, otherwise use 365 days ago
|
||||
if connector.last_indexed_at:
|
||||
# Check if last_indexed_at is today
|
||||
today = datetime.now().date()
|
||||
if connector.last_indexed_at.date() == today:
|
||||
# If last indexed today, go back 1 day to ensure we don't miss anything
|
||||
start_date = end_date - timedelta(days=1)
|
||||
else:
|
||||
start_date = connector.last_indexed_at
|
||||
else:
|
||||
start_date = end_date - timedelta(days=365)
|
||||
|
||||
# Format dates for Notion API (ISO format)
|
||||
start_date_str = start_date.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
end_date_str = end_date.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
logger.info(f"Fetching Notion pages from {start_date_str} to {end_date_str}")
|
||||
|
||||
# Get all pages
|
||||
try:
|
||||
pages = notion_client.get_all_pages(start_date=start_date_str, end_date=end_date_str)
|
||||
logger.info(f"Found {len(pages)} Notion pages")
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Notion pages: {str(e)}", exc_info=True)
|
||||
return 0, f"Failed to get Notion pages: {str(e)}"
|
||||
|
||||
if not pages:
|
||||
logger.info("No Notion pages found to index")
|
||||
return 0, "No Notion pages found"
|
||||
|
||||
# Track the number of documents indexed
|
||||
documents_indexed = 0
|
||||
skipped_pages = []
|
||||
|
||||
# Process each page
|
||||
for page in pages:
|
||||
try:
|
||||
page_id = page.get("page_id")
|
||||
page_title = page.get("title", f"Untitled page ({page_id})")
|
||||
page_content = page.get("content", [])
|
||||
|
||||
logger.info(f"Processing Notion page: {page_title} ({page_id})")
|
||||
|
||||
if not page_content:
|
||||
logger.info(f"No content found in page {page_title}. Skipping.")
|
||||
skipped_pages.append(f"{page_title} (no content)")
|
||||
continue
|
||||
|
||||
# Convert page content to markdown format
|
||||
markdown_content = f"# Notion Page: {page_title}\n\n"
|
||||
|
||||
# Process blocks recursively
|
||||
def process_blocks(blocks, level=0):
|
||||
result = ""
|
||||
for block in blocks:
|
||||
block_type = block.get("type")
|
||||
block_content = block.get("content", "")
|
||||
children = block.get("children", [])
|
||||
|
||||
# Add indentation based on level
|
||||
indent = " " * level
|
||||
|
||||
# Format based on block type
|
||||
if block_type in ["paragraph", "text"]:
|
||||
result += f"{indent}{block_content}\n\n"
|
||||
elif block_type in ["heading_1", "header"]:
|
||||
result += f"{indent}# {block_content}\n\n"
|
||||
elif block_type == "heading_2":
|
||||
result += f"{indent}## {block_content}\n\n"
|
||||
elif block_type == "heading_3":
|
||||
result += f"{indent}### {block_content}\n\n"
|
||||
elif block_type == "bulleted_list_item":
|
||||
result += f"{indent}* {block_content}\n"
|
||||
elif block_type == "numbered_list_item":
|
||||
result += f"{indent}1. {block_content}\n"
|
||||
elif block_type == "to_do":
|
||||
result += f"{indent}- [ ] {block_content}\n"
|
||||
elif block_type == "toggle":
|
||||
result += f"{indent}> {block_content}\n"
|
||||
elif block_type == "code":
|
||||
result += f"{indent}```\n{block_content}\n```\n\n"
|
||||
elif block_type == "quote":
|
||||
result += f"{indent}> {block_content}\n\n"
|
||||
elif block_type == "callout":
|
||||
result += f"{indent}> **Note:** {block_content}\n\n"
|
||||
elif block_type == "image":
|
||||
result += f"{indent}\n\n"
|
||||
else:
|
||||
# Default for other block types
|
||||
if block_content:
|
||||
result += f"{indent}{block_content}\n\n"
|
||||
|
||||
# Process children recursively
|
||||
if children:
|
||||
result += process_blocks(children, level + 1)
|
||||
|
||||
return result
|
||||
|
||||
logger.debug(f"Converting {len(page_content)} blocks to markdown for page {page_title}")
|
||||
markdown_content += process_blocks(page_content)
|
||||
|
||||
# Format document metadata
|
||||
metadata_sections = [
|
||||
("METADATA", [
|
||||
f"PAGE_TITLE: {page_title}",
|
||||
f"PAGE_ID: {page_id}",
|
||||
f"INDEXED_AT: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
]),
|
||||
("CONTENT", [
|
||||
"FORMAT: markdown",
|
||||
"TEXT_START",
|
||||
markdown_content,
|
||||
"TEXT_END"
|
||||
])
|
||||
]
|
||||
|
||||
# Build the document string
|
||||
document_parts = []
|
||||
document_parts.append("<DOCUMENT>")
|
||||
|
||||
for section_title, section_content in metadata_sections:
|
||||
document_parts.append(f"<{section_title}>")
|
||||
document_parts.extend(section_content)
|
||||
document_parts.append(f"</{section_title}>")
|
||||
|
||||
document_parts.append("</DOCUMENT>")
|
||||
combined_document_string = '\n'.join(document_parts)
|
||||
|
||||
# Generate summary
|
||||
logger.debug(f"Generating summary for page {page_title}")
|
||||
summary_chain = SUMMARY_PROMPT_TEMPLATE | config.long_context_llm_instance
|
||||
summary_result = await summary_chain.ainvoke({"document": combined_document_string})
|
||||
summary_content = summary_result.content
|
||||
summary_embedding = config.embedding_model_instance.embed(summary_content)
|
||||
|
||||
# Process chunks
|
||||
logger.debug(f"Chunking content for page {page_title}")
|
||||
chunks = [
|
||||
Chunk(content=chunk.text, embedding=chunk.embedding)
|
||||
for chunk in config.chunker_instance.chunk(markdown_content)
|
||||
]
|
||||
|
||||
# Create and store document
|
||||
document = Document(
|
||||
search_space_id=search_space_id,
|
||||
title=f"Notion - {page_title}",
|
||||
document_type=DocumentType.NOTION_CONNECTOR,
|
||||
document_metadata={
|
||||
"page_title": page_title,
|
||||
"page_id": page_id,
|
||||
"indexed_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
},
|
||||
content=summary_content,
|
||||
embedding=summary_embedding,
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
session.add(document)
|
||||
documents_indexed += 1
|
||||
logger.info(f"Successfully indexed Notion page: {page_title}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Notion page {page.get('title', 'Unknown')}: {str(e)}", exc_info=True)
|
||||
skipped_pages.append(f"{page.get('title', 'Unknown')} (processing error)")
|
||||
continue # Skip this page and continue with others
|
||||
|
||||
# Update the last_indexed_at timestamp for the connector only if requested
|
||||
# and if we successfully indexed at least one page
|
||||
if update_last_indexed and documents_indexed > 0:
|
||||
connector.last_indexed_at = datetime.now()
|
||||
logger.info(f"Updated last_indexed_at for connector {connector_id}")
|
||||
|
||||
# Commit all changes
|
||||
await session.commit()
|
||||
|
||||
# Prepare result message
|
||||
result_message = None
|
||||
if skipped_pages:
|
||||
result_message = f"Indexed {documents_indexed} pages. Skipped {len(skipped_pages)} pages: {', '.join(skipped_pages)}"
|
||||
|
||||
logger.info(f"Notion indexing completed: {documents_indexed} pages indexed, {len(skipped_pages)} pages skipped")
|
||||
return documents_indexed, result_message
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
logger.error(f"Database error during Notion indexing: {str(db_error)}", exc_info=True)
|
||||
return 0, f"Database error: {str(db_error)}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Failed to index Notion pages: {str(e)}", exc_info=True)
|
||||
return 0, f"Failed to index Notion pages: {str(e)}"
|
||||
340
surfsense_backend/app/tasks/stream_connector_search_results.py
Normal file
340
surfsense_backend/app/tasks/stream_connector_search_results.py
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
import json
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, AsyncGenerator, Dict, Any
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from app.utils.connector_service import ConnectorService
|
||||
from app.utils.research_service import ResearchService
|
||||
from app.utils.streaming_service import StreamingService
|
||||
from app.utils.reranker_service import RerankerService
|
||||
from app.config import config
|
||||
from app.utils.document_converters import convert_chunks_to_langchain_documents
|
||||
|
||||
async def stream_connector_search_results(
|
||||
user_query: str,
|
||||
user_id: int,
|
||||
search_space_id: int,
|
||||
session: AsyncSession,
|
||||
research_mode: str,
|
||||
selected_connectors: List[str]
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream connector search results to the client
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
user_id: The user's ID
|
||||
search_space_id: The search space ID
|
||||
session: The database session
|
||||
research_mode: The research mode
|
||||
selected_connectors: List of selected connectors
|
||||
|
||||
Yields:
|
||||
str: Formatted response strings
|
||||
"""
|
||||
# Initialize services
|
||||
connector_service = ConnectorService(session)
|
||||
streaming_service = StreamingService()
|
||||
|
||||
|
||||
reranker_service = RerankerService.get_reranker_instance(config)
|
||||
|
||||
all_raw_documents = [] # Store all raw documents before reranking
|
||||
all_sources = []
|
||||
TOP_K = 20
|
||||
|
||||
if research_mode == "GENERAL":
|
||||
TOP_K = 20
|
||||
elif research_mode == "DEEP":
|
||||
TOP_K = 40
|
||||
elif research_mode == "DEEPER":
|
||||
TOP_K = 60
|
||||
|
||||
|
||||
# Process each selected connector
|
||||
for connector in selected_connectors:
|
||||
# Crawled URLs
|
||||
if connector == "CRAWLED_URL":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search for crawled URLs...")
|
||||
|
||||
# Search for crawled URLs
|
||||
result_object, crawled_urls_chunks = await connector_service.search_crawled_urls(
|
||||
user_query=user_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant crawled URLs",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(crawled_urls_chunks)
|
||||
|
||||
|
||||
# Files
|
||||
if connector == "FILE":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search for files...")
|
||||
|
||||
# Search for files
|
||||
result_object, files_chunks = await connector_service.search_files(
|
||||
user_query=user_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant files",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(files_chunks)
|
||||
|
||||
# Tavily Connector
|
||||
if connector == "TAVILY_API":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search with Tavily API...")
|
||||
|
||||
# Search using Tavily API
|
||||
result_object, tavily_chunks = await connector_service.search_tavily(
|
||||
user_query=user_query,
|
||||
user_id=user_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant results from Tavily",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(tavily_chunks)
|
||||
|
||||
# Slack Connector
|
||||
if connector == "SLACK_CONNECTOR":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search for slack connector...")
|
||||
|
||||
# Search using Slack API
|
||||
result_object, slack_chunks = await connector_service.search_slack(
|
||||
user_query=user_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant results from Slack",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(slack_chunks)
|
||||
|
||||
|
||||
# Notion Connector
|
||||
if connector == "NOTION_CONNECTOR":
|
||||
# Send terminal message about starting search
|
||||
yield streaming_service.add_terminal_message("Starting to search for notion connector...")
|
||||
|
||||
# Search using Notion API
|
||||
result_object, notion_chunks = await connector_service.search_notion(
|
||||
user_query=user_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
|
||||
# Send terminal message about search results
|
||||
yield streaming_service.add_terminal_message(
|
||||
f"Found {len(result_object['sources'])} relevant results from Notion",
|
||||
"success"
|
||||
)
|
||||
|
||||
# Update sources
|
||||
all_sources.append(result_object)
|
||||
yield streaming_service.update_sources(all_sources)
|
||||
|
||||
# Add documents to collection
|
||||
all_raw_documents.extend(notion_chunks)
|
||||
|
||||
|
||||
|
||||
|
||||
# If we have documents to research
|
||||
if all_raw_documents:
|
||||
# Rerank all documents if reranker is available
|
||||
if reranker_service:
|
||||
yield streaming_service.add_terminal_message("Reranking documents for better relevance...", "info")
|
||||
|
||||
# Convert documents to format expected by reranker
|
||||
reranker_input_docs = [
|
||||
{
|
||||
"chunk_id": doc.get("chunk_id", f"chunk_{i}"),
|
||||
"content": doc.get("content", ""),
|
||||
"score": doc.get("score", 0.0),
|
||||
"document": {
|
||||
"id": doc.get("document", {}).get("id", ""),
|
||||
"title": doc.get("document", {}).get("title", ""),
|
||||
"document_type": doc.get("document", {}).get("document_type", ""),
|
||||
"metadata": doc.get("document", {}).get("metadata", {})
|
||||
}
|
||||
} for i, doc in enumerate(all_raw_documents)
|
||||
]
|
||||
|
||||
# Rerank documents
|
||||
reranked_docs = reranker_service.rerank_documents(user_query, reranker_input_docs)
|
||||
|
||||
# Sort by score in descending order
|
||||
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||||
|
||||
|
||||
|
||||
# Convert back to langchain documents format
|
||||
from langchain.schema import Document as LangchainDocument
|
||||
all_langchain_documents_to_research = [
|
||||
LangchainDocument(
|
||||
page_content= f"""<document><metadata><source_id>{doc.get("document", {}).get("id", "")}</source_id></metadata><content>{doc.get("content", "")}</content></document>""",
|
||||
metadata={
|
||||
# **doc.get("document", {}).get("metadata", {}),
|
||||
# "score": doc.get("score", 0.0),
|
||||
# "rank": doc.get("rank", 0),
|
||||
# "document_id": doc.get("document", {}).get("id", ""),
|
||||
# "document_title": doc.get("document", {}).get("title", ""),
|
||||
# "document_type": doc.get("document", {}).get("document_type", ""),
|
||||
# # Explicitly set source_id for citation purposes
|
||||
"source_id": str(doc.get("document", {}).get("id", ""))
|
||||
}
|
||||
) for doc in reranked_docs
|
||||
]
|
||||
|
||||
yield streaming_service.add_terminal_message(f"Reranked {len(all_langchain_documents_to_research)} documents", "success")
|
||||
else:
|
||||
# Use raw documents if no reranker is available
|
||||
all_langchain_documents_to_research = convert_chunks_to_langchain_documents(all_raw_documents)
|
||||
|
||||
# Send terminal message about starting research
|
||||
yield streaming_service.add_terminal_message("Starting to research...", "info")
|
||||
|
||||
# Create a buffer to collect report content
|
||||
report_buffer = []
|
||||
|
||||
|
||||
# Use the streaming research method
|
||||
yield streaming_service.add_terminal_message("Generating report...", "info")
|
||||
|
||||
# Create a wrapper to handle the streaming
|
||||
class StreamHandler:
|
||||
def __init__(self):
|
||||
self.queue = asyncio.Queue()
|
||||
|
||||
async def handle_progress(self, data):
|
||||
result = None
|
||||
if data.get("type") == "logs":
|
||||
# Handle log messages
|
||||
result = streaming_service.add_terminal_message(data.get("output", ""), "info")
|
||||
elif data.get("type") == "report":
|
||||
# Handle report content
|
||||
content = data.get("output", "")
|
||||
|
||||
# Fix incorrect citation formats using regex
|
||||
|
||||
# More specific pattern to match only numeric citations in markdown-style links
|
||||
# This matches patterns like ([1](https://github.com/...)) but not general links like ([Click here](https://...))
|
||||
pattern = r'\(\[(\d+)\]\((https?://[^\)]+)\)\)'
|
||||
|
||||
# Replace with just [X] where X is the number
|
||||
content = re.sub(pattern, r'[\1]', content)
|
||||
|
||||
# Also match other incorrect formats like ([1]) and convert to [1]
|
||||
# Only match if the content inside brackets is a number
|
||||
content = re.sub(r'\(\[(\d+)\]\)', r'[\1]', content)
|
||||
|
||||
report_buffer.append(content)
|
||||
# Update the answer with the accumulated content
|
||||
result = streaming_service.update_answer(report_buffer)
|
||||
|
||||
if result:
|
||||
await self.queue.put(result)
|
||||
return result
|
||||
|
||||
async def get_next(self):
|
||||
try:
|
||||
return await self.queue.get()
|
||||
except Exception as e:
|
||||
print(f"Error getting next item from queue: {e}")
|
||||
return None
|
||||
|
||||
def task_done(self):
|
||||
self.queue.task_done()
|
||||
|
||||
# Create the stream handler
|
||||
stream_handler = StreamHandler()
|
||||
|
||||
# Start the research process in a separate task
|
||||
research_task = asyncio.create_task(
|
||||
ResearchService.stream_research(
|
||||
user_query=user_query,
|
||||
documents=all_langchain_documents_to_research,
|
||||
on_progress=stream_handler.handle_progress,
|
||||
research_mode=research_mode
|
||||
)
|
||||
)
|
||||
|
||||
# Stream results as they become available
|
||||
while not research_task.done() or not stream_handler.queue.empty():
|
||||
try:
|
||||
# Get the next result with a timeout
|
||||
result = await asyncio.wait_for(stream_handler.get_next(), timeout=0.1)
|
||||
stream_handler.task_done()
|
||||
yield result
|
||||
except asyncio.TimeoutError:
|
||||
# No result available yet, check if the research task is done
|
||||
if research_task.done():
|
||||
# If the queue is empty and the task is done, we're finished
|
||||
if stream_handler.queue.empty():
|
||||
break
|
||||
|
||||
# Get the final report
|
||||
try:
|
||||
final_report = await research_task
|
||||
|
||||
# Send terminal message about research completion
|
||||
yield streaming_service.add_terminal_message("Research completed", "success")
|
||||
|
||||
# Update the answer with the final report
|
||||
final_report_lines = final_report.split('\n')
|
||||
yield streaming_service.update_answer(final_report_lines)
|
||||
except Exception as e:
|
||||
# Handle any exceptions
|
||||
yield streaming_service.add_terminal_message(f"Error during research: {str(e)}", "error")
|
||||
|
||||
# Send completion message
|
||||
yield streaming_service.format_completion()
|
||||
95
surfsense_backend/app/users.py
Normal file
95
surfsense_backend/app/users.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import Depends, Request, Response
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, models
|
||||
from fastapi_users.authentication import (
|
||||
AuthenticationBackend,
|
||||
BearerTransport,
|
||||
JWTStrategy,
|
||||
)
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
from app.config import config
|
||||
from app.db import User, get_user_db
|
||||
from pydantic import BaseModel
|
||||
|
||||
class BearerResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
SECRET = config.SECRET_KEY
|
||||
|
||||
google_oauth_client = GoogleOAuth2(
|
||||
config.GOOGLE_OAUTH_CLIENT_ID,
|
||||
config.GOOGLE_OAUTH_CLIENT_SECRET,
|
||||
)
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = SECRET
|
||||
verification_token_secret = SECRET
|
||||
|
||||
async def on_after_register(self, user: User, request: Optional[Request] = None):
|
||||
print(f"User {user.id} has registered.")
|
||||
|
||||
async def on_after_forgot_password(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
):
|
||||
print(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
):
|
||||
print(
|
||||
f"Verification requested for user {user.id}. Verification token: {token}")
|
||||
|
||||
|
||||
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
|
||||
yield UserManager(user_db)
|
||||
|
||||
|
||||
def get_jwt_strategy() -> JWTStrategy[models.UP, models.ID]:
|
||||
return JWTStrategy(secret=SECRET, lifetime_seconds=3600*24)
|
||||
|
||||
|
||||
# # COOKIE AUTH | Uncomment if you want to use cookie auth.
|
||||
# from fastapi_users.authentication import (
|
||||
# CookieTransport,
|
||||
# )
|
||||
# class CustomCookieTransport(CookieTransport):
|
||||
# async def get_login_response(self, token: str) -> Response:
|
||||
# response = RedirectResponse(config.OAUTH_REDIRECT_URL, status_code=302)
|
||||
# return self._set_login_cookie(response, token)
|
||||
|
||||
# cookie_transport = CustomCookieTransport(
|
||||
# cookie_max_age=3600,
|
||||
# )
|
||||
|
||||
# auth_backend = AuthenticationBackend(
|
||||
# name="jwt",
|
||||
# transport=cookie_transport,
|
||||
# get_strategy=get_jwt_strategy,
|
||||
# )
|
||||
|
||||
# BEARER AUTH CODE.
|
||||
class CustomBearerTransport(BearerTransport):
|
||||
async def get_login_response(self, token: str) -> Response:
|
||||
bearer_response = BearerResponse(access_token=token, token_type="bearer")
|
||||
redirect_url = f"{config.NEXT_FRONTEND_URL}/auth/callback?token={bearer_response.access_token}"
|
||||
return RedirectResponse(redirect_url, status_code=302)
|
||||
|
||||
bearer_transport = CustomBearerTransport(tokenUrl="auth/jwt/login")
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="jwt",
|
||||
transport=bearer_transport,
|
||||
get_strategy=get_jwt_strategy,
|
||||
)
|
||||
|
||||
fastapi_users = FastAPIUsers[User, uuid.UUID](get_user_manager, [auth_backend])
|
||||
|
||||
current_active_user = fastapi_users.current_user(active=True)
|
||||
12
surfsense_backend/app/utils/check_ownership.py
Normal file
12
surfsense_backend/app/utils/check_ownership.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from app.db import User
|
||||
|
||||
# Helper function to check user ownership
|
||||
async def check_ownership(session: AsyncSession, model, item_id: int, user: User):
|
||||
item = await session.execute(select(model).filter(model.id == item_id, model.user_id == user.id))
|
||||
item = item.scalars().first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found or you don't have permission to access it")
|
||||
return item
|
||||
385
surfsense_backend/app/utils/connector_service.py
Normal file
385
surfsense_backend/app/utils/connector_service.py
Normal file
|
|
@ -0,0 +1,385 @@
|
|||
import json
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from tavily import TavilyClient
|
||||
|
||||
|
||||
class ConnectorService:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.retriever = ChucksHybridSearchRetriever(session)
|
||||
self.source_id_counter = 1
|
||||
|
||||
async def search_crawled_urls(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search for crawled URLs and return both the source information and langchain documents
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
crawled_urls_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="CRAWLED_URL"
|
||||
)
|
||||
|
||||
# Map crawled_urls_chunks to the required format
|
||||
mapped_sources = {}
|
||||
for i, chunk in enumerate(crawled_urls_chunks):
|
||||
#Fix for UI
|
||||
crawled_urls_chunks[i]['document']['id'] = self.source_id_counter
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
||||
# Create a mapped source entry
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"title": document.get('title', 'Untitled Document'),
|
||||
"description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])),
|
||||
"url": metadata.get('url', '')
|
||||
}
|
||||
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Use a unique identifier for tracking unique sources
|
||||
source_key = source.get("url") or source.get("title")
|
||||
if source_key and source_key not in mapped_sources:
|
||||
mapped_sources[source_key] = source
|
||||
|
||||
# Convert to list of sources
|
||||
sources_list = list(mapped_sources.values())
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 1,
|
||||
"name": "Crawled URLs",
|
||||
"type": "CRAWLED_URL",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
|
||||
return result_object, crawled_urls_chunks
|
||||
|
||||
async def search_files(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search for files and return both the source information and langchain documents
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
files_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="FILE"
|
||||
)
|
||||
|
||||
# Map crawled_urls_chunks to the required format
|
||||
mapped_sources = {}
|
||||
for i, chunk in enumerate(files_chunks):
|
||||
#Fix for UI
|
||||
files_chunks[i]['document']['id'] = self.source_id_counter
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
||||
# Create a mapped source entry
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"title": document.get('title', 'Untitled Document'),
|
||||
"description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])),
|
||||
"url": metadata.get('url', '')
|
||||
}
|
||||
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Use a unique identifier for tracking unique sources
|
||||
source_key = source.get("url") or source.get("title")
|
||||
if source_key and source_key not in mapped_sources:
|
||||
mapped_sources[source_key] = source
|
||||
|
||||
# Convert to list of sources
|
||||
sources_list = list(mapped_sources.values())
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 2,
|
||||
"name": "Files",
|
||||
"type": "FILE",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
return result_object, files_chunks
|
||||
|
||||
async def get_connector_by_type(self, user_id: int, connector_type: SearchSourceConnectorType) -> Optional[SearchSourceConnector]:
|
||||
"""
|
||||
Get a connector by type for a specific user
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
connector_type: The connector type to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[SearchSourceConnector]: The connector if found, None otherwise
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type == connector_type
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def search_tavily(self, user_query: str, user_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search using Tavily API and return both the source information and documents
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
user_id: The user's ID
|
||||
top_k: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, documents)
|
||||
"""
|
||||
# Get Tavily connector configuration
|
||||
tavily_connector = await self.get_connector_by_type(user_id, SearchSourceConnectorType.TAVILY_API)
|
||||
|
||||
if not tavily_connector:
|
||||
# Return empty results if no Tavily connector is configured
|
||||
return {
|
||||
"id": 3,
|
||||
"name": "Tavily Search",
|
||||
"type": "TAVILY_API",
|
||||
"sources": [],
|
||||
}, []
|
||||
|
||||
# Initialize Tavily client with API key from connector config
|
||||
tavily_api_key = tavily_connector.config.get("TAVILY_API_KEY")
|
||||
tavily_client = TavilyClient(api_key=tavily_api_key)
|
||||
|
||||
# Perform search with Tavily
|
||||
try:
|
||||
response = tavily_client.search(
|
||||
query=user_query,
|
||||
max_results=top_k,
|
||||
search_depth="advanced" # Use advanced search for better results
|
||||
)
|
||||
|
||||
# Extract results from Tavily response
|
||||
tavily_results = response.get("results", [])
|
||||
|
||||
# Map Tavily results to the required format
|
||||
sources_list = []
|
||||
documents = []
|
||||
|
||||
# Start IDs from 1000 to avoid conflicts with other connectors
|
||||
base_id = 100
|
||||
|
||||
for i, result in enumerate(tavily_results):
|
||||
|
||||
# Create a source entry
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"title": result.get("title", "Tavily Result"),
|
||||
"description": result.get("content", "")[:100],
|
||||
"url": result.get("url", "")
|
||||
}
|
||||
sources_list.append(source)
|
||||
|
||||
# Create a document entry
|
||||
document = {
|
||||
"chunk_id": f"tavily_chunk_{i}",
|
||||
"content": result.get("content", ""),
|
||||
"score": result.get("score", 0.0),
|
||||
"document": {
|
||||
"id": self.source_id_counter,
|
||||
"title": result.get("title", "Tavily Result"),
|
||||
"document_type": "TAVILY_API",
|
||||
"metadata": {
|
||||
"url": result.get("url", ""),
|
||||
"published_date": result.get("published_date", ""),
|
||||
"source": "TAVILY_API"
|
||||
}
|
||||
}
|
||||
}
|
||||
documents.append(document)
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 3,
|
||||
"name": "Tavily Search",
|
||||
"type": "TAVILY_API",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
return result_object, documents
|
||||
|
||||
except Exception as e:
|
||||
# Log the error and return empty results
|
||||
print(f"Error searching with Tavily: {str(e)}")
|
||||
return {
|
||||
"id": 3,
|
||||
"name": "Tavily Search",
|
||||
"type": "TAVILY_API",
|
||||
"sources": [],
|
||||
}, []
|
||||
|
||||
async def search_slack(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search for slack and return both the source information and langchain documents
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
slack_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="SLACK_CONNECTOR"
|
||||
)
|
||||
|
||||
# Map slack_chunks to the required format
|
||||
mapped_sources = {}
|
||||
for i, chunk in enumerate(slack_chunks):
|
||||
#Fix for UI
|
||||
slack_chunks[i]['document']['id'] = self.source_id_counter
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
||||
# Create a mapped source entry with Slack-specific metadata
|
||||
channel_name = metadata.get('channel_name', 'Unknown Channel')
|
||||
channel_id = metadata.get('channel_id', '')
|
||||
message_date = metadata.get('start_date', '')
|
||||
|
||||
# Create a more descriptive title for Slack messages
|
||||
title = f"Slack: {channel_name}"
|
||||
if message_date:
|
||||
title += f" ({message_date})"
|
||||
|
||||
# Create a more descriptive description for Slack messages
|
||||
description = chunk.get('content', '')[:100]
|
||||
if len(description) == 100:
|
||||
description += "..."
|
||||
|
||||
# For URL, we can use a placeholder or construct a URL to the Slack channel if available
|
||||
url = ""
|
||||
if channel_id:
|
||||
url = f"https://slack.com/app_redirect?channel={channel_id}"
|
||||
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"url": url,
|
||||
}
|
||||
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Use channel_id and content as a unique identifier for tracking unique sources
|
||||
source_key = f"{channel_id}_{chunk.get('chunk_id', i)}"
|
||||
if source_key and source_key not in mapped_sources:
|
||||
mapped_sources[source_key] = source
|
||||
|
||||
# Convert to list of sources
|
||||
sources_list = list(mapped_sources.values())
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 4,
|
||||
"name": "Slack",
|
||||
"type": "SLACK_CONNECTOR",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
return result_object, slack_chunks
|
||||
|
||||
async def search_notion(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search for Notion pages and return both the source information and langchain documents
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
user_id: The user's ID
|
||||
search_space_id: The search space ID to search in
|
||||
top_k: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
notion_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="NOTION_CONNECTOR"
|
||||
)
|
||||
|
||||
# Map notion_chunks to the required format
|
||||
mapped_sources = {}
|
||||
for i, chunk in enumerate(notion_chunks):
|
||||
# Fix for UI
|
||||
notion_chunks[i]['document']['id'] = self.source_id_counter
|
||||
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
||||
# Create a mapped source entry with Notion-specific metadata
|
||||
page_title = metadata.get('page_title', 'Untitled Page')
|
||||
page_id = metadata.get('page_id', '')
|
||||
indexed_at = metadata.get('indexed_at', '')
|
||||
|
||||
# Create a more descriptive title for Notion pages
|
||||
title = f"Notion: {page_title}"
|
||||
if indexed_at:
|
||||
title += f" (indexed: {indexed_at})"
|
||||
|
||||
# Create a more descriptive description for Notion pages
|
||||
description = chunk.get('content', '')[:100]
|
||||
if len(description) == 100:
|
||||
description += "..."
|
||||
|
||||
# For URL, we can use a placeholder or construct a URL to the Notion page if available
|
||||
url = ""
|
||||
if page_id:
|
||||
# Notion page URLs follow this format
|
||||
url = f"https://notion.so/{page_id.replace('-', '')}"
|
||||
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"url": url,
|
||||
}
|
||||
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Use page_id and content as a unique identifier for tracking unique sources
|
||||
source_key = f"{page_id}_{chunk.get('chunk_id', i)}"
|
||||
if source_key and source_key not in mapped_sources:
|
||||
mapped_sources[source_key] = source
|
||||
|
||||
# Convert to list of sources
|
||||
sources_list = list(mapped_sources.values())
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 5,
|
||||
"name": "Notion",
|
||||
"type": "NOTION_CONNECTOR",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
return result_object, notion_chunks
|
||||
136
surfsense_backend/app/utils/document_converters.py
Normal file
136
surfsense_backend/app/utils/document_converters.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
async def convert_element_to_markdown(element) -> str:
|
||||
"""
|
||||
Convert an Unstructured element to markdown format based on its category.
|
||||
|
||||
Args:
|
||||
element: The Unstructured API element object
|
||||
|
||||
Returns:
|
||||
str: Markdown formatted string
|
||||
"""
|
||||
element_category = element.metadata["category"]
|
||||
content = element.page_content
|
||||
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
markdown_mapping = {
|
||||
"Formula": lambda x: f"```math\n{x}\n```",
|
||||
"FigureCaption": lambda x: f"*Figure: {x}*",
|
||||
"NarrativeText": lambda x: f"{x}\n\n",
|
||||
"ListItem": lambda x: f"- {x}\n",
|
||||
"Title": lambda x: f"# {x}\n\n",
|
||||
"Address": lambda x: f"> {x}\n\n",
|
||||
"EmailAddress": lambda x: f"`{x}`",
|
||||
"Image": lambda x: f"",
|
||||
"PageBreak": lambda x: "\n---\n",
|
||||
"Table": lambda x: f"```html\n{element.metadata['text_as_html']}\n```",
|
||||
"Header": lambda x: f"## {x}\n\n",
|
||||
"Footer": lambda x: f"*{x}*\n\n",
|
||||
"CodeSnippet": lambda x: f"```\n{x}\n```",
|
||||
"PageNumber": lambda x: f"*Page {x}*\n\n",
|
||||
"UncategorizedText": lambda x: f"{x}\n\n"
|
||||
}
|
||||
|
||||
converter = markdown_mapping.get(element_category, lambda x: x)
|
||||
return converter(content)
|
||||
|
||||
|
||||
async def convert_document_to_markdown(elements):
|
||||
"""
|
||||
Convert all document elements to markdown.
|
||||
|
||||
Args:
|
||||
elements: List of Unstructured API elements
|
||||
|
||||
Returns:
|
||||
str: Complete markdown document
|
||||
"""
|
||||
markdown_parts = []
|
||||
|
||||
for element in elements:
|
||||
markdown_text = await convert_element_to_markdown(element)
|
||||
if markdown_text:
|
||||
markdown_parts.append(markdown_text)
|
||||
|
||||
return "".join(markdown_parts)
|
||||
|
||||
def convert_chunks_to_langchain_documents(chunks):
|
||||
"""
|
||||
Convert chunks from hybrid search results to LangChain Document objects.
|
||||
|
||||
Args:
|
||||
chunks: List of chunk dictionaries from hybrid search results
|
||||
|
||||
Returns:
|
||||
List of LangChain Document objects
|
||||
"""
|
||||
try:
|
||||
from langchain_core.documents import Document as LangChainDocument
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"LangChain is not installed. Please install it with `pip install langchain langchain-core`"
|
||||
)
|
||||
|
||||
langchain_docs = []
|
||||
|
||||
for chunk in chunks:
|
||||
# Extract content from the chunk
|
||||
content = chunk.get("content", "")
|
||||
|
||||
# Create metadata dictionary
|
||||
metadata = {
|
||||
"chunk_id": chunk.get("chunk_id"),
|
||||
"score": chunk.get("score"),
|
||||
"rank": chunk.get("rank") if "rank" in chunk else None,
|
||||
}
|
||||
|
||||
# Add document information to metadata
|
||||
if "document" in chunk:
|
||||
doc = chunk["document"]
|
||||
metadata.update({
|
||||
"document_id": doc.get("id"),
|
||||
"document_title": doc.get("title"),
|
||||
"document_type": doc.get("document_type"),
|
||||
})
|
||||
|
||||
# Add document metadata if available
|
||||
if "metadata" in doc:
|
||||
# Prefix document metadata keys to avoid conflicts
|
||||
doc_metadata = {f"doc_meta_{k}": v for k, v in doc.get("metadata", {}).items()}
|
||||
metadata.update(doc_metadata)
|
||||
|
||||
# Add source URL if available in metadata
|
||||
if "url" in doc.get("metadata", {}):
|
||||
metadata["source"] = doc["metadata"]["url"]
|
||||
elif "sourceURL" in doc.get("metadata", {}):
|
||||
metadata["source"] = doc["metadata"]["sourceURL"]
|
||||
|
||||
# Ensure source_id is set for citation purposes
|
||||
# Use document_id as the source_id if available
|
||||
if "document_id" in metadata:
|
||||
metadata["source_id"] = metadata["document_id"]
|
||||
|
||||
# Update content for citation mode - format as XML with explicit source_id
|
||||
new_content = f"""
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>{metadata.get("source_id", metadata.get("document_id", "unknown"))}</source_id>
|
||||
</metadata>
|
||||
<content>
|
||||
<text>
|
||||
{content}
|
||||
</text>
|
||||
</content>
|
||||
</document>
|
||||
"""
|
||||
|
||||
# Create LangChain Document
|
||||
langchain_doc = LangChainDocument(
|
||||
page_content=new_content,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
langchain_docs.append(langchain_doc)
|
||||
|
||||
return langchain_docs
|
||||
95
surfsense_backend/app/utils/reranker_service.py
Normal file
95
surfsense_backend/app/utils/reranker_service.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from rerankers import Document as RerankerDocument
|
||||
|
||||
class RerankerService:
|
||||
"""
|
||||
Service for reranking documents using a configured reranker
|
||||
"""
|
||||
|
||||
def __init__(self, reranker_instance=None):
|
||||
"""
|
||||
Initialize the reranker service
|
||||
|
||||
Args:
|
||||
reranker_instance: The reranker instance to use for reranking
|
||||
"""
|
||||
self.reranker_instance = reranker_instance
|
||||
|
||||
def rerank_documents(self, query_text: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Rerank documents using the configured reranker
|
||||
|
||||
Args:
|
||||
query_text: The query text to use for reranking
|
||||
documents: List of document dictionaries to rerank
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Reranked documents
|
||||
"""
|
||||
if not self.reranker_instance or not documents:
|
||||
return documents
|
||||
|
||||
try:
|
||||
# Create Document objects for the rerankers library
|
||||
reranker_docs = []
|
||||
for i, doc in enumerate(documents):
|
||||
chunk_id = doc.get("chunk_id", f"chunk_{i}")
|
||||
content = doc.get("content", "")
|
||||
score = doc.get("score", 0.0)
|
||||
document_info = doc.get("document", {})
|
||||
|
||||
reranker_docs.append(
|
||||
RerankerDocument(
|
||||
text=content,
|
||||
doc_id=chunk_id,
|
||||
metadata={
|
||||
'document_id': document_info.get("id", ""),
|
||||
'document_title': document_info.get("title", ""),
|
||||
'document_type': document_info.get("document_type", ""),
|
||||
'rrf_score': score
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Rerank using the configured reranker
|
||||
reranking_results = self.reranker_instance.rank(
|
||||
query=query_text,
|
||||
docs=reranker_docs
|
||||
)
|
||||
|
||||
# Process the results from the reranker
|
||||
# Convert to serializable dictionaries
|
||||
serialized_results = []
|
||||
for result in reranking_results.results:
|
||||
# Find the original document by id
|
||||
original_doc = next((doc for doc in documents if doc.get("chunk_id") == result.document.doc_id), None)
|
||||
if original_doc:
|
||||
# Create a new document with the reranked score
|
||||
reranked_doc = original_doc.copy()
|
||||
reranked_doc["score"] = float(result.score)
|
||||
reranked_doc["rank"] = result.rank
|
||||
serialized_results.append(reranked_doc)
|
||||
|
||||
return serialized_results
|
||||
|
||||
except Exception as e:
|
||||
# Log the error
|
||||
logging.error(f"Error during reranking: {str(e)}")
|
||||
# Fall back to original documents without reranking
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_reranker_instance(config=None) -> Optional['RerankerService']:
|
||||
"""
|
||||
Get a reranker service instance based on configuration
|
||||
|
||||
Args:
|
||||
config: Configuration object that may contain a reranker_instance
|
||||
|
||||
Returns:
|
||||
Optional[RerankerService]: A reranker service instance or None
|
||||
"""
|
||||
if config and hasattr(config, 'reranker_instance') and config.reranker_instance:
|
||||
return RerankerService(config.reranker_instance)
|
||||
return None
|
||||
211
surfsense_backend/app/utils/research_service.py
Normal file
211
surfsense_backend/app/utils/research_service.py
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
import asyncio
|
||||
import re
|
||||
from typing import List, Dict, Any, AsyncGenerator, Callable, Optional
|
||||
from langchain.schema import Document
|
||||
from gpt_researcher.agent import GPTResearcher
|
||||
from gpt_researcher.utils.enum import ReportType, Tone, ReportSource
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
class ResearchService:
|
||||
@staticmethod
|
||||
async def create_custom_prompt(user_query: str) -> str:
|
||||
citation_prompt = f"""
|
||||
You are a research assistant tasked with analyzing documents and providing comprehensive answers with proper citations in IEEE format.
|
||||
|
||||
<instructions>
|
||||
1. Carefully analyze all provided documents in the <document> section's.
|
||||
2. Extract relevant information that addresses the user's query.
|
||||
3. Synthesize a comprehensive, well-structured answer using information from these documents.
|
||||
4. For EVERY piece of information you include from the documents, add an IEEE-style citation in square brackets [X] where X is the source_id from the document's metadata.
|
||||
5. Make sure ALL factual statements from the documents have proper citations.
|
||||
6. If multiple documents support the same point, include all relevant citations [X], [Y].
|
||||
7. Present information in a logical, coherent flow.
|
||||
8. Use your own words to connect ideas, but cite ALL information from the documents.
|
||||
9. If documents contain conflicting information, acknowledge this and present both perspectives with appropriate citations.
|
||||
10. Do not make up or include information not found in the provided documents.
|
||||
11. CRITICAL: You MUST use the exact source_id value from each document's metadata for citations. Do not create your own citation numbers.
|
||||
12. CRITICAL: Every citation MUST be in the IEEE format [X] where X is the exact source_id value.
|
||||
13. CRITICAL: Never renumber or reorder citations - always use the original source_id values.
|
||||
14. CRITICAL: Do not return citations as clickable links.
|
||||
15. CRITICAL: Never format citations as markdown links like "([1](https://example.com))". Always use plain square brackets only.
|
||||
16. CRITICAL: Citations must ONLY appear as [X] or [X], [Y], [Z] format - never with parentheses, hyperlinks, or other formatting.
|
||||
17. CRITICAL: Never make up citation numbers. Only use source_id values that are explicitly provided in the document metadata.
|
||||
18. CRITICAL: If you are unsure about a source_id, do not include a citation rather than guessing or making one up.
|
||||
</instructions>
|
||||
|
||||
<format>
|
||||
- Write in clear, professional language suitable for academic or technical audiences
|
||||
- Organize your response with appropriate paragraphs, headings, and structure
|
||||
- Every fact from the documents must have an IEEE-style citation in square brackets [X] where X is the EXACT source_id from the document's metadata
|
||||
- Citations should appear at the end of the sentence containing the information they support
|
||||
- Multiple citations should be separated by commas: [X], [Y], [Z]
|
||||
- No need to return references section. Just citation numbers in answer.
|
||||
- NEVER create your own citation numbering system - use the exact source_id values from the documents.
|
||||
- NEVER format citations as clickable links or as markdown links like "([1](https://example.com))". Always use plain square brackets only.
|
||||
- NEVER make up citation numbers if you are unsure about the source_id. It is better to omit the citation than to guess.
|
||||
</format>
|
||||
|
||||
<input_example>
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>1</source_id>
|
||||
</metadata>
|
||||
<content>
|
||||
<text>
|
||||
The Great Barrier Reef is the world's largest coral reef system, stretching over 2,300 kilometers along the coast of Queensland, Australia. It comprises over 2,900 individual reefs and 900 islands.
|
||||
</text>
|
||||
</content>
|
||||
</document>
|
||||
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>13</source_id>
|
||||
</metadata>
|
||||
<content>
|
||||
<text>
|
||||
Climate change poses a significant threat to coral reefs worldwide. Rising ocean temperatures have led to mass coral bleaching events in the Great Barrier Reef in 2016, 2017, and 2020.
|
||||
</text>
|
||||
</content>
|
||||
</document>
|
||||
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>21</source_id>
|
||||
</metadata>
|
||||
<content>
|
||||
<text>
|
||||
The Great Barrier Reef was designated a UNESCO World Heritage Site in 1981 due to its outstanding universal value and biological diversity. It is home to over 1,500 species of fish and 400 types of coral.
|
||||
</text>
|
||||
</content>
|
||||
</document>
|
||||
</input_example>
|
||||
|
||||
<output_example>
|
||||
The Great Barrier Reef is the world's largest coral reef system, stretching over 2,300 kilometers along the coast of Queensland, Australia [1]. It was designated a UNESCO World Heritage Site in 1981 due to its outstanding universal value and biological diversity [21]. The reef is home to over 1,500 species of fish and 400 types of coral [21]. Unfortunately, climate change poses a significant threat to coral reefs worldwide, with rising ocean temperatures leading to mass coral bleaching events in the Great Barrier Reef in 2016, 2017, and 2020 [13]. The reef system comprises over 2,900 individual reefs and 900 islands [1], making it an ecological treasure that requires protection from multiple threats [1], [13].
|
||||
</output_example>
|
||||
|
||||
<incorrect_citation_formats>
|
||||
DO NOT use any of these incorrect citation formats:
|
||||
- Using parentheses and markdown links: ([1](https://github.com/MODSetter/SurfSense))
|
||||
- Using parentheses around brackets: ([1])
|
||||
- Using hyperlinked text: [link to source 1](https://example.com)
|
||||
- Using footnote style: ... reef system¹
|
||||
- Making up citation numbers when source_id is unknown
|
||||
|
||||
ONLY use plain square brackets [1] or multiple citations [1], [2], [3]
|
||||
</incorrect_citation_formats>
|
||||
|
||||
Note that the citation numbers match exactly with the source_id values (1, 13, and 21) and are not renumbered sequentially. Citations follow IEEE style with square brackets and appear at the end of sentences.
|
||||
|
||||
Now, please research the following query:
|
||||
|
||||
<user_query_to_research>
|
||||
{user_query}
|
||||
</user_query_to_research>
|
||||
"""
|
||||
|
||||
return citation_prompt
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def stream_research(
|
||||
user_query: str,
|
||||
documents: List[Document] = None,
|
||||
on_progress: Optional[Callable] = None,
|
||||
research_mode: str = "GENERAL"
|
||||
) -> str:
|
||||
"""
|
||||
Stream the research process using GPTResearcher
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
documents: List of Document objects to use for research
|
||||
on_progress: Optional callback for progress updates
|
||||
research_mode: Research mode to use
|
||||
|
||||
Returns:
|
||||
str: The final research report
|
||||
"""
|
||||
# Create a custom websocket-like object to capture streaming output
|
||||
class StreamingWebsocket:
|
||||
async def send_json(self, data):
|
||||
if on_progress:
|
||||
try:
|
||||
# Filter out excessive logging of the prompt
|
||||
if data.get("type") == "logs":
|
||||
output = data.get("output", "")
|
||||
# Check if this is a verbose prompt log
|
||||
if "You are a research assistant tasked with analyzing documents" in output and len(output) > 500:
|
||||
# Replace with a shorter message
|
||||
data["output"] = f"Processing research for query: {user_query}"
|
||||
|
||||
result = await on_progress(data)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"Error in on_progress callback: {e}")
|
||||
return None
|
||||
|
||||
streaming_websocket = StreamingWebsocket()
|
||||
|
||||
custom_prompt_for_ieee_citations = await ResearchService.create_custom_prompt(user_query)
|
||||
|
||||
if(research_mode == "GENERAL"):
|
||||
research_report_type = ReportType.CustomReport.value
|
||||
elif(research_mode == "DEEP"):
|
||||
research_report_type = ReportType.ResearchReport.value
|
||||
elif(research_mode == "DEEPER"):
|
||||
research_report_type = ReportType.DetailedReport.value
|
||||
# elif(research_mode == "DEEPEST"):
|
||||
# research_report_type = ReportType.DeepResearch.value
|
||||
|
||||
# Initialize GPTResearcher with the streaming websocket
|
||||
researcher = GPTResearcher(
|
||||
query=custom_prompt_for_ieee_citations,
|
||||
report_type=research_report_type,
|
||||
report_format="IEEE",
|
||||
report_source=ReportSource.LangChainDocuments.value,
|
||||
tone=Tone.Formal,
|
||||
documents=documents,
|
||||
verbose=True,
|
||||
websocket=streaming_websocket
|
||||
)
|
||||
|
||||
# Conduct research
|
||||
await researcher.conduct_research()
|
||||
|
||||
# Generate report with streaming
|
||||
report = await researcher.write_report()
|
||||
|
||||
# Fix citation format
|
||||
report = ResearchService.fix_citation_format(report)
|
||||
|
||||
return report
|
||||
|
||||
@staticmethod
|
||||
def fix_citation_format(text: str) -> str:
|
||||
"""
|
||||
Fix any incorrectly formatted citations in the text.
|
||||
|
||||
Args:
|
||||
text: The text to fix
|
||||
|
||||
Returns:
|
||||
str: The text with fixed citations
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
# More specific pattern to match only numeric citations in markdown-style links
|
||||
# This matches patterns like ([1](https://github.com/...)) but not general links like ([Click here](https://...))
|
||||
pattern = r'\(\[(\d+)\]\((https?://[^\)]+)\)\)'
|
||||
|
||||
# Replace with just [X] where X is the number
|
||||
text = re.sub(pattern, r'[\1]', text)
|
||||
|
||||
# Also match other incorrect formats like ([1]) and convert to [1]
|
||||
# Only match if the content inside brackets is a number
|
||||
text = re.sub(r'\(\[(\d+)\]\)', r'[\1]', text)
|
||||
|
||||
return text
|
||||
99
surfsense_backend/app/utils/streaming_service.py
Normal file
99
surfsense_backend/app/utils/streaming_service.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
import json
|
||||
from typing import List, Dict, Any, Generator
|
||||
|
||||
class StreamingService:
|
||||
def __init__(self):
|
||||
self.terminal_idx = 1
|
||||
self.message_annotations = [
|
||||
{
|
||||
"type": "TERMINAL_INFO",
|
||||
"content": []
|
||||
},
|
||||
{
|
||||
"type": "SOURCES",
|
||||
"content": []
|
||||
},
|
||||
{
|
||||
"type": "ANSWER",
|
||||
"content": []
|
||||
}
|
||||
]
|
||||
|
||||
def add_terminal_message(self, text: str, message_type: str = "info") -> str:
|
||||
"""
|
||||
Add a terminal message to the annotations and return the formatted response
|
||||
|
||||
Args:
|
||||
text: The message text
|
||||
message_type: The message type (info, success, error)
|
||||
|
||||
Returns:
|
||||
str: The formatted response string
|
||||
"""
|
||||
self.message_annotations[0]["content"].append({
|
||||
"id": self.terminal_idx,
|
||||
"text": text,
|
||||
"type": message_type
|
||||
})
|
||||
self.terminal_idx += 1
|
||||
return self._format_annotations()
|
||||
|
||||
def update_sources(self, sources: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Update the sources in the annotations and return the formatted response
|
||||
|
||||
Args:
|
||||
sources: List of source objects
|
||||
|
||||
Returns:
|
||||
str: The formatted response string
|
||||
"""
|
||||
self.message_annotations[1]["content"] = sources
|
||||
return self._format_annotations()
|
||||
|
||||
def update_answer(self, answer_content: List[str]) -> str:
|
||||
"""
|
||||
Update the answer in the annotations and return the formatted response
|
||||
|
||||
Args:
|
||||
answer_content: The answer content as a list of strings
|
||||
|
||||
Returns:
|
||||
str: The formatted response string
|
||||
"""
|
||||
self.message_annotations[2] = {
|
||||
"type": "ANSWER",
|
||||
"content": answer_content
|
||||
}
|
||||
return self._format_annotations()
|
||||
|
||||
def _format_annotations(self) -> str:
|
||||
"""
|
||||
Format the annotations as a string
|
||||
|
||||
Returns:
|
||||
str: The formatted annotations string
|
||||
"""
|
||||
return f'8:{json.dumps(self.message_annotations)}\n'
|
||||
|
||||
def format_completion(self, prompt_tokens: int = 156, completion_tokens: int = 204) -> str:
|
||||
"""
|
||||
Format a completion message
|
||||
|
||||
Args:
|
||||
prompt_tokens: Number of prompt tokens
|
||||
completion_tokens: Number of completion tokens
|
||||
|
||||
Returns:
|
||||
str: The formatted completion string
|
||||
"""
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
completion_data = {
|
||||
"finishReason": "stop",
|
||||
"usage": {
|
||||
"promptTokens": prompt_tokens,
|
||||
"completionTokens": completion_tokens,
|
||||
"totalTokens": total_tokens
|
||||
}
|
||||
}
|
||||
return f'd:{json.dumps(completion_data)}\n'
|
||||
4
surfsense_backend/main.py
Normal file
4
surfsense_backend/main.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
import uvicorn
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run("app.app:app", host="0.0.0.0", log_level="info")
|
||||
27
surfsense_backend/pyproject.toml
Normal file
27
surfsense_backend/pyproject.toml
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
[project]
|
||||
name = "surf-new-backend"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"asyncpg>=0.30.0",
|
||||
"chonkie[all]>=0.4.1",
|
||||
"fastapi>=0.115.8",
|
||||
"fastapi-users[oauth,sqlalchemy]>=14.0.1",
|
||||
"firecrawl-py>=1.12.0",
|
||||
"gpt-researcher>=0.12.12",
|
||||
"langchain-community>=0.3.17",
|
||||
"langchain-unstructured>=0.1.6",
|
||||
"litellm>=1.61.4",
|
||||
"markdownify>=0.14.1",
|
||||
"notion-client>=2.3.0",
|
||||
"pgvector>=0.3.6",
|
||||
"playwright>=1.50.0",
|
||||
"rerankers[flashrank]>=0.7.1",
|
||||
"slack-sdk>=3.34.0",
|
||||
"tavily-python>=0.3.2",
|
||||
"unstructured-client>=0.30.0",
|
||||
"uvicorn[standard]>=0.34.0",
|
||||
"validators>=0.34.0",
|
||||
]
|
||||
3271
surfsense_backend/uv.lock
generated
Normal file
3271
surfsense_backend/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue