mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-25 19:15:18 +02:00
Merge branch 'main' of https://github.com/MODSetter/SurfSense into add-github-connector
This commit is contained in:
commit
a69bbb32f7
13 changed files with 565 additions and 81 deletions
0
surfsense_backend/app/agents/__init__.py
Normal file
0
surfsense_backend/app/agents/__init__.py
Normal file
0
surfsense_backend/app/agents/researcher/__init__.py
Normal file
0
surfsense_backend/app/agents/researcher/__init__.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
"""New LangGraph Agent.
|
||||
|
||||
This module defines a custom graph.
|
||||
"""
|
||||
|
||||
from .graph import graph
|
||||
|
||||
__all__ = ["graph"]
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
"""Define the configurable parameters for the agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Optional, List
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Configuration:
|
||||
"""The configuration for the agent."""
|
||||
|
||||
# Input parameters provided at invocation
|
||||
sub_section_title: str
|
||||
sub_questions: List[str]
|
||||
connectors_to_search: List[str]
|
||||
user_id: str
|
||||
search_space_id: int
|
||||
top_k: int = 20 # Default top_k value
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
cls, config: Optional[RunnableConfig] = None
|
||||
) -> Configuration:
|
||||
"""Create a Configuration instance from a RunnableConfig object."""
|
||||
configurable = (config.get("configurable") or {}) if config else {}
|
||||
_fields = {f.name for f in fields(cls) if f.init}
|
||||
return cls(**{k: v for k, v in configurable.items() if k in _fields})
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
from langgraph.graph import StateGraph
|
||||
from .state import State
|
||||
from .nodes import fetch_relevant_documents, write_sub_section
|
||||
from .configuration import Configuration
|
||||
|
||||
# Define a new graph
|
||||
workflow = StateGraph(State, config_schema=Configuration)
|
||||
|
||||
# Add the nodes to the graph
|
||||
workflow.add_node("fetch_relevant_documents", fetch_relevant_documents)
|
||||
workflow.add_node("write_sub_section", write_sub_section)
|
||||
|
||||
# Entry point
|
||||
workflow.add_edge("__start__", "fetch_relevant_documents")
|
||||
# Connect fetch_relevant_documents to write_sub_section
|
||||
workflow.add_edge("fetch_relevant_documents", "write_sub_section")
|
||||
# Exit point
|
||||
workflow.add_edge("write_sub_section", "__end__")
|
||||
|
||||
# Compile the workflow into an executable graph
|
||||
graph = workflow.compile()
|
||||
graph.name = "Sub Section Writer" # This defines the custom name in LangSmith
|
||||
|
|
@ -0,0 +1,244 @@
|
|||
from .configuration import Configuration
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from .state import State
|
||||
from typing import Any, Dict
|
||||
from app.utils.connector_service import ConnectorService
|
||||
from app.utils.reranker_service import RerankerService
|
||||
from app.config import config as app_config
|
||||
from .prompts import citation_system_prompt
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
async def fetch_relevant_documents(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch relevant documents for the sub-section using specified connectors.
|
||||
|
||||
This node retrieves documents from various data sources based on the sub-questions
|
||||
derived from the sub-section title. It searches across all selected connectors
|
||||
(YouTube, Extension, Crawled URLs, Files, Tavily API, Slack, Notion) and reranks
|
||||
the results to provide the most relevant information for the agent workflow.
|
||||
|
||||
Returns:
|
||||
Dict containing the reranked documents in the "relevant_documents_fetched" key.
|
||||
"""
|
||||
# Get configuration
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
|
||||
# Extract state parameters
|
||||
db_session = state.db_session
|
||||
|
||||
# Extract config parameters
|
||||
user_id = configuration.user_id
|
||||
search_space_id = configuration.search_space_id
|
||||
TOP_K = configuration.top_k
|
||||
|
||||
# Initialize services
|
||||
connector_service = ConnectorService(db_session)
|
||||
reranker_service = RerankerService.get_reranker_instance(app_config)
|
||||
|
||||
all_raw_documents = [] # Store all raw documents before reranking
|
||||
|
||||
for user_query in configuration.sub_questions:
|
||||
# Reformulate query (optional, consider if needed for each sub-question)
|
||||
# reformulated_query = await QueryService.reformulate_query(user_query)
|
||||
reformulated_query = user_query # Using original sub-question for now
|
||||
|
||||
# Process each selected connector
|
||||
for connector in configuration.connectors_to_search:
|
||||
if connector == "YOUTUBE_VIDEO":
|
||||
_, youtube_chunks = await connector_service.search_youtube(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
all_raw_documents.extend(youtube_chunks)
|
||||
|
||||
elif connector == "EXTENSION":
|
||||
_, extension_chunks = await connector_service.search_extension(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
all_raw_documents.extend(extension_chunks)
|
||||
|
||||
elif connector == "CRAWLED_URL":
|
||||
_, crawled_urls_chunks = await connector_service.search_crawled_urls(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
all_raw_documents.extend(crawled_urls_chunks)
|
||||
|
||||
elif connector == "FILE":
|
||||
_, files_chunks = await connector_service.search_files(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
all_raw_documents.extend(files_chunks)
|
||||
|
||||
elif connector == "TAVILY_API":
|
||||
_, tavily_chunks = await connector_service.search_tavily(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
all_raw_documents.extend(tavily_chunks)
|
||||
|
||||
elif connector == "SLACK_CONNECTOR":
|
||||
_, slack_chunks = await connector_service.search_slack(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
all_raw_documents.extend(slack_chunks)
|
||||
|
||||
elif connector == "NOTION_CONNECTOR":
|
||||
_, notion_chunks = await connector_service.search_notion(
|
||||
user_query=reformulated_query,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
top_k=TOP_K
|
||||
)
|
||||
all_raw_documents.extend(notion_chunks)
|
||||
|
||||
# If we have documents and a reranker is available, rerank them
|
||||
# Deduplicate documents based on chunk_id or content to avoid processing duplicates
|
||||
seen_chunk_ids = set()
|
||||
seen_content_hashes = set()
|
||||
deduplicated_docs = []
|
||||
|
||||
for doc in all_raw_documents:
|
||||
chunk_id = doc.get("chunk_id")
|
||||
content = doc.get("content", "")
|
||||
content_hash = hash(content)
|
||||
|
||||
# Skip if we've seen this chunk_id or content before
|
||||
if (chunk_id and chunk_id in seen_chunk_ids) or content_hash in seen_content_hashes:
|
||||
continue
|
||||
|
||||
# Add to our tracking sets and keep this document
|
||||
if chunk_id:
|
||||
seen_chunk_ids.add(chunk_id)
|
||||
seen_content_hashes.add(content_hash)
|
||||
deduplicated_docs.append(doc)
|
||||
|
||||
# Use deduplicated documents for reranking
|
||||
reranked_docs = deduplicated_docs
|
||||
if deduplicated_docs and reranker_service:
|
||||
# Use the main sub_section_title for reranking context
|
||||
rerank_query = configuration.sub_section_title
|
||||
|
||||
# Convert documents to format expected by reranker
|
||||
reranker_input_docs = [
|
||||
{
|
||||
"chunk_id": doc.get("chunk_id", f"chunk_{i}"),
|
||||
"content": doc.get("content", ""),
|
||||
"score": doc.get("score", 0.0),
|
||||
"document": {
|
||||
"id": doc.get("document", {}).get("id", ""),
|
||||
"title": doc.get("document", {}).get("title", ""),
|
||||
"document_type": doc.get("document", {}).get("document_type", ""),
|
||||
"metadata": doc.get("document", {}).get("metadata", {})
|
||||
}
|
||||
} for i, doc in enumerate(deduplicated_docs)
|
||||
]
|
||||
|
||||
# Rerank documents using the main title query
|
||||
reranked_docs = reranker_service.rerank_documents(rerank_query, reranker_input_docs)
|
||||
|
||||
# Sort by score in descending order
|
||||
reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||||
|
||||
# Update state with fetched documents
|
||||
return {
|
||||
"relevant_documents_fetched": reranked_docs
|
||||
}
|
||||
|
||||
|
||||
|
||||
async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
Write the sub-section using the fetched documents.
|
||||
|
||||
This node takes the relevant documents fetched in the previous node and uses
|
||||
an LLM to generate a comprehensive answer to the sub-section questions with
|
||||
proper citations. The citations follow IEEE format using source IDs from the
|
||||
documents.
|
||||
|
||||
Returns:
|
||||
Dict containing the final answer in the "final_answer" key.
|
||||
"""
|
||||
|
||||
# Get configuration and relevant documents
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
documents = state.relevant_documents_fetched
|
||||
|
||||
# Initialize LLM
|
||||
llm = app_config.fast_llm_instance
|
||||
|
||||
# If no documents were found, return a message indicating this
|
||||
if not documents or len(documents) == 0:
|
||||
return {
|
||||
"final_answer": "No relevant documents were found to answer this question. Please try refining your search or providing more specific questions."
|
||||
}
|
||||
|
||||
# Prepare documents for citation formatting
|
||||
formatted_documents = []
|
||||
for i, doc in enumerate(documents):
|
||||
# Extract content and metadata
|
||||
content = doc.get("content", "")
|
||||
doc_info = doc.get("document", {})
|
||||
document_id = doc_info.get("id", f"{i+1}") # Use document ID or index+1 as source_id
|
||||
|
||||
# Format document according to the citation system prompt's expected format
|
||||
formatted_doc = f"""
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>{document_id}</source_id>
|
||||
</metadata>
|
||||
<content>
|
||||
{content}
|
||||
</content>
|
||||
</document>
|
||||
"""
|
||||
formatted_documents.append(formatted_doc)
|
||||
|
||||
# Create the query that combines the section title and questions
|
||||
# section_title = configuration.sub_section_title
|
||||
questions = "\n".join([f"- {q}" for q in configuration.sub_questions])
|
||||
documents_text = "\n".join(formatted_documents)
|
||||
|
||||
# Construct a clear, structured query for the LLM
|
||||
human_message_content = f"""
|
||||
Please write a comprehensive answer for the title:
|
||||
|
||||
Address the following questions:
|
||||
<questions>
|
||||
{questions}
|
||||
</questions>
|
||||
|
||||
Use the provided documents as your source material and cite them properly using the IEEE citation format [X] where X is the source_id.
|
||||
<documents>
|
||||
{documents_text}
|
||||
</documents>
|
||||
"""
|
||||
|
||||
# Create messages for the LLM
|
||||
messages = [
|
||||
SystemMessage(content=citation_system_prompt),
|
||||
HumanMessage(content=human_message_content)
|
||||
]
|
||||
|
||||
# Call the LLM and get the response
|
||||
response = await llm.ainvoke(messages)
|
||||
final_answer = response.content
|
||||
|
||||
return {
|
||||
"final_answer": final_answer
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
citation_system_prompt = f"""
|
||||
You are a research assistant tasked with analyzing documents and providing comprehensive answers with proper citations in IEEE format.
|
||||
|
||||
<instructions>
|
||||
1. Carefully analyze all provided documents in the <document> section's.
|
||||
2. Extract relevant information that addresses the user's query.
|
||||
3. Synthesize a comprehensive, well-structured answer using information from these documents.
|
||||
4. For EVERY piece of information you include from the documents, add an IEEE-style citation in square brackets [X] where X is the source_id from the document's metadata.
|
||||
5. Make sure ALL factual statements from the documents have proper citations.
|
||||
6. If multiple documents support the same point, include all relevant citations [X], [Y].
|
||||
7. Present information in a logical, coherent flow.
|
||||
8. Use your own words to connect ideas, but cite ALL information from the documents.
|
||||
9. If documents contain conflicting information, acknowledge this and present both perspectives with appropriate citations.
|
||||
10. Do not make up or include information not found in the provided documents.
|
||||
11. CRITICAL: You MUST use the exact source_id value from each document's metadata for citations. Do not create your own citation numbers.
|
||||
12. CRITICAL: Every citation MUST be in the IEEE format [X] where X is the exact source_id value.
|
||||
13. CRITICAL: Never renumber or reorder citations - always use the original source_id values.
|
||||
14. CRITICAL: Do not return citations as clickable links.
|
||||
15. CRITICAL: Never format citations as markdown links like "([1](https://example.com))". Always use plain square brackets only.
|
||||
16. CRITICAL: Citations must ONLY appear as [X] or [X], [Y], [Z] format - never with parentheses, hyperlinks, or other formatting.
|
||||
17. CRITICAL: Never make up citation numbers. Only use source_id values that are explicitly provided in the document metadata.
|
||||
18. CRITICAL: If you are unsure about a source_id, do not include a citation rather than guessing or making one up.
|
||||
</instructions>
|
||||
|
||||
<format>
|
||||
- Write in clear, professional language suitable for academic or technical audiences
|
||||
- Organize your response with appropriate paragraphs, headings, and structure
|
||||
- Every fact from the documents must have an IEEE-style citation in square brackets [X] where X is the EXACT source_id from the document's metadata
|
||||
- Citations should appear at the end of the sentence containing the information they support
|
||||
- Multiple citations should be separated by commas: [X], [Y], [Z]
|
||||
- No need to return references section. Just citation numbers in answer.
|
||||
- NEVER create your own citation numbering system - use the exact source_id values from the documents.
|
||||
- NEVER format citations as clickable links or as markdown links like "([1](https://example.com))". Always use plain square brackets only.
|
||||
- NEVER make up citation numbers if you are unsure about the source_id. It is better to omit the citation than to guess.
|
||||
</format>
|
||||
|
||||
<input_example>
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>1</source_id>
|
||||
</metadata>
|
||||
<content>
|
||||
The Great Barrier Reef is the world's largest coral reef system, stretching over 2,300 kilometers along the coast of Queensland, Australia. It comprises over 2,900 individual reefs and 900 islands.
|
||||
</content>
|
||||
</document>
|
||||
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>13</source_id>
|
||||
</metadata>
|
||||
<content>
|
||||
Climate change poses a significant threat to coral reefs worldwide. Rising ocean temperatures have led to mass coral bleaching events in the Great Barrier Reef in 2016, 2017, and 2020.
|
||||
</content>
|
||||
</document>
|
||||
|
||||
<document>
|
||||
<metadata>
|
||||
<source_id>21</source_id>
|
||||
</metadata>
|
||||
<content>
|
||||
The Great Barrier Reef was designated a UNESCO World Heritage Site in 1981 due to its outstanding universal value and biological diversity. It is home to over 1,500 species of fish and 400 types of coral.
|
||||
</content>
|
||||
</document>
|
||||
</input_example>
|
||||
|
||||
<output_example>
|
||||
The Great Barrier Reef is the world's largest coral reef system, stretching over 2,300 kilometers along the coast of Queensland, Australia [1]. It was designated a UNESCO World Heritage Site in 1981 due to its outstanding universal value and biological diversity [21]. The reef is home to over 1,500 species of fish and 400 types of coral [21]. Unfortunately, climate change poses a significant threat to coral reefs worldwide, with rising ocean temperatures leading to mass coral bleaching events in the Great Barrier Reef in 2016, 2017, and 2020 [13]. The reef system comprises over 2,900 individual reefs and 900 islands [1], making it an ecological treasure that requires protection from multiple threats [1], [13].
|
||||
</output_example>
|
||||
|
||||
<incorrect_citation_formats>
|
||||
DO NOT use any of these incorrect citation formats:
|
||||
- Using parentheses and markdown links: ([1](https://github.com/MODSetter/SurfSense))
|
||||
- Using parentheses around brackets: ([1])
|
||||
- Using hyperlinked text: [link to source 1](https://example.com)
|
||||
- Using footnote style: ... reef system¹
|
||||
- Making up citation numbers when source_id is unknown
|
||||
|
||||
ONLY use plain square brackets [1] or multiple citations [1], [2], [3]
|
||||
</incorrect_citation_formats>
|
||||
|
||||
Note that the citation numbers match exactly with the source_id values (1, 13, and 21) and are not renumbered sequentially. Citations follow IEEE style with square brackets and appear at the end of sentences.
|
||||
"""
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
"""Define the state structures for the agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
"""Defines the dynamic state for the agent during execution.
|
||||
|
||||
This state tracks the database session and the outputs generated by the agent's nodes.
|
||||
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
||||
for more information.
|
||||
"""
|
||||
# Runtime context
|
||||
db_session: AsyncSession
|
||||
|
||||
# OUTPUT: Populated by agent nodes
|
||||
relevant_documents_fetched: Optional[List[Any]] = None
|
||||
final_answer: Optional[str] = None
|
||||
|
||||
|
|
@ -46,7 +46,7 @@ async def handle_chat_data(
|
|||
response = StreamingResponse(stream_connector_search_results(
|
||||
user_query,
|
||||
user.id,
|
||||
search_space_id,
|
||||
search_space_id, # Already converted to int in lines 32-37
|
||||
session,
|
||||
research_mode,
|
||||
selected_connectors
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from app.utils.document_converters import convert_chunks_to_langchain_documents
|
|||
|
||||
async def stream_connector_search_results(
|
||||
user_query: str,
|
||||
user_id: int,
|
||||
user_id: str,
|
||||
search_space_id: int,
|
||||
session: AsyncSession,
|
||||
research_mode: str,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ class ConnectorService:
|
|||
self.retriever = ChucksHybridSearchRetriever(session)
|
||||
self.source_id_counter = 1
|
||||
|
||||
async def search_crawled_urls(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_crawled_urls(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search for crawled URLs and return both the source information and langchain documents
|
||||
|
||||
|
|
@ -28,16 +28,16 @@ class ConnectorService:
|
|||
document_type="CRAWLED_URL"
|
||||
)
|
||||
|
||||
# Map crawled_urls_chunks to the required format
|
||||
mapped_sources = {}
|
||||
# Process each chunk and create sources directly without deduplication
|
||||
sources_list = []
|
||||
for i, chunk in enumerate(crawled_urls_chunks):
|
||||
#Fix for UI
|
||||
# Fix for UI
|
||||
crawled_urls_chunks[i]['document']['id'] = self.source_id_counter
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
||||
# Create a mapped source entry
|
||||
# Create a source entry
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"title": document.get('title', 'Untitled Document'),
|
||||
|
|
@ -46,14 +46,7 @@ class ConnectorService:
|
|||
}
|
||||
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Use a unique identifier for tracking unique sources
|
||||
source_key = source.get("url") or source.get("title")
|
||||
if source_key and source_key not in mapped_sources:
|
||||
mapped_sources[source_key] = source
|
||||
|
||||
# Convert to list of sources
|
||||
sources_list = list(mapped_sources.values())
|
||||
sources_list.append(source)
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
|
|
@ -63,10 +56,9 @@ class ConnectorService:
|
|||
"sources": sources_list,
|
||||
}
|
||||
|
||||
|
||||
return result_object, crawled_urls_chunks
|
||||
|
||||
async def search_files(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_files(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search for files and return both the source information and langchain documents
|
||||
|
||||
|
|
@ -81,16 +73,16 @@ class ConnectorService:
|
|||
document_type="FILE"
|
||||
)
|
||||
|
||||
# Map crawled_urls_chunks to the required format
|
||||
mapped_sources = {}
|
||||
# Process each chunk and create sources directly without deduplication
|
||||
sources_list = []
|
||||
for i, chunk in enumerate(files_chunks):
|
||||
#Fix for UI
|
||||
# Fix for UI
|
||||
files_chunks[i]['document']['id'] = self.source_id_counter
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
metadata = document.get('metadata', {})
|
||||
|
||||
# Create a mapped source entry
|
||||
# Create a source entry
|
||||
source = {
|
||||
"id": self.source_id_counter,
|
||||
"title": document.get('title', 'Untitled Document'),
|
||||
|
|
@ -99,14 +91,7 @@ class ConnectorService:
|
|||
}
|
||||
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Use a unique identifier for tracking unique sources
|
||||
source_key = source.get("url") or source.get("title")
|
||||
if source_key and source_key not in mapped_sources:
|
||||
mapped_sources[source_key] = source
|
||||
|
||||
# Convert to list of sources
|
||||
sources_list = list(mapped_sources.values())
|
||||
sources_list.append(source)
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
|
|
@ -118,7 +103,7 @@ class ConnectorService:
|
|||
|
||||
return result_object, files_chunks
|
||||
|
||||
async def get_connector_by_type(self, user_id: int, connector_type: SearchSourceConnectorType) -> Optional[SearchSourceConnector]:
|
||||
async def get_connector_by_type(self, user_id: str, connector_type: SearchSourceConnectorType) -> Optional[SearchSourceConnector]:
|
||||
"""
|
||||
Get a connector by type for a specific user
|
||||
|
||||
|
|
@ -138,7 +123,7 @@ class ConnectorService:
|
|||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def search_tavily(self, user_query: str, user_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_tavily(self, user_query: str, user_id: str, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search using Tavily API and return both the source information and documents
|
||||
|
||||
|
|
@ -177,13 +162,10 @@ class ConnectorService:
|
|||
# Extract results from Tavily response
|
||||
tavily_results = response.get("results", [])
|
||||
|
||||
# Map Tavily results to the required format
|
||||
# Process each result and create sources directly without deduplication
|
||||
sources_list = []
|
||||
documents = []
|
||||
|
||||
# Start IDs from 1000 to avoid conflicts with other connectors
|
||||
base_id = 100
|
||||
|
||||
for i, result in enumerate(tavily_results):
|
||||
|
||||
# Create a source entry
|
||||
|
|
@ -234,7 +216,7 @@ class ConnectorService:
|
|||
"sources": [],
|
||||
}, []
|
||||
|
||||
async def search_slack(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_slack(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search for slack and return both the source information and langchain documents
|
||||
|
||||
|
|
@ -249,10 +231,10 @@ class ConnectorService:
|
|||
document_type="SLACK_CONNECTOR"
|
||||
)
|
||||
|
||||
# Map slack_chunks to the required format
|
||||
mapped_sources = {}
|
||||
# Process each chunk and create sources directly without deduplication
|
||||
sources_list = []
|
||||
for i, chunk in enumerate(slack_chunks):
|
||||
#Fix for UI
|
||||
# Fix for UI
|
||||
slack_chunks[i]['document']['id'] = self.source_id_counter
|
||||
# Extract document metadata
|
||||
document = chunk.get('document', {})
|
||||
|
|
@ -286,14 +268,7 @@ class ConnectorService:
|
|||
}
|
||||
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Use channel_id and content as a unique identifier for tracking unique sources
|
||||
source_key = f"{channel_id}_{chunk.get('chunk_id', i)}"
|
||||
if source_key and source_key not in mapped_sources:
|
||||
mapped_sources[source_key] = source
|
||||
|
||||
# Convert to list of sources
|
||||
sources_list = list(mapped_sources.values())
|
||||
sources_list.append(source)
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
|
|
@ -305,7 +280,7 @@ class ConnectorService:
|
|||
|
||||
return result_object, slack_chunks
|
||||
|
||||
async def search_notion(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_notion(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search for Notion pages and return both the source information and langchain documents
|
||||
|
||||
|
|
@ -326,8 +301,8 @@ class ConnectorService:
|
|||
document_type="NOTION_CONNECTOR"
|
||||
)
|
||||
|
||||
# Map notion_chunks to the required format
|
||||
mapped_sources = {}
|
||||
# Process each chunk and create sources directly without deduplication
|
||||
sources_list = []
|
||||
for i, chunk in enumerate(notion_chunks):
|
||||
# Fix for UI
|
||||
notion_chunks[i]['document']['id'] = self.source_id_counter
|
||||
|
|
@ -365,14 +340,7 @@ class ConnectorService:
|
|||
}
|
||||
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Use page_id and content as a unique identifier for tracking unique sources
|
||||
source_key = f"{page_id}_{chunk.get('chunk_id', i)}"
|
||||
if source_key and source_key not in mapped_sources:
|
||||
mapped_sources[source_key] = source
|
||||
|
||||
# Convert to list of sources
|
||||
sources_list = list(mapped_sources.values())
|
||||
sources_list.append(source)
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
|
|
@ -384,7 +352,7 @@ class ConnectorService:
|
|||
|
||||
return result_object, notion_chunks
|
||||
|
||||
async def search_extension(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_extension(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search for extension data and return both the source information and langchain documents
|
||||
|
||||
|
|
@ -405,8 +373,8 @@ class ConnectorService:
|
|||
document_type="EXTENSION"
|
||||
)
|
||||
|
||||
# Map extension_chunks to the required format
|
||||
mapped_sources = {}
|
||||
# Process each chunk and create sources directly without deduplication
|
||||
sources_list = []
|
||||
for i, chunk in enumerate(extension_chunks):
|
||||
# Fix for UI
|
||||
extension_chunks[i]['document']['id'] = self.source_id_counter
|
||||
|
|
@ -462,14 +430,7 @@ class ConnectorService:
|
|||
}
|
||||
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Use URL and timestamp as a unique identifier for tracking unique sources
|
||||
source_key = f"{webpage_url}_{visit_date}"
|
||||
if source_key and source_key not in mapped_sources:
|
||||
mapped_sources[source_key] = source
|
||||
|
||||
# Convert to list of sources
|
||||
sources_list = list(mapped_sources.values())
|
||||
sources_list.append(source)
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
|
|
@ -481,7 +442,7 @@ class ConnectorService:
|
|||
|
||||
return result_object, extension_chunks
|
||||
|
||||
async def search_youtube(self, user_query: str, user_id: int, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
async def search_youtube(self, user_query: str, user_id: str, search_space_id: int, top_k: int = 20) -> tuple:
|
||||
"""
|
||||
Search for YouTube videos and return both the source information and langchain documents
|
||||
|
||||
|
|
@ -502,8 +463,8 @@ class ConnectorService:
|
|||
document_type="YOUTUBE_VIDEO"
|
||||
)
|
||||
|
||||
# Map youtube_chunks to the required format
|
||||
mapped_sources = {}
|
||||
# Process each chunk and create sources directly without deduplication
|
||||
sources_list = []
|
||||
for i, chunk in enumerate(youtube_chunks):
|
||||
# Fix for UI
|
||||
youtube_chunks[i]['document']['id'] = self.source_id_counter
|
||||
|
|
@ -541,18 +502,11 @@ class ConnectorService:
|
|||
}
|
||||
|
||||
self.source_id_counter += 1
|
||||
|
||||
# Use video_id as a unique identifier for tracking unique sources
|
||||
source_key = video_id or f"youtube_{i}"
|
||||
if source_key and source_key not in mapped_sources:
|
||||
mapped_sources[source_key] = source
|
||||
|
||||
# Convert to list of sources
|
||||
sources_list = list(mapped_sources.values())
|
||||
sources_list.append(source)
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
"id": 6, # Assign a unique ID for the YouTube connector
|
||||
"id": 7, # Assign a unique ID for the YouTube connector
|
||||
"name": "YouTube Videos",
|
||||
"type": "YOUTUBE_VIDEO",
|
||||
"sources": sources_list,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue