mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-26 09:16:22 +02:00
- Modified the document extraction and citation formatting to accommodate a new structure that includes a `chunks` list for each document. - Enhanced the citation format to reference `chunk_id` instead of `source_id`, ensuring accurate citations in the UI. - Updated various components, including the connector service and reranker service, to handle the new document format and maintain compatibility with existing functionalities. - Improved documentation and comments to reflect changes in the data structure and citation requirements.
357 lines
13 KiB
Python
357 lines
13 KiB
Python
from datetime import datetime
|
|
|
|
|
|
class ChucksHybridSearchRetriever:
|
|
def __init__(self, db_session):
|
|
"""
|
|
Initialize the hybrid search retriever with a database session.
|
|
|
|
Args:
|
|
db_session: SQLAlchemy AsyncSession from FastAPI dependency injection
|
|
"""
|
|
self.db_session = db_session
|
|
|
|
async def vector_search(
|
|
self,
|
|
query_text: str,
|
|
top_k: int,
|
|
search_space_id: int,
|
|
start_date: datetime | None = None,
|
|
end_date: datetime | None = None,
|
|
) -> list:
|
|
"""
|
|
Perform vector similarity search on chunks.
|
|
|
|
Args:
|
|
query_text: The search query text
|
|
top_k: Number of results to return
|
|
search_space_id: The search space ID to search within
|
|
start_date: Optional start date for filtering documents by updated_at
|
|
end_date: Optional end date for filtering documents by updated_at
|
|
|
|
Returns:
|
|
List of chunks sorted by vector similarity
|
|
"""
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import joinedload
|
|
|
|
from app.config import config
|
|
from app.db import Chunk, Document
|
|
|
|
# Get embedding for the query
|
|
embedding_model = config.embedding_model_instance
|
|
query_embedding = embedding_model.embed(query_text)
|
|
|
|
# Build the query filtered by search space
|
|
query = (
|
|
select(Chunk)
|
|
.options(joinedload(Chunk.document).joinedload(Document.search_space))
|
|
.join(Document, Chunk.document_id == Document.id)
|
|
.where(Document.search_space_id == search_space_id)
|
|
)
|
|
|
|
# Add time-based filtering if provided
|
|
if start_date is not None:
|
|
query = query.where(Document.updated_at >= start_date)
|
|
if end_date is not None:
|
|
query = query.where(Document.updated_at <= end_date)
|
|
|
|
# Add vector similarity ordering
|
|
query = query.order_by(Chunk.embedding.op("<=>")(query_embedding)).limit(top_k)
|
|
|
|
# Execute the query
|
|
result = await self.db_session.execute(query)
|
|
chunks = result.scalars().all()
|
|
|
|
return chunks
|
|
|
|
async def full_text_search(
|
|
self,
|
|
query_text: str,
|
|
top_k: int,
|
|
search_space_id: int,
|
|
start_date: datetime | None = None,
|
|
end_date: datetime | None = None,
|
|
) -> list:
|
|
"""
|
|
Perform full-text keyword search on chunks.
|
|
|
|
Args:
|
|
query_text: The search query text
|
|
top_k: Number of results to return
|
|
search_space_id: The search space ID to search within
|
|
start_date: Optional start date for filtering documents by updated_at
|
|
end_date: Optional end date for filtering documents by updated_at
|
|
|
|
Returns:
|
|
List of chunks sorted by text relevance
|
|
"""
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.orm import joinedload
|
|
|
|
from app.db import Chunk, Document
|
|
|
|
# Create tsvector and tsquery for PostgreSQL full-text search
|
|
tsvector = func.to_tsvector("english", Chunk.content)
|
|
tsquery = func.plainto_tsquery("english", query_text)
|
|
|
|
# Build the query filtered by search space
|
|
query = (
|
|
select(Chunk)
|
|
.options(joinedload(Chunk.document).joinedload(Document.search_space))
|
|
.join(Document, Chunk.document_id == Document.id)
|
|
.where(Document.search_space_id == search_space_id)
|
|
.where(
|
|
tsvector.op("@@")(tsquery)
|
|
) # Only include results that match the query
|
|
)
|
|
|
|
# Add time-based filtering if provided
|
|
if start_date is not None:
|
|
query = query.where(Document.updated_at >= start_date)
|
|
if end_date is not None:
|
|
query = query.where(Document.updated_at <= end_date)
|
|
|
|
# Add text search ranking
|
|
query = query.order_by(func.ts_rank_cd(tsvector, tsquery).desc()).limit(top_k)
|
|
|
|
# Execute the query
|
|
result = await self.db_session.execute(query)
|
|
chunks = result.scalars().all()
|
|
|
|
return chunks
|
|
|
|
async def hybrid_search(
|
|
self,
|
|
query_text: str,
|
|
top_k: int,
|
|
search_space_id: int,
|
|
document_type: str | None = None,
|
|
start_date: datetime | None = None,
|
|
end_date: datetime | None = None,
|
|
) -> list:
|
|
"""
|
|
Hybrid search that returns **documents** (not individual chunks).
|
|
|
|
Each returned item is a document-grouped dict that preserves real DB chunk IDs so
|
|
downstream agents can cite with `[citation:<chunk_id>]`.
|
|
|
|
Args:
|
|
query_text: The search query text
|
|
top_k: Number of documents to return
|
|
search_space_id: The search space ID to search within
|
|
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
|
|
start_date: Optional start date for filtering documents by updated_at
|
|
end_date: Optional end date for filtering documents by updated_at
|
|
|
|
Returns:
|
|
List of dictionaries containing document data and relevance scores. Each dict contains:
|
|
- chunk_id: a "primary" chunk id for compatibility (best-ranked chunk for the doc)
|
|
- content: concatenated chunk content (useful for reranking)
|
|
- chunks: list[{chunk_id, content}] for citation-aware prompting
|
|
- document: {id, title, document_type, metadata}
|
|
"""
|
|
from sqlalchemy import func, select, text
|
|
from sqlalchemy.orm import joinedload
|
|
|
|
from app.config import config
|
|
from app.db import Chunk, Document, DocumentType
|
|
|
|
# Get embedding for the query
|
|
embedding_model = config.embedding_model_instance
|
|
query_embedding = embedding_model.embed(query_text)
|
|
|
|
# RRF constants
|
|
k = 60
|
|
n_results = top_k * 5 # Fetch extra chunks for better document-level fusion
|
|
|
|
# Create tsvector and tsquery for PostgreSQL full-text search
|
|
tsvector = func.to_tsvector("english", Chunk.content)
|
|
tsquery = func.plainto_tsquery("english", query_text)
|
|
|
|
# Base conditions for chunk filtering - search space is required
|
|
base_conditions = [Document.search_space_id == search_space_id]
|
|
|
|
# Add document type filter if provided
|
|
if document_type is not None:
|
|
# Convert string to enum value if needed
|
|
if isinstance(document_type, str):
|
|
try:
|
|
doc_type_enum = DocumentType[document_type]
|
|
base_conditions.append(Document.document_type == doc_type_enum)
|
|
except KeyError:
|
|
# If the document type doesn't exist in the enum, return empty results
|
|
return []
|
|
else:
|
|
base_conditions.append(Document.document_type == document_type)
|
|
|
|
# Add time-based filtering if provided
|
|
if start_date is not None:
|
|
base_conditions.append(Document.updated_at >= start_date)
|
|
if end_date is not None:
|
|
base_conditions.append(Document.updated_at <= end_date)
|
|
|
|
# CTE for semantic search filtered by search space
|
|
semantic_search_cte = (
|
|
select(
|
|
Chunk.id,
|
|
func.rank()
|
|
.over(order_by=Chunk.embedding.op("<=>")(query_embedding))
|
|
.label("rank"),
|
|
)
|
|
.join(Document, Chunk.document_id == Document.id)
|
|
.where(*base_conditions)
|
|
)
|
|
|
|
semantic_search_cte = (
|
|
semantic_search_cte.order_by(Chunk.embedding.op("<=>")(query_embedding))
|
|
.limit(n_results)
|
|
.cte("semantic_search")
|
|
)
|
|
|
|
# CTE for keyword search filtered by search space
|
|
keyword_search_cte = (
|
|
select(
|
|
Chunk.id,
|
|
func.rank()
|
|
.over(order_by=func.ts_rank_cd(tsvector, tsquery).desc())
|
|
.label("rank"),
|
|
)
|
|
.join(Document, Chunk.document_id == Document.id)
|
|
.where(*base_conditions)
|
|
.where(tsvector.op("@@")(tsquery))
|
|
)
|
|
|
|
keyword_search_cte = (
|
|
keyword_search_cte.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
|
.limit(n_results)
|
|
.cte("keyword_search")
|
|
)
|
|
|
|
# Final combined query using a FULL OUTER JOIN with RRF scoring
|
|
final_query = (
|
|
select(
|
|
Chunk,
|
|
(
|
|
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0)
|
|
+ func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
|
|
).label("score"),
|
|
)
|
|
.select_from(
|
|
semantic_search_cte.outerjoin(
|
|
keyword_search_cte,
|
|
semantic_search_cte.c.id == keyword_search_cte.c.id,
|
|
full=True,
|
|
)
|
|
)
|
|
.join(
|
|
Chunk,
|
|
Chunk.id
|
|
== func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id),
|
|
)
|
|
.options(joinedload(Chunk.document))
|
|
.order_by(text("score DESC"))
|
|
.limit(top_k)
|
|
)
|
|
|
|
# Execute the query
|
|
result = await self.db_session.execute(final_query)
|
|
chunks_with_scores = result.all()
|
|
|
|
# If no results were found, return an empty list
|
|
if not chunks_with_scores:
|
|
return []
|
|
|
|
# Convert to serializable dictionaries
|
|
serialized_chunk_results: list[dict] = []
|
|
for chunk, score in chunks_with_scores:
|
|
serialized_chunk_results.append(
|
|
{
|
|
"chunk_id": chunk.id,
|
|
"content": chunk.content,
|
|
"score": float(score), # Ensure score is a Python float
|
|
"document": {
|
|
"id": chunk.document.id,
|
|
"title": chunk.document.title,
|
|
"document_type": chunk.document.document_type.value
|
|
if hasattr(chunk.document, "document_type")
|
|
else None,
|
|
"metadata": chunk.document.document_metadata,
|
|
},
|
|
}
|
|
)
|
|
|
|
# Group by document, preserving ranking order by best chunk rank
|
|
doc_scores: dict[int, float] = {}
|
|
doc_order: list[int] = []
|
|
for item in serialized_chunk_results:
|
|
doc_id = item.get("document", {}).get("id")
|
|
if doc_id is None:
|
|
continue
|
|
if doc_id not in doc_scores:
|
|
doc_scores[doc_id] = item.get("score", 0.0)
|
|
doc_order.append(doc_id)
|
|
else:
|
|
# Use the best score as doc score
|
|
doc_scores[doc_id] = max(doc_scores[doc_id], item.get("score", 0.0))
|
|
|
|
# Keep only top_k documents by initial rank order.
|
|
doc_ids = doc_order[:top_k]
|
|
if not doc_ids:
|
|
return []
|
|
|
|
# Fetch ALL chunks for selected documents in a single query so the final prompt can cite
|
|
# any chunk from those documents.
|
|
chunk_query = (
|
|
select(Chunk)
|
|
.options(joinedload(Chunk.document))
|
|
.join(Document, Chunk.document_id == Document.id)
|
|
.where(Document.id.in_(doc_ids))
|
|
.where(*base_conditions)
|
|
.order_by(Chunk.document_id, Chunk.id)
|
|
)
|
|
chunks_result = await self.db_session.execute(chunk_query)
|
|
all_chunks = chunks_result.scalars().all()
|
|
|
|
# Assemble final doc-grouped results in the same order as doc_ids
|
|
doc_map: dict[int, dict] = {
|
|
doc_id: {
|
|
"document_id": doc_id,
|
|
"content": "",
|
|
"score": float(doc_scores.get(doc_id, 0.0)),
|
|
"chunks": [],
|
|
"document": {},
|
|
"source": None,
|
|
}
|
|
for doc_id in doc_ids
|
|
}
|
|
|
|
for chunk in all_chunks:
|
|
doc = chunk.document
|
|
doc_id = doc.id
|
|
if doc_id not in doc_map:
|
|
continue
|
|
doc_entry = doc_map[doc_id]
|
|
doc_entry["document"] = {
|
|
"id": doc.id,
|
|
"title": doc.title,
|
|
"document_type": doc.document_type.value
|
|
if getattr(doc, "document_type", None)
|
|
else None,
|
|
"metadata": doc.document_metadata or {},
|
|
}
|
|
doc_entry["source"] = (
|
|
doc.document_type.value if getattr(doc, "document_type", None) else None
|
|
)
|
|
doc_entry["chunks"].append({"chunk_id": chunk.id, "content": chunk.content})
|
|
|
|
# Fill concatenated content (useful for reranking)
|
|
final_docs: list[dict] = []
|
|
for doc_id in doc_ids:
|
|
entry = doc_map[doc_id]
|
|
entry["content"] = "\n\n".join(
|
|
c["content"] for c in entry.get("chunks", []) if c.get("content")
|
|
)
|
|
final_docs.append(entry)
|
|
|
|
return final_docs
|