2026-03-22 00:43:53 +05:30
|
|
|
import contextlib
|
2026-02-27 16:32:30 -08:00
|
|
|
import time
|
2025-12-12 02:42:20 -08:00
|
|
|
from datetime import datetime
|
|
|
|
|
|
2026-02-27 16:32:30 -08:00
|
|
|
from app.utils.perf import get_perf_logger
|
|
|
|
|
|
2026-02-28 19:40:24 -08:00
|
|
|
_MAX_FETCH_CHUNKS_PER_DOC = 30
|
|
|
|
|
|
2025-12-12 02:42:20 -08:00
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
class DocumentHybridSearchRetriever:
|
|
|
|
|
def __init__(self, db_session):
|
|
|
|
|
"""
|
|
|
|
|
Initialize the hybrid search retriever with a database session.
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
Args:
|
|
|
|
|
db_session: SQLAlchemy AsyncSession from FastAPI dependency injection
|
|
|
|
|
"""
|
|
|
|
|
self.db_session = db_session
|
|
|
|
|
|
2025-07-24 14:43:48 -07:00
|
|
|
async def vector_search(
|
|
|
|
|
self,
|
|
|
|
|
query_text: str,
|
|
|
|
|
top_k: int,
|
2025-11-27 22:45:04 -08:00
|
|
|
search_space_id: int,
|
2025-12-12 02:42:20 -08:00
|
|
|
start_date: datetime | None = None,
|
|
|
|
|
end_date: datetime | None = None,
|
2025-07-24 14:43:48 -07:00
|
|
|
) -> list:
|
2025-03-20 22:56:24 -07:00
|
|
|
"""
|
|
|
|
|
Perform vector similarity search on documents.
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
Args:
|
|
|
|
|
query_text: The search query text
|
|
|
|
|
top_k: Number of results to return
|
2025-11-27 22:45:04 -08:00
|
|
|
search_space_id: The search space ID to search within
|
2025-12-12 02:42:20 -08:00
|
|
|
start_date: Optional start date for filtering documents by updated_at
|
|
|
|
|
end_date: Optional end date for filtering documents by updated_at
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
Returns:
|
|
|
|
|
List of documents sorted by vector similarity
|
|
|
|
|
"""
|
2025-07-24 14:43:48 -07:00
|
|
|
from sqlalchemy import select
|
2025-03-20 22:56:24 -07:00
|
|
|
from sqlalchemy.orm import joinedload
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
from app.config import config
|
2025-11-27 22:45:04 -08:00
|
|
|
from app.db import Document
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2026-02-27 16:32:30 -08:00
|
|
|
perf = get_perf_logger()
|
|
|
|
|
t0 = time.perf_counter()
|
|
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
# Get embedding for the query
|
|
|
|
|
embedding_model = config.embedding_model_instance
|
|
|
|
|
query_embedding = embedding_model.embed(query_text)
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-11-27 22:45:04 -08:00
|
|
|
# Build the query filtered by search space
|
2025-03-20 22:56:24 -07:00
|
|
|
query = (
|
|
|
|
|
select(Document)
|
|
|
|
|
.options(joinedload(Document.search_space))
|
2025-11-27 22:45:04 -08:00
|
|
|
.where(Document.search_space_id == search_space_id)
|
2025-03-20 22:56:24 -07:00
|
|
|
)
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-12-12 02:42:20 -08:00
|
|
|
# Add time-based filtering if provided
|
|
|
|
|
if start_date is not None:
|
|
|
|
|
query = query.where(Document.updated_at >= start_date)
|
|
|
|
|
if end_date is not None:
|
|
|
|
|
query = query.where(Document.updated_at <= end_date)
|
|
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
# Add vector similarity ordering
|
2025-07-24 14:43:48 -07:00
|
|
|
query = query.order_by(Document.embedding.op("<=>")(query_embedding)).limit(
|
|
|
|
|
top_k
|
2025-03-20 22:56:24 -07:00
|
|
|
)
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
# Execute the query
|
|
|
|
|
result = await self.db_session.execute(query)
|
|
|
|
|
documents = result.scalars().all()
|
2026-02-27 16:32:30 -08:00
|
|
|
perf.info(
|
|
|
|
|
"[doc_search] vector_search in %.3fs results=%d space=%d",
|
2026-02-27 17:19:25 -08:00
|
|
|
time.perf_counter() - t0,
|
|
|
|
|
len(documents),
|
|
|
|
|
search_space_id,
|
2026-02-27 16:32:30 -08:00
|
|
|
)
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
return documents
|
|
|
|
|
|
2025-07-24 14:43:48 -07:00
|
|
|
async def full_text_search(
|
|
|
|
|
self,
|
|
|
|
|
query_text: str,
|
|
|
|
|
top_k: int,
|
2025-11-27 22:45:04 -08:00
|
|
|
search_space_id: int,
|
2025-12-12 02:42:20 -08:00
|
|
|
start_date: datetime | None = None,
|
|
|
|
|
end_date: datetime | None = None,
|
2025-07-24 14:43:48 -07:00
|
|
|
) -> list:
|
2025-03-20 22:56:24 -07:00
|
|
|
"""
|
|
|
|
|
Perform full-text keyword search on documents.
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
Args:
|
|
|
|
|
query_text: The search query text
|
|
|
|
|
top_k: Number of results to return
|
2025-11-27 22:45:04 -08:00
|
|
|
search_space_id: The search space ID to search within
|
2025-12-12 02:42:20 -08:00
|
|
|
start_date: Optional start date for filtering documents by updated_at
|
|
|
|
|
end_date: Optional end date for filtering documents by updated_at
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
Returns:
|
|
|
|
|
List of documents sorted by text relevance
|
|
|
|
|
"""
|
2025-07-24 14:43:48 -07:00
|
|
|
from sqlalchemy import func, select
|
2025-03-20 22:56:24 -07:00
|
|
|
from sqlalchemy.orm import joinedload
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-11-27 22:45:04 -08:00
|
|
|
from app.db import Document
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2026-02-27 16:32:30 -08:00
|
|
|
perf = get_perf_logger()
|
|
|
|
|
t0 = time.perf_counter()
|
|
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
# Create tsvector and tsquery for PostgreSQL full-text search
|
2025-07-24 14:43:48 -07:00
|
|
|
tsvector = func.to_tsvector("english", Document.content)
|
|
|
|
|
tsquery = func.plainto_tsquery("english", query_text)
|
|
|
|
|
|
2025-11-27 22:45:04 -08:00
|
|
|
# Build the query filtered by search space
|
2025-03-20 22:56:24 -07:00
|
|
|
query = (
|
|
|
|
|
select(Document)
|
|
|
|
|
.options(joinedload(Document.search_space))
|
2025-11-27 22:45:04 -08:00
|
|
|
.where(Document.search_space_id == search_space_id)
|
2025-07-24 14:43:48 -07:00
|
|
|
.where(
|
|
|
|
|
tsvector.op("@@")(tsquery)
|
|
|
|
|
) # Only include results that match the query
|
2025-03-20 22:56:24 -07:00
|
|
|
)
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-12-12 02:42:20 -08:00
|
|
|
# Add time-based filtering if provided
|
|
|
|
|
if start_date is not None:
|
|
|
|
|
query = query.where(Document.updated_at >= start_date)
|
|
|
|
|
if end_date is not None:
|
|
|
|
|
query = query.where(Document.updated_at <= end_date)
|
|
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
# Add text search ranking
|
2025-07-24 14:43:48 -07:00
|
|
|
query = query.order_by(func.ts_rank_cd(tsvector, tsquery).desc()).limit(top_k)
|
|
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
# Execute the query
|
|
|
|
|
result = await self.db_session.execute(query)
|
|
|
|
|
documents = result.scalars().all()
|
2026-02-27 16:32:30 -08:00
|
|
|
perf.info(
|
|
|
|
|
"[doc_search] full_text_search in %.3fs results=%d space=%d",
|
2026-02-27 17:19:25 -08:00
|
|
|
time.perf_counter() - t0,
|
|
|
|
|
len(documents),
|
|
|
|
|
search_space_id,
|
2026-02-27 16:32:30 -08:00
|
|
|
)
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
return documents
|
|
|
|
|
|
2025-07-24 14:43:48 -07:00
|
|
|
async def hybrid_search(
|
|
|
|
|
self,
|
|
|
|
|
query_text: str,
|
|
|
|
|
top_k: int,
|
2025-11-27 22:45:04 -08:00
|
|
|
search_space_id: int,
|
2026-03-19 05:08:21 +05:30
|
|
|
document_type: str | list[str] | None = None,
|
2025-12-12 02:42:20 -08:00
|
|
|
start_date: datetime | None = None,
|
|
|
|
|
end_date: datetime | None = None,
|
2026-02-27 17:19:25 -08:00
|
|
|
query_embedding: list | None = None,
|
2025-07-24 14:43:48 -07:00
|
|
|
) -> list:
|
2025-03-20 22:56:24 -07:00
|
|
|
"""
|
2025-12-14 22:07:31 -08:00
|
|
|
Hybrid search that returns **documents** (not individual chunks).
|
|
|
|
|
|
|
|
|
|
Each returned item is a document-grouped dict that preserves real DB chunk IDs so
|
|
|
|
|
downstream agents can cite with `[citation:<chunk_id>]`.
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
Args:
|
|
|
|
|
query_text: The search query text
|
2025-12-14 22:07:31 -08:00
|
|
|
top_k: Number of documents to return
|
2025-11-27 22:45:04 -08:00
|
|
|
search_space_id: The search space ID to search within
|
2025-03-20 22:56:24 -07:00
|
|
|
document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL")
|
2025-12-12 02:42:20 -08:00
|
|
|
start_date: Optional start date for filtering documents by updated_at
|
|
|
|
|
end_date: Optional end date for filtering documents by updated_at
|
2026-02-27 17:19:25 -08:00
|
|
|
query_embedding: Pre-computed embedding vector. If None, will be computed here.
|
2025-03-20 22:56:24 -07:00
|
|
|
"""
|
2025-07-24 14:43:48 -07:00
|
|
|
from sqlalchemy import func, select, text
|
2025-03-20 22:56:24 -07:00
|
|
|
from sqlalchemy.orm import joinedload
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
from app.config import config
|
2025-12-14 22:07:31 -08:00
|
|
|
from app.db import Chunk, Document, DocumentType
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2026-02-27 16:32:30 -08:00
|
|
|
perf = get_perf_logger()
|
|
|
|
|
t0 = time.perf_counter()
|
|
|
|
|
|
2026-02-27 17:19:25 -08:00
|
|
|
if query_embedding is None:
|
|
|
|
|
embedding_model = config.embedding_model_instance
|
|
|
|
|
query_embedding = embedding_model.embed(query_text)
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-12-14 22:07:31 -08:00
|
|
|
# RRF constants
|
|
|
|
|
k = 60
|
|
|
|
|
n_results = top_k * 2 # Fetch extra documents for better fusion
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
# Create tsvector and tsquery for PostgreSQL full-text search
|
2025-07-24 14:43:48 -07:00
|
|
|
tsvector = func.to_tsvector("english", Document.content)
|
|
|
|
|
tsquery = func.plainto_tsquery("english", query_text)
|
|
|
|
|
|
2026-03-10 01:26:37 -07:00
|
|
|
# Base conditions for document filtering - search space is required.
|
|
|
|
|
# Exclude documents in "deleting" state (background deletion in progress).
|
|
|
|
|
base_conditions = [
|
|
|
|
|
Document.search_space_id == search_space_id,
|
|
|
|
|
func.coalesce(Document.status["state"].astext, "ready") != "deleting",
|
|
|
|
|
]
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2026-03-19 05:08:21 +05:30
|
|
|
# Add document type filter if provided (single string or list of strings)
|
2025-03-20 22:56:24 -07:00
|
|
|
if document_type is not None:
|
2026-03-21 13:20:13 +05:30
|
|
|
type_list = (
|
|
|
|
|
document_type if isinstance(document_type, list) else [document_type]
|
|
|
|
|
)
|
2026-03-19 05:08:21 +05:30
|
|
|
doc_type_enums = []
|
|
|
|
|
for dt in type_list:
|
|
|
|
|
if isinstance(dt, str):
|
2026-03-22 00:43:53 +05:30
|
|
|
with contextlib.suppress(KeyError):
|
2026-03-19 05:08:21 +05:30
|
|
|
doc_type_enums.append(DocumentType[dt])
|
|
|
|
|
else:
|
|
|
|
|
doc_type_enums.append(dt)
|
|
|
|
|
if not doc_type_enums:
|
|
|
|
|
return []
|
|
|
|
|
if len(doc_type_enums) == 1:
|
|
|
|
|
base_conditions.append(Document.document_type == doc_type_enums[0])
|
2025-03-20 22:56:24 -07:00
|
|
|
else:
|
2026-03-19 05:08:21 +05:30
|
|
|
base_conditions.append(Document.document_type.in_(doc_type_enums))
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-12-12 02:42:20 -08:00
|
|
|
# Add time-based filtering if provided
|
|
|
|
|
if start_date is not None:
|
|
|
|
|
base_conditions.append(Document.updated_at >= start_date)
|
|
|
|
|
if end_date is not None:
|
|
|
|
|
base_conditions.append(Document.updated_at <= end_date)
|
|
|
|
|
|
2025-11-27 22:45:04 -08:00
|
|
|
# CTE for semantic search filtered by search space
|
|
|
|
|
semantic_search_cte = select(
|
|
|
|
|
Document.id,
|
|
|
|
|
func.rank()
|
|
|
|
|
.over(order_by=Document.embedding.op("<=>")(query_embedding))
|
|
|
|
|
.label("rank"),
|
|
|
|
|
).where(*base_conditions)
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
semantic_search_cte = (
|
2025-07-24 14:43:48 -07:00
|
|
|
semantic_search_cte.order_by(Document.embedding.op("<=>")(query_embedding))
|
2025-03-20 22:56:24 -07:00
|
|
|
.limit(n_results)
|
|
|
|
|
.cte("semantic_search")
|
|
|
|
|
)
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-11-27 22:45:04 -08:00
|
|
|
# CTE for keyword search filtered by search space
|
2025-03-20 22:56:24 -07:00
|
|
|
keyword_search_cte = (
|
|
|
|
|
select(
|
|
|
|
|
Document.id,
|
2025-07-24 14:43:48 -07:00
|
|
|
func.rank()
|
|
|
|
|
.over(order_by=func.ts_rank_cd(tsvector, tsquery).desc())
|
|
|
|
|
.label("rank"),
|
2025-03-20 22:56:24 -07:00
|
|
|
)
|
|
|
|
|
.where(*base_conditions)
|
|
|
|
|
.where(tsvector.op("@@")(tsquery))
|
|
|
|
|
)
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
keyword_search_cte = (
|
2025-07-24 14:43:48 -07:00
|
|
|
keyword_search_cte.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
|
2025-03-20 22:56:24 -07:00
|
|
|
.limit(n_results)
|
|
|
|
|
.cte("keyword_search")
|
|
|
|
|
)
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
# Final combined query using a FULL OUTER JOIN with RRF scoring
|
|
|
|
|
final_query = (
|
|
|
|
|
select(
|
|
|
|
|
Document,
|
|
|
|
|
(
|
2025-07-24 14:43:48 -07:00
|
|
|
func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0)
|
|
|
|
|
+ func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0)
|
|
|
|
|
).label("score"),
|
2025-03-20 22:56:24 -07:00
|
|
|
)
|
|
|
|
|
.select_from(
|
|
|
|
|
semantic_search_cte.outerjoin(
|
2025-07-24 14:43:48 -07:00
|
|
|
keyword_search_cte,
|
2025-03-20 22:56:24 -07:00
|
|
|
semantic_search_cte.c.id == keyword_search_cte.c.id,
|
2025-07-24 14:43:48 -07:00
|
|
|
full=True,
|
2025-03-20 22:56:24 -07:00
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
.join(
|
|
|
|
|
Document,
|
2025-07-24 14:43:48 -07:00
|
|
|
Document.id
|
|
|
|
|
== func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id),
|
2025-03-20 22:56:24 -07:00
|
|
|
)
|
|
|
|
|
.options(joinedload(Document.search_space))
|
|
|
|
|
.order_by(text("score DESC"))
|
|
|
|
|
.limit(top_k)
|
|
|
|
|
)
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
# Execute the query
|
|
|
|
|
result = await self.db_session.execute(final_query)
|
|
|
|
|
documents_with_scores = result.all()
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-03-20 22:56:24 -07:00
|
|
|
# If no results were found, return an empty list
|
|
|
|
|
if not documents_with_scores:
|
|
|
|
|
return []
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-12-14 22:07:31 -08:00
|
|
|
# Collect document IDs for chunk fetching
|
|
|
|
|
doc_ids: list[int] = [doc.id for doc, _score in documents_with_scores]
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2026-02-28 19:40:24 -08:00
|
|
|
# Fetch chunks for these documents, capped per document to avoid
|
|
|
|
|
# loading hundreds of chunks for a single large file.
|
2025-12-14 22:07:31 -08:00
|
|
|
chunks_query = (
|
|
|
|
|
select(Chunk)
|
|
|
|
|
.options(joinedload(Chunk.document))
|
|
|
|
|
.where(Chunk.document_id.in_(doc_ids))
|
|
|
|
|
.order_by(Chunk.document_id, Chunk.id)
|
|
|
|
|
)
|
|
|
|
|
chunks_result = await self.db_session.execute(chunks_query)
|
2026-02-28 19:40:24 -08:00
|
|
|
raw_chunks = chunks_result.scalars().all()
|
|
|
|
|
|
|
|
|
|
doc_chunk_counts: dict[int, int] = {}
|
|
|
|
|
chunks: list = []
|
|
|
|
|
for chunk in raw_chunks:
|
|
|
|
|
did = chunk.document_id
|
|
|
|
|
count = doc_chunk_counts.get(did, 0)
|
|
|
|
|
if count < _MAX_FETCH_CHUNKS_PER_DOC:
|
|
|
|
|
chunks.append(chunk)
|
|
|
|
|
doc_chunk_counts[did] = count + 1
|
2025-12-14 22:07:31 -08:00
|
|
|
|
|
|
|
|
# Assemble doc-grouped results
|
|
|
|
|
doc_map: dict[int, dict] = {
|
|
|
|
|
doc.id: {
|
|
|
|
|
"document_id": doc.id,
|
|
|
|
|
"content": "",
|
|
|
|
|
"score": float(score),
|
|
|
|
|
"chunks": [],
|
|
|
|
|
"document": {
|
|
|
|
|
"id": doc.id,
|
|
|
|
|
"title": doc.title,
|
|
|
|
|
"document_type": doc.document_type.value
|
|
|
|
|
if getattr(doc, "document_type", None)
|
|
|
|
|
else None,
|
|
|
|
|
"metadata": doc.document_metadata or {},
|
|
|
|
|
},
|
|
|
|
|
"source": doc.document_type.value
|
|
|
|
|
if getattr(doc, "document_type", None)
|
|
|
|
|
else None,
|
|
|
|
|
}
|
|
|
|
|
for doc, score in documents_with_scores
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for chunk in chunks:
|
|
|
|
|
doc_id = chunk.document_id
|
|
|
|
|
if doc_id not in doc_map:
|
|
|
|
|
continue
|
|
|
|
|
doc_map[doc_id]["chunks"].append(
|
|
|
|
|
{"chunk_id": chunk.id, "content": chunk.content}
|
|
|
|
|
)
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2025-12-14 22:07:31 -08:00
|
|
|
# Fill concatenated content (useful for reranking)
|
|
|
|
|
final_docs: list[dict] = []
|
|
|
|
|
for doc_id in doc_ids:
|
|
|
|
|
entry = doc_map[doc_id]
|
|
|
|
|
entry["content"] = "\n\n".join(
|
|
|
|
|
c["content"] for c in entry.get("chunks", []) if c.get("content")
|
2025-07-24 14:43:48 -07:00
|
|
|
)
|
2025-12-14 22:07:31 -08:00
|
|
|
final_docs.append(entry)
|
2025-07-24 14:43:48 -07:00
|
|
|
|
2026-02-27 16:32:30 -08:00
|
|
|
perf.info(
|
|
|
|
|
"[doc_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s",
|
2026-02-27 17:19:25 -08:00
|
|
|
time.perf_counter() - t0,
|
|
|
|
|
len(final_docs),
|
|
|
|
|
search_space_id,
|
|
|
|
|
document_type,
|
2026-02-27 16:32:30 -08:00
|
|
|
)
|
2025-12-14 22:07:31 -08:00
|
|
|
return final_docs
|