import time from datetime import datetime from app.utils.perf import get_perf_logger _MAX_FETCH_CHUNKS_PER_DOC = 30 class ChucksHybridSearchRetriever: def __init__(self, db_session): """ Initialize the hybrid search retriever with a database session. Args: db_session: SQLAlchemy AsyncSession from FastAPI dependency injection """ self.db_session = db_session async def vector_search( self, query_text: str, top_k: int, search_space_id: int, start_date: datetime | None = None, end_date: datetime | None = None, ) -> list: """ Perform vector similarity search on chunks. Args: query_text: The search query text top_k: Number of results to return search_space_id: The search space ID to search within start_date: Optional start date for filtering documents by updated_at end_date: Optional end date for filtering documents by updated_at Returns: List of chunks sorted by vector similarity """ from sqlalchemy import select from sqlalchemy.orm import joinedload from app.config import config from app.db import Chunk, Document perf = get_perf_logger() t0 = time.perf_counter() # Get embedding for the query embedding_model = config.embedding_model_instance t_embed = time.perf_counter() query_embedding = embedding_model.embed(query_text) perf.debug( "[chunk_search] vector_search embedding in %.3fs", time.perf_counter() - t_embed, ) # Build the query filtered by search space query = ( select(Chunk) .options(joinedload(Chunk.document).joinedload(Document.search_space)) .join(Document, Chunk.document_id == Document.id) .where(Document.search_space_id == search_space_id) ) # Add time-based filtering if provided if start_date is not None: query = query.where(Document.updated_at >= start_date) if end_date is not None: query = query.where(Document.updated_at <= end_date) # Add vector similarity ordering query = query.order_by(Chunk.embedding.op("<=>")(query_embedding)).limit(top_k) # Execute the query t_db = time.perf_counter() result = await self.db_session.execute(query) chunks = result.scalars().all() perf.info( "[chunk_search] vector_search DB query in %.3fs results=%d (total %.3fs) space=%d", time.perf_counter() - t_db, len(chunks), time.perf_counter() - t0, search_space_id, ) return chunks async def full_text_search( self, query_text: str, top_k: int, search_space_id: int, start_date: datetime | None = None, end_date: datetime | None = None, ) -> list: """ Perform full-text keyword search on chunks. Args: query_text: The search query text top_k: Number of results to return search_space_id: The search space ID to search within start_date: Optional start date for filtering documents by updated_at end_date: Optional end date for filtering documents by updated_at Returns: List of chunks sorted by text relevance """ from sqlalchemy import func, select from sqlalchemy.orm import joinedload from app.db import Chunk, Document perf = get_perf_logger() t0 = time.perf_counter() # Create tsvector and tsquery for PostgreSQL full-text search tsvector = func.to_tsvector("english", Chunk.content) tsquery = func.plainto_tsquery("english", query_text) # Build the query filtered by search space query = ( select(Chunk) .options(joinedload(Chunk.document).joinedload(Document.search_space)) .join(Document, Chunk.document_id == Document.id) .where(Document.search_space_id == search_space_id) .where( tsvector.op("@@")(tsquery) ) # Only include results that match the query ) # Add time-based filtering if provided if start_date is not None: query = query.where(Document.updated_at >= start_date) if end_date is not None: query = query.where(Document.updated_at <= end_date) # Add text search ranking query = query.order_by(func.ts_rank_cd(tsvector, tsquery).desc()).limit(top_k) # Execute the query result = await self.db_session.execute(query) chunks = result.scalars().all() perf.info( "[chunk_search] full_text_search in %.3fs results=%d space=%d", time.perf_counter() - t0, len(chunks), search_space_id, ) return chunks async def hybrid_search( self, query_text: str, top_k: int, search_space_id: int, document_type: str | None = None, start_date: datetime | None = None, end_date: datetime | None = None, query_embedding: list | None = None, ) -> list: """ Hybrid search that returns **documents** (not individual chunks). Each returned item is a document-grouped dict that preserves real DB chunk IDs so downstream agents can cite with `[citation:]`. Args: query_text: The search query text top_k: Number of documents to return search_space_id: The search space ID to search within document_type: Optional document type to filter results (e.g., "FILE", "CRAWLED_URL") start_date: Optional start date for filtering documents by updated_at end_date: Optional end date for filtering documents by updated_at query_embedding: Pre-computed embedding vector. If None, will be computed here. Returns: List of dictionaries containing document data and relevance scores. Each dict contains: - chunk_id: a "primary" chunk id for compatibility (best-ranked chunk for the doc) - content: concatenated chunk content (useful for reranking) - chunks: list[{chunk_id, content}] for citation-aware prompting - document: {id, title, document_type, metadata} """ from sqlalchemy import func, select, text from sqlalchemy.orm import joinedload from app.config import config from app.db import Chunk, Document, DocumentType perf = get_perf_logger() t0 = time.perf_counter() if query_embedding is None: embedding_model = config.embedding_model_instance t_embed = time.perf_counter() query_embedding = embedding_model.embed(query_text) perf.debug( "[chunk_search] hybrid_search embedding in %.3fs", time.perf_counter() - t_embed, ) # RRF constants k = 60 n_results = top_k * 5 # Fetch extra chunks for better document-level fusion # Create tsvector and tsquery for PostgreSQL full-text search tsvector = func.to_tsvector("english", Chunk.content) tsquery = func.plainto_tsquery("english", query_text) # Base conditions for chunk filtering - search space is required base_conditions = [Document.search_space_id == search_space_id] # Add document type filter if provided if document_type is not None: # Convert string to enum value if needed if isinstance(document_type, str): try: doc_type_enum = DocumentType[document_type] base_conditions.append(Document.document_type == doc_type_enum) except KeyError: # If the document type doesn't exist in the enum, return empty results return [] else: base_conditions.append(Document.document_type == document_type) # Add time-based filtering if provided if start_date is not None: base_conditions.append(Document.updated_at >= start_date) if end_date is not None: base_conditions.append(Document.updated_at <= end_date) # CTE for semantic search filtered by search space semantic_search_cte = ( select( Chunk.id, func.rank() .over(order_by=Chunk.embedding.op("<=>")(query_embedding)) .label("rank"), ) .join(Document, Chunk.document_id == Document.id) .where(*base_conditions) ) semantic_search_cte = ( semantic_search_cte.order_by(Chunk.embedding.op("<=>")(query_embedding)) .limit(n_results) .cte("semantic_search") ) # CTE for keyword search filtered by search space keyword_search_cte = ( select( Chunk.id, func.rank() .over(order_by=func.ts_rank_cd(tsvector, tsquery).desc()) .label("rank"), ) .join(Document, Chunk.document_id == Document.id) .where(*base_conditions) .where(tsvector.op("@@")(tsquery)) ) keyword_search_cte = ( keyword_search_cte.order_by(func.ts_rank_cd(tsvector, tsquery).desc()) .limit(n_results) .cte("keyword_search") ) # Final combined query using a FULL OUTER JOIN with RRF scoring final_query = ( select( Chunk, ( func.coalesce(1.0 / (k + semantic_search_cte.c.rank), 0.0) + func.coalesce(1.0 / (k + keyword_search_cte.c.rank), 0.0) ).label("score"), ) .select_from( semantic_search_cte.outerjoin( keyword_search_cte, semantic_search_cte.c.id == keyword_search_cte.c.id, full=True, ) ) .join( Chunk, Chunk.id == func.coalesce(semantic_search_cte.c.id, keyword_search_cte.c.id), ) .options(joinedload(Chunk.document)) .order_by(text("score DESC")) .limit(top_k) ) # Execute the RRF query t_rrf = time.perf_counter() result = await self.db_session.execute(final_query) chunks_with_scores = result.all() perf.info( "[chunk_search] hybrid_search RRF query in %.3fs results=%d space=%d type=%s", time.perf_counter() - t_rrf, len(chunks_with_scores), search_space_id, document_type, ) # If no results were found, return an empty list if not chunks_with_scores: return [] # Convert to serializable dictionaries serialized_chunk_results: list[dict] = [] for chunk, score in chunks_with_scores: serialized_chunk_results.append( { "chunk_id": chunk.id, "content": chunk.content, "score": float(score), # Ensure score is a Python float "document": { "id": chunk.document.id, "title": chunk.document.title, "document_type": chunk.document.document_type.value if hasattr(chunk.document, "document_type") else None, "metadata": chunk.document.document_metadata, }, } ) # Group by document, preserving ranking order by best chunk rank doc_scores: dict[int, float] = {} doc_order: list[int] = [] for item in serialized_chunk_results: doc_id = item.get("document", {}).get("id") if doc_id is None: continue if doc_id not in doc_scores: doc_scores[doc_id] = item.get("score", 0.0) doc_order.append(doc_id) else: # Use the best score as doc score doc_scores[doc_id] = max(doc_scores[doc_id], item.get("score", 0.0)) # Keep only top_k documents by initial rank order. doc_ids = doc_order[:top_k] if not doc_ids: return [] # Fetch chunks for selected documents. We cap per document to avoid # loading hundreds of chunks for a single large file while still # ensuring the chunks that matched the RRF query are always included. chunk_query = ( select(Chunk) .options(joinedload(Chunk.document)) .join(Document, Chunk.document_id == Document.id) .where(Document.id.in_(doc_ids)) .where(*base_conditions) .order_by(Chunk.document_id, Chunk.id) ) chunks_result = await self.db_session.execute(chunk_query) raw_chunks = chunks_result.scalars().all() matched_chunk_ids: set[int] = { item["chunk_id"] for item in serialized_chunk_results } doc_chunk_counts: dict[int, int] = {} all_chunks: list = [] for chunk in raw_chunks: did = chunk.document_id count = doc_chunk_counts.get(did, 0) if chunk.id in matched_chunk_ids or count < _MAX_FETCH_CHUNKS_PER_DOC: all_chunks.append(chunk) doc_chunk_counts[did] = count + 1 # Assemble final doc-grouped results in the same order as doc_ids doc_map: dict[int, dict] = { doc_id: { "document_id": doc_id, "content": "", "score": float(doc_scores.get(doc_id, 0.0)), "chunks": [], "document": {}, "source": None, } for doc_id in doc_ids } for chunk in all_chunks: doc = chunk.document doc_id = doc.id if doc_id not in doc_map: continue doc_entry = doc_map[doc_id] doc_entry["document"] = { "id": doc.id, "title": doc.title, "document_type": doc.document_type.value if getattr(doc, "document_type", None) else None, "metadata": doc.document_metadata or {}, } doc_entry["source"] = ( doc.document_type.value if getattr(doc, "document_type", None) else None ) doc_entry["chunks"].append({"chunk_id": chunk.id, "content": chunk.content}) # Fill concatenated content (useful for reranking) final_docs: list[dict] = [] for doc_id in doc_ids: entry = doc_map[doc_id] entry["content"] = "\n\n".join( c["content"] for c in entry.get("chunks", []) if c.get("content") ) final_docs.append(entry) perf.info( "[chunk_search] hybrid_search TOTAL in %.3fs docs=%d space=%d type=%s", time.perf_counter() - t0, len(final_docs), search_space_id, document_type, ) return final_docs