mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-02 04:12:47 +02:00
feat: SurfSense v0.0.6 init
This commit is contained in:
parent
18fc19e8d9
commit
da23012970
58 changed files with 8284 additions and 2076 deletions
12
surfsense_backend/app/utils/check_ownership.py
Normal file
12
surfsense_backend/app/utils/check_ownership.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from fastapi import HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from app.db import User
|
||||
|
||||
# Helper function to check user ownership
|
||||
async def check_ownership(session: AsyncSession, model, item_id: int, user: User):
|
||||
item = await session.execute(select(model).filter(model.id == item_id, model.user_id == user.id))
|
||||
item = item.scalars().first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found or you don't have permission to access it")
|
||||
return item
|
||||
385
surfsense_backend/app/utils/connector_service.py
Normal file
385
surfsense_backend/app/utils/connector_service.py
Normal file
|
|
@ -0,0 +1,385 @@
|
|||
import json
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from app.retriver.chunks_hybrid_search import ChucksHybridSearchRetriever
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from tavily import TavilyClient
|
||||
|
||||
|
||||
class ConnectorService:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.retriever = ChucksHybridSearchRetriever(session)
|
||||
self.source_id_counter = 1
|
||||
|
||||
async def search_crawled_urls(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search for crawled URLs and return both the source information and langchain documents
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
crawled_urls_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="CRAWLED_URL"
|
||||
)
|
||||
|
||||
# Map crawled_urls_chunks to the required format
|
||||
mapped_sources = {}
|
||||
for i, chunk in enumerate(crawled_urls_chunks):
|
||||
#Fix for UI
|
||||
crawled_urls_chunks[i]['document']['id'] = self.source_id_counter
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
||||
# Create a mapped source entry
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"title": document.get('title', 'Untitled Document'),
|
||||
"description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])),
|
||||
"url": metadata.get('url', '')
|
||||
}
|
||||
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Use a unique identifier for tracking unique sources
|
||||
source_key = source.get("url") or source.get("title")
|
||||
if source_key and source_key not in mapped_sources:
|
||||
mapped_sources[source_key] = source
|
||||
|
||||
# Convert to list of sources
|
||||
sources_list = list(mapped_sources.values())
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 1,
|
||||
"name": "Crawled URLs",
|
||||
"type": "CRAWLED_URL",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
|
||||
return result_object, crawled_urls_chunks
|
||||
|
||||
async def search_files(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search for files and return both the source information and langchain documents
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
files_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="FILE"
|
||||
)
|
||||
|
||||
# Map crawled_urls_chunks to the required format
|
||||
mapped_sources = {}
|
||||
for i, chunk in enumerate(files_chunks):
|
||||
#Fix for UI
|
||||
files_chunks[i]['document']['id'] = self.source_id_counter
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
||||
# Create a mapped source entry
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"title": document.get('title', 'Untitled Document'),
|
||||
"description": metadata.get('og:description', metadata.get('ogDescription', chunk.get('content', '')[:100])),
|
||||
"url": metadata.get('url', '')
|
||||
}
|
||||
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Use a unique identifier for tracking unique sources
|
||||
source_key = source.get("url") or source.get("title")
|
||||
if source_key and source_key not in mapped_sources:
|
||||
mapped_sources[source_key] = source
|
||||
|
||||
# Convert to list of sources
|
||||
sources_list = list(mapped_sources.values())
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 2,
|
||||
"name": "Files",
|
||||
"type": "FILE",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
return result_object, files_chunks
|
||||
|
||||
async def get_connector_by_type(self, user_id: int, connector_type: SearchSourceConnectorType) -> Optional[SearchSourceConnector]:
|
||||
"""
|
||||
Get a connector by type for a specific user
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
connector_type: The connector type to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[SearchSourceConnector]: The connector if found, None otherwise
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(SearchSourceConnector)
|
||||
.filter(
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type == connector_type
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def search_tavily(self, user_query: str, user_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search using Tavily API and return both the source information and documents
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
user_id: The user's ID
|
||||
top_k: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, documents)
|
||||
"""
|
||||
# Get Tavily connector configuration
|
||||
tavily_connector = await self.get_connector_by_type(user_id, SearchSourceConnectorType.TAVILY_API)
|
||||
|
||||
if not tavily_connector:
|
||||
# Return empty results if no Tavily connector is configured
|
||||
return {
|
||||
"id": 3,
|
||||
"name": "Tavily Search",
|
||||
"type": "TAVILY_API",
|
||||
"sources": [],
|
||||
}, []
|
||||
|
||||
# Initialize Tavily client with API key from connector config
|
||||
tavily_api_key = tavily_connector.config.get("TAVILY_API_KEY")
|
||||
tavily_client = TavilyClient(api_key=tavily_api_key)
|
||||
|
||||
# Perform search with Tavily
|
||||
try:
|
||||
response = tavily_client.search(
|
||||
query=user_query,
|
||||
max_results=top_k,
|
||||
search_depth="advanced" # Use advanced search for better results
|
||||
)
|
||||
|
||||
# Extract results from Tavily response
|
||||
tavily_results = response.get("results", [])
|
||||
|
||||
# Map Tavily results to the required format
|
||||
sources_list = []
|
||||
documents = []
|
||||
|
||||
# Start IDs from 1000 to avoid conflicts with other connectors
|
||||
base_id = 100
|
||||
|
||||
for i, result in enumerate(tavily_results):
|
||||
|
||||
# Create a source entry
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"title": result.get("title", "Tavily Result"),
|
||||
"description": result.get("content", "")[:100],
|
||||
"url": result.get("url", "")
|
||||
}
|
||||
sources_list.append(source)
|
||||
|
||||
# Create a document entry
|
||||
document = {
|
||||
"chunk_id": f"tavily_chunk_{i}",
|
||||
"content": result.get("content", ""),
|
||||
"score": result.get("score", 0.0),
|
||||
"document": {
|
||||
"id": self.source_id_counter,
|
||||
"title": result.get("title", "Tavily Result"),
|
||||
"document_type": "TAVILY_API",
|
||||
"metadata": {
|
||||
"url": result.get("url", ""),
|
||||
"published_date": result.get("published_date", ""),
|
||||
"source": "TAVILY_API"
|
||||
}
|
||||
}
|
||||
}
|
||||
documents.append(document)
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 3,
|
||||
"name": "Tavily Search",
|
||||
"type": "TAVILY_API",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
return result_object, documents
|
||||
|
||||
except Exception as e:
|
||||
# Log the error and return empty results
|
||||
print(f"Error searching with Tavily: {str(e)}")
|
||||
return {
|
||||
"id": 3,
|
||||
"name": "Tavily Search",
|
||||
"type": "TAVILY_API",
|
||||
"sources": [],
|
||||
}, []
|
||||
|
||||
async def search_slack(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search for slack and return both the source information and langchain documents
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
slack_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="SLACK_CONNECTOR"
|
||||
)
|
||||
|
||||
# Map slack_chunks to the required format
|
||||
mapped_sources = {}
|
||||
for i, chunk in enumerate(slack_chunks):
|
||||
#Fix for UI
|
||||
slack_chunks[i]['document']['id'] = self.source_id_counter
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
||||
# Create a mapped source entry with Slack-specific metadata
|
||||
channel_name = metadata.get('channel_name', 'Unknown Channel')
|
||||
channel_id = metadata.get('channel_id', '')
|
||||
message_date = metadata.get('start_date', '')
|
||||
|
||||
# Create a more descriptive title for Slack messages
|
||||
title = f"Slack: {channel_name}"
|
||||
if message_date:
|
||||
title += f" ({message_date})"
|
||||
|
||||
# Create a more descriptive description for Slack messages
|
||||
description = chunk.get('content', '')[:100]
|
||||
if len(description) == 100:
|
||||
description += "..."
|
||||
|
||||
# For URL, we can use a placeholder or construct a URL to the Slack channel if available
|
||||
url = ""
|
||||
if channel_id:
|
||||
url = f"https://slack.com/app_redirect?channel={channel_id}"
|
||||
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"url": url,
|
||||
}
|
||||
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Use channel_id and content as a unique identifier for tracking unique sources
|
||||
source_key = f"{channel_id}_{chunk.get('chunk_id', i)}"
|
||||
if source_key and source_key not in mapped_sources:
|
||||
mapped_sources[source_key] = source
|
||||
|
||||
# Convert to list of sources
|
||||
sources_list = list(mapped_sources.values())
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 4,
|
||||
"name": "Slack",
|
||||
"type": "SLACK_CONNECTOR",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
return result_object, slack_chunks
|
||||
|
||||
async def search_notion(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search for Notion pages and return both the source information and langchain documents
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
user_id: The user's ID
|
||||
search_space_id: The search space ID to search in
|
||||
top_k: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
tuple: (sources_info, langchain_documents)
|
||||
"""
|
||||
notion_chunks = await self.retriever.hybrid_search(
|
||||
query_text=user_query,
|
||||
top_k=top_k,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
document_type="NOTION_CONNECTOR"
|
||||
)
|
||||
|
||||
# Map notion_chunks to the required format
|
||||
mapped_sources = {}
|
||||
for i, chunk in enumerate(notion_chunks):
|
||||
# Fix for UI
|
||||
notion_chunks[i]['document']['id'] = self.source_id_counter
|
||||
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
||||
# Create a mapped source entry with Notion-specific metadata
|
||||
page_title = metadata.get('page_title', 'Untitled Page')
|
||||
page_id = metadata.get('page_id', '')
|
||||
indexed_at = metadata.get('indexed_at', '')
|
||||
|
||||
# Create a more descriptive title for Notion pages
|
||||
title = f"Notion: {page_title}"
|
||||
if indexed_at:
|
||||
title += f" (indexed: {indexed_at})"
|
||||
|
||||
# Create a more descriptive description for Notion pages
|
||||
description = chunk.get('content', '')[:100]
|
||||
if len(description) == 100:
|
||||
description += "..."
|
||||
|
||||
# For URL, we can use a placeholder or construct a URL to the Notion page if available
|
||||
url = ""
|
||||
if page_id:
|
||||
# Notion page URLs follow this format
|
||||
url = f"https://notion.so/{page_id.replace('-', '')}"
|
||||
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"url": url,
|
||||
}
|
||||
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Use page_id and content as a unique identifier for tracking unique sources
|
||||
source_key = f"{page_id}_{chunk.get('chunk_id', i)}"
|
||||
if source_key and source_key not in mapped_sources:
|
||||
mapped_sources[source_key] = source
|
||||
|
||||
# Convert to list of sources
|
||||
sources_list = list(mapped_sources.values())
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 5,
|
||||
"name": "Notion",
|
||||
"type": "NOTION_CONNECTOR",
|
||||
"sources": sources_list,
|
||||
}
|
||||
|
||||
return result_object, notion_chunks
|
||||
136
surfsense_backend/app/utils/document_converters.py
Normal file
136
surfsense_backend/app/utils/document_converters.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
async def convert_element_to_markdown(element) -> str:
|
||||
"""
|
||||
Convert an Unstructured element to markdown format based on its category.
|
||||
|
||||
Args:
|
||||
element: The Unstructured API element object
|
||||
|
||||
Returns:
|
||||
str: Markdown formatted string
|
||||
"""
|
||||
element_category = element.metadata["category"]
|
||||
content = element.page_content
|
||||
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
markdown_mapping = {
|
||||
"Formula": lambda x: f"```math\n{x}\n```",
|
||||
"FigureCaption": lambda x: f"*Figure: {x}*",
|
||||
"NarrativeText": lambda x: f"{x}\n\n",
|
||||
"ListItem": lambda x: f"- {x}\n",
|
||||
"Title": lambda x: f"# {x}\n\n",
|
||||
"Address": lambda x: f"> {x}\n\n",
|
||||
"EmailAddress": lambda x: f"`{x}`",
|
||||
"Image": lambda x: f"",
|
||||
"PageBreak": lambda x: "\n---\n",
|
||||
"Table": lambda x: f"```html\n{element.metadata['text_as_html']}\n```",
|
||||
"Header": lambda x: f"## {x}\n\n",
|
||||
"Footer": lambda x: f"*{x}*\n\n",
|
||||
"CodeSnippet": lambda x: f"```\n{x}\n```",
|
||||
"PageNumber": lambda x: f"*Page {x}*\n\n",
|
||||
"UncategorizedText": lambda x: f"{x}\n\n"
|
||||
}
|
||||
|
||||
converter = markdown_mapping.get(element_category, lambda x: x)
|
||||
return converter(content)
|
||||
|
||||
|
||||
async def convert_document_to_markdown(elements):
|
||||
"""
|
||||
Convert all document elements to markdown.
|
||||
|
||||
Args:
|
||||
elements: List of Unstructured API elements
|
||||
|
||||
Returns:
|
||||
str: Complete markdown document
|
||||
"""
|
||||
markdown_parts = []
|
||||
|
||||
for element in elements:
|
||||
markdown_text = await convert_element_to_markdown(element)
|
||||
if markdown_text:
|
||||
markdown_parts.append(markdown_text)
|
||||
|
||||
return "".join(markdown_parts)
|
||||
|
||||
def convert_chunks_to_langchain_documents(chunks):
|
||||
"""
|
||||
Convert chunks from hybrid search results to LangChain Document objects.
|
||||
|
||||
Args:
|
||||
chunks: List of chunk dictionaries from hybrid search results
|
||||
|
||||
Returns:
|
||||
List of LangChain Document objects
|
||||
"""
|
||||
try:
|
||||
from langchain_core.documents import Document as LangChainDocument
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"LangChain is not installed. Please install it with `pip install langchain langchain-core`"
|
||||
)
|
||||
|
||||
langchain_docs = []
|
||||
|
||||
for chunk in chunks:
|
||||
# Extract content from the chunk
|
||||
content = chunk.get("content", "")
|
||||
|
||||
# Create metadata dictionary
|
||||
metadata = {
|
||||
"chunk_id": chunk.get("chunk_id"),
|
||||
"score": chunk.get("score"),
|
||||
"rank": chunk.get("rank") if "rank" in chunk else None,
|
||||
}
|
||||
|
||||
# Add document information to metadata
|
||||
if "document" in chunk:
|
||||
doc = chunk["document"]
|
||||
metadata.update({
|
||||
"document_id": doc.get("id"),
|
||||
"document_title": doc.get("title"),
|
||||
"document_type": doc.get("document_type"),
|
||||
})
|
||||
|
||||
# Add document metadata if available
|
||||
if "metadata" in doc:
|
||||
# Prefix document metadata keys to avoid conflicts
|
||||
doc_metadata = {f"doc_meta_{k}": v for k, v in doc.get("metadata", {}).items()}
|
||||
metadata.update(doc_metadata)
|
||||
|
||||
# Add source URL if available in metadata
|
||||
if "url" in doc.get("metadata", {}):
|
||||
metadata["source"] = doc["metadata"]["url"]
|
||||
elif "sourceURL" in doc.get("metadata", {}):
|
||||
metadata["source"] = doc["metadata"]["sourceURL"]
|
||||
|
||||
# Ensure source_id is set for citation purposes
|
||||
# Use document_id as the source_id if available
|
||||
if "document_id" in metadata:
|
||||
metadata["source_id"] = metadata["document_id"]
|
||||
|
||||
# Update content for citation mode - format as XML with explicit source_id
|
||||
new_content = f"""
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>{metadata.get("source_id", metadata.get("document_id", "unknown"))}</source_id>
|
||||
</metadata>
|
||||
<content>
|
||||
<text>
|
||||
{content}
|
||||
</text>
|
||||
</content>
|
||||
</document>
|
||||
"""
|
||||
|
||||
# Create LangChain Document
|
||||
langchain_doc = LangChainDocument(
|
||||
page_content=new_content,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
langchain_docs.append(langchain_doc)
|
||||
|
||||
return langchain_docs
|
||||
95
surfsense_backend/app/utils/reranker_service.py
Normal file
95
surfsense_backend/app/utils/reranker_service.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from rerankers import Document as RerankerDocument
|
||||
|
||||
class RerankerService:
|
||||
"""
|
||||
Service for reranking documents using a configured reranker
|
||||
"""
|
||||
|
||||
def __init__(self, reranker_instance=None):
|
||||
"""
|
||||
Initialize the reranker service
|
||||
|
||||
Args:
|
||||
reranker_instance: The reranker instance to use for reranking
|
||||
"""
|
||||
self.reranker_instance = reranker_instance
|
||||
|
||||
def rerank_documents(self, query_text: str, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Rerank documents using the configured reranker
|
||||
|
||||
Args:
|
||||
query_text: The query text to use for reranking
|
||||
documents: List of document dictionaries to rerank
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: Reranked documents
|
||||
"""
|
||||
if not self.reranker_instance or not documents:
|
||||
return documents
|
||||
|
||||
try:
|
||||
# Create Document objects for the rerankers library
|
||||
reranker_docs = []
|
||||
for i, doc in enumerate(documents):
|
||||
chunk_id = doc.get("chunk_id", f"chunk_{i}")
|
||||
content = doc.get("content", "")
|
||||
score = doc.get("score", 0.0)
|
||||
document_info = doc.get("document", {})
|
||||
|
||||
reranker_docs.append(
|
||||
RerankerDocument(
|
||||
text=content,
|
||||
doc_id=chunk_id,
|
||||
metadata={
|
||||
'document_id': document_info.get("id", ""),
|
||||
'document_title': document_info.get("title", ""),
|
||||
'document_type': document_info.get("document_type", ""),
|
||||
'rrf_score': score
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Rerank using the configured reranker
|
||||
reranking_results = self.reranker_instance.rank(
|
||||
query=query_text,
|
||||
docs=reranker_docs
|
||||
)
|
||||
|
||||
# Process the results from the reranker
|
||||
# Convert to serializable dictionaries
|
||||
serialized_results = []
|
||||
for result in reranking_results.results:
|
||||
# Find the original document by id
|
||||
original_doc = next((doc for doc in documents if doc.get("chunk_id") == result.document.doc_id), None)
|
||||
if original_doc:
|
||||
# Create a new document with the reranked score
|
||||
reranked_doc = original_doc.copy()
|
||||
reranked_doc["score"] = float(result.score)
|
||||
reranked_doc["rank"] = result.rank
|
||||
serialized_results.append(reranked_doc)
|
||||
|
||||
return serialized_results
|
||||
|
||||
except Exception as e:
|
||||
# Log the error
|
||||
logging.error(f"Error during reranking: {str(e)}")
|
||||
# Fall back to original documents without reranking
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_reranker_instance(config=None) -> Optional['RerankerService']:
|
||||
"""
|
||||
Get a reranker service instance based on configuration
|
||||
|
||||
Args:
|
||||
config: Configuration object that may contain a reranker_instance
|
||||
|
||||
Returns:
|
||||
Optional[RerankerService]: A reranker service instance or None
|
||||
"""
|
||||
if config and hasattr(config, 'reranker_instance') and config.reranker_instance:
|
||||
return RerankerService(config.reranker_instance)
|
||||
return None
|
||||
211
surfsense_backend/app/utils/research_service.py
Normal file
211
surfsense_backend/app/utils/research_service.py
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
import asyncio
|
||||
import re
|
||||
from typing import List, Dict, Any, AsyncGenerator, Callable, Optional
|
||||
from langchain.schema import Document
|
||||
from gpt_researcher.agent import GPTResearcher
|
||||
from gpt_researcher.utils.enum import ReportType, Tone, ReportSource
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
class ResearchService:
|
||||
@staticmethod
|
||||
async def create_custom_prompt(user_query: str) -> str:
|
||||
citation_prompt = f"""
|
||||
You are a research assistant tasked with analyzing documents and providing comprehensive answers with proper citations in IEEE format.
|
||||
|
||||
<instructions>
|
||||
1. Carefully analyze all provided documents in the <document> section's.
|
||||
2. Extract relevant information that addresses the user's query.
|
||||
3. Synthesize a comprehensive, well-structured answer using information from these documents.
|
||||
4. For EVERY piece of information you include from the documents, add an IEEE-style citation in square brackets [X] where X is the source_id from the document's metadata.
|
||||
5. Make sure ALL factual statements from the documents have proper citations.
|
||||
6. If multiple documents support the same point, include all relevant citations [X], [Y].
|
||||
7. Present information in a logical, coherent flow.
|
||||
8. Use your own words to connect ideas, but cite ALL information from the documents.
|
||||
9. If documents contain conflicting information, acknowledge this and present both perspectives with appropriate citations.
|
||||
10. Do not make up or include information not found in the provided documents.
|
||||
11. CRITICAL: You MUST use the exact source_id value from each document's metadata for citations. Do not create your own citation numbers.
|
||||
12. CRITICAL: Every citation MUST be in the IEEE format [X] where X is the exact source_id value.
|
||||
13. CRITICAL: Never renumber or reorder citations - always use the original source_id values.
|
||||
14. CRITICAL: Do not return citations as clickable links.
|
||||
15. CRITICAL: Never format citations as markdown links like "([1](https://example.com))". Always use plain square brackets only.
|
||||
16. CRITICAL: Citations must ONLY appear as [X] or [X], [Y], [Z] format - never with parentheses, hyperlinks, or other formatting.
|
||||
17. CRITICAL: Never make up citation numbers. Only use source_id values that are explicitly provided in the document metadata.
|
||||
18. CRITICAL: If you are unsure about a source_id, do not include a citation rather than guessing or making one up.
|
||||
</instructions>
|
||||
|
||||
<format>
|
||||
- Write in clear, professional language suitable for academic or technical audiences
|
||||
- Organize your response with appropriate paragraphs, headings, and structure
|
||||
- Every fact from the documents must have an IEEE-style citation in square brackets [X] where X is the EXACT source_id from the document's metadata
|
||||
- Citations should appear at the end of the sentence containing the information they support
|
||||
- Multiple citations should be separated by commas: [X], [Y], [Z]
|
||||
- No need to return references section. Just citation numbers in answer.
|
||||
- NEVER create your own citation numbering system - use the exact source_id values from the documents.
|
||||
- NEVER format citations as clickable links or as markdown links like "([1](https://example.com))". Always use plain square brackets only.
|
||||
- NEVER make up citation numbers if you are unsure about the source_id. It is better to omit the citation than to guess.
|
||||
</format>
|
||||
|
||||
<input_example>
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>1</source_id>
|
||||
</metadata>
|
||||
<content>
|
||||
<text>
|
||||
The Great Barrier Reef is the world's largest coral reef system, stretching over 2,300 kilometers along the coast of Queensland, Australia. It comprises over 2,900 individual reefs and 900 islands.
|
||||
</text>
|
||||
</content>
|
||||
</document>
|
||||
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>13</source_id>
|
||||
</metadata>
|
||||
<content>
|
||||
<text>
|
||||
Climate change poses a significant threat to coral reefs worldwide. Rising ocean temperatures have led to mass coral bleaching events in the Great Barrier Reef in 2016, 2017, and 2020.
|
||||
</text>
|
||||
</content>
|
||||
</document>
|
||||
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>21</source_id>
|
||||
</metadata>
|
||||
<content>
|
||||
<text>
|
||||
The Great Barrier Reef was designated a UNESCO World Heritage Site in 1981 due to its outstanding universal value and biological diversity. It is home to over 1,500 species of fish and 400 types of coral.
|
||||
</text>
|
||||
</content>
|
||||
</document>
|
||||
</input_example>
|
||||
|
||||
<output_example>
|
||||
The Great Barrier Reef is the world's largest coral reef system, stretching over 2,300 kilometers along the coast of Queensland, Australia [1]. It was designated a UNESCO World Heritage Site in 1981 due to its outstanding universal value and biological diversity [21]. The reef is home to over 1,500 species of fish and 400 types of coral [21]. Unfortunately, climate change poses a significant threat to coral reefs worldwide, with rising ocean temperatures leading to mass coral bleaching events in the Great Barrier Reef in 2016, 2017, and 2020 [13]. The reef system comprises over 2,900 individual reefs and 900 islands [1], making it an ecological treasure that requires protection from multiple threats [1], [13].
|
||||
</output_example>
|
||||
|
||||
<incorrect_citation_formats>
|
||||
DO NOT use any of these incorrect citation formats:
|
||||
- Using parentheses and markdown links: ([1](https://github.com/MODSetter/SurfSense))
|
||||
- Using parentheses around brackets: ([1])
|
||||
- Using hyperlinked text: [link to source 1](https://example.com)
|
||||
- Using footnote style: ... reef system¹
|
||||
- Making up citation numbers when source_id is unknown
|
||||
|
||||
ONLY use plain square brackets [1] or multiple citations [1], [2], [3]
|
||||
</incorrect_citation_formats>
|
||||
|
||||
Note that the citation numbers match exactly with the source_id values (1, 13, and 21) and are not renumbered sequentially. Citations follow IEEE style with square brackets and appear at the end of sentences.
|
||||
|
||||
Now, please research the following query:
|
||||
|
||||
<user_query_to_research>
|
||||
{user_query}
|
||||
</user_query_to_research>
|
||||
"""
|
||||
|
||||
return citation_prompt
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def stream_research(
|
||||
user_query: str,
|
||||
documents: List[Document] = None,
|
||||
on_progress: Optional[Callable] = None,
|
||||
research_mode: str = "GENERAL"
|
||||
) -> str:
|
||||
"""
|
||||
Stream the research process using GPTResearcher
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
documents: List of Document objects to use for research
|
||||
on_progress: Optional callback for progress updates
|
||||
research_mode: Research mode to use
|
||||
|
||||
Returns:
|
||||
str: The final research report
|
||||
"""
|
||||
# Create a custom websocket-like object to capture streaming output
|
||||
class StreamingWebsocket:
|
||||
async def send_json(self, data):
|
||||
if on_progress:
|
||||
try:
|
||||
# Filter out excessive logging of the prompt
|
||||
if data.get("type") == "logs":
|
||||
output = data.get("output", "")
|
||||
# Check if this is a verbose prompt log
|
||||
if "You are a research assistant tasked with analyzing documents" in output and len(output) > 500:
|
||||
# Replace with a shorter message
|
||||
data["output"] = f"Processing research for query: {user_query}"
|
||||
|
||||
result = await on_progress(data)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"Error in on_progress callback: {e}")
|
||||
return None
|
||||
|
||||
streaming_websocket = StreamingWebsocket()
|
||||
|
||||
custom_prompt_for_ieee_citations = await ResearchService.create_custom_prompt(user_query)
|
||||
|
||||
if(research_mode == "GENERAL"):
|
||||
research_report_type = ReportType.CustomReport.value
|
||||
elif(research_mode == "DEEP"):
|
||||
research_report_type = ReportType.ResearchReport.value
|
||||
elif(research_mode == "DEEPER"):
|
||||
research_report_type = ReportType.DetailedReport.value
|
||||
# elif(research_mode == "DEEPEST"):
|
||||
# research_report_type = ReportType.DeepResearch.value
|
||||
|
||||
# Initialize GPTResearcher with the streaming websocket
|
||||
researcher = GPTResearcher(
|
||||
query=custom_prompt_for_ieee_citations,
|
||||
report_type=research_report_type,
|
||||
report_format="IEEE",
|
||||
report_source=ReportSource.LangChainDocuments.value,
|
||||
tone=Tone.Formal,
|
||||
documents=documents,
|
||||
verbose=True,
|
||||
websocket=streaming_websocket
|
||||
)
|
||||
|
||||
# Conduct research
|
||||
await researcher.conduct_research()
|
||||
|
||||
# Generate report with streaming
|
||||
report = await researcher.write_report()
|
||||
|
||||
# Fix citation format
|
||||
report = ResearchService.fix_citation_format(report)
|
||||
|
||||
return report
|
||||
|
||||
@staticmethod
|
||||
def fix_citation_format(text: str) -> str:
|
||||
"""
|
||||
Fix any incorrectly formatted citations in the text.
|
||||
|
||||
Args:
|
||||
text: The text to fix
|
||||
|
||||
Returns:
|
||||
str: The text with fixed citations
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
# More specific pattern to match only numeric citations in markdown-style links
|
||||
# This matches patterns like ([1](https://github.com/...)) but not general links like ([Click here](https://...))
|
||||
pattern = r'\(\[(\d+)\]\((https?://[^\)]+)\)\)'
|
||||
|
||||
# Replace with just [X] where X is the number
|
||||
text = re.sub(pattern, r'[\1]', text)
|
||||
|
||||
# Also match other incorrect formats like ([1]) and convert to [1]
|
||||
# Only match if the content inside brackets is a number
|
||||
text = re.sub(r'\(\[(\d+)\]\)', r'[\1]', text)
|
||||
|
||||
return text
|
||||
99
surfsense_backend/app/utils/streaming_service.py
Normal file
99
surfsense_backend/app/utils/streaming_service.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
import json
|
||||
from typing import List, Dict, Any, Generator
|
||||
|
||||
class StreamingService:
|
||||
def __init__(self):
|
||||
self.terminal_idx = 1
|
||||
self.message_annotations = [
|
||||
{
|
||||
"type": "TERMINAL_INFO",
|
||||
"content": []
|
||||
},
|
||||
{
|
||||
"type": "SOURCES",
|
||||
"content": []
|
||||
},
|
||||
{
|
||||
"type": "ANSWER",
|
||||
"content": []
|
||||
}
|
||||
]
|
||||
|
||||
def add_terminal_message(self, text: str, message_type: str = "info") -> str:
|
||||
"""
|
||||
Add a terminal message to the annotations and return the formatted response
|
||||
|
||||
Args:
|
||||
text: The message text
|
||||
message_type: The message type (info, success, error)
|
||||
|
||||
Returns:
|
||||
str: The formatted response string
|
||||
"""
|
||||
self.message_annotations[0]["content"].append({
|
||||
"id": self.terminal_idx,
|
||||
"text": text,
|
||||
"type": message_type
|
||||
})
|
||||
self.terminal_idx += 1
|
||||
return self._format_annotations()
|
||||
|
||||
def update_sources(self, sources: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Update the sources in the annotations and return the formatted response
|
||||
|
||||
Args:
|
||||
sources: List of source objects
|
||||
|
||||
Returns:
|
||||
str: The formatted response string
|
||||
"""
|
||||
self.message_annotations[1]["content"] = sources
|
||||
return self._format_annotations()
|
||||
|
||||
def update_answer(self, answer_content: List[str]) -> str:
|
||||
"""
|
||||
Update the answer in the annotations and return the formatted response
|
||||
|
||||
Args:
|
||||
answer_content: The answer content as a list of strings
|
||||
|
||||
Returns:
|
||||
str: The formatted response string
|
||||
"""
|
||||
self.message_annotations[2] = {
|
||||
"type": "ANSWER",
|
||||
"content": answer_content
|
||||
}
|
||||
return self._format_annotations()
|
||||
|
||||
def _format_annotations(self) -> str:
|
||||
"""
|
||||
Format the annotations as a string
|
||||
|
||||
Returns:
|
||||
str: The formatted annotations string
|
||||
"""
|
||||
return f'8:{json.dumps(self.message_annotations)}\n'
|
||||
|
||||
def format_completion(self, prompt_tokens: int = 156, completion_tokens: int = 204) -> str:
|
||||
"""
|
||||
Format a completion message
|
||||
|
||||
Args:
|
||||
prompt_tokens: Number of prompt tokens
|
||||
completion_tokens: Number of completion tokens
|
||||
|
||||
Returns:
|
||||
str: The formatted completion string
|
||||
"""
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
completion_data = {
|
||||
"finishReason": "stop",
|
||||
"usage": {
|
||||
"promptTokens": prompt_tokens,
|
||||
"completionTokens": completion_tokens,
|
||||
"totalTokens": total_tokens
|
||||
}
|
||||
}
|
||||
return f'd:{json.dumps(completion_data)}\n'
|
||||
Loading…
Add table
Add a link
Reference in a new issue