retrieval: add search scope models and hybrid chunk search

This commit is contained in:
CREDO23 2026-06-25 08:23:29 +02:00
parent 26a1431e87
commit 608192057f
2 changed files with 249 additions and 0 deletions

View file

@ -0,0 +1,202 @@
"""Hybrid (semantic + keyword) chunk search with reciprocal-rank fusion.
Only matched chunks are citable, so the fused result already holds every passage
shown there is no second per-document fetch. Returns the top ``top_k``
documents, each carrying its matched chunks in reading order.
"""
from __future__ import annotations
import asyncio
import contextlib
from sqlalchemy import func, select, text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.config import config
from app.db import Chunk, Document, DocumentType
from .models import ChunkHit, DocumentHit, SearchScope
_RRF_K = 60
_CANDIDATE_MULTIPLIER = 5 # fused-chunk pool size relative to top_k
_MAX_PASSAGES_PER_DOC = 12
async def search_chunks(
db_session: AsyncSession,
*,
search_space_id: int,
query: str,
scope: SearchScope,
top_k: int,
query_embedding: list[float] | None = None,
) -> list[DocumentHit]:
"""Top ``top_k`` documents for ``query`` within scope, each with its chunks."""
document_types = _resolve_document_types(scope.document_types)
if document_types == []: # types requested, none recognized → nothing matches
return []
if query_embedding is None:
query_embedding = await asyncio.to_thread(
config.embedding_model_instance.embed, query
)
conditions = _base_conditions(search_space_id, scope, document_types)
rows = await _fused_chunks(
db_session,
query=query,
query_embedding=query_embedding,
conditions=conditions,
candidate_pool=top_k * _CANDIDATE_MULTIPLIER,
)
return _group_into_documents(rows, top_k=top_k)
def _resolve_document_types(
raw: tuple[str, ...] | None,
) -> list[DocumentType] | None:
"""Map type names to enum members; ``None`` when unfiltered, ``[]`` if all unknown."""
if not raw:
return None
resolved: list[DocumentType] = []
for name in raw:
with contextlib.suppress(KeyError):
resolved.append(DocumentType[name])
return resolved
def _base_conditions(
search_space_id: int,
scope: SearchScope,
document_types: list[DocumentType] | None,
) -> list:
"""Filters shared by both search legs."""
conditions = [
Document.search_space_id == search_space_id,
func.coalesce(Document.status["state"].astext, "ready") != "deleting",
]
if document_types:
conditions.append(Document.document_type.in_(document_types))
if scope.document_ids:
conditions.append(Document.id.in_(scope.document_ids))
if scope.start_date is not None:
conditions.append(Document.updated_at >= scope.start_date)
if scope.end_date is not None:
conditions.append(Document.updated_at <= scope.end_date)
return conditions
async def _fused_chunks(
db_session: AsyncSession,
*,
query: str,
query_embedding: list[float],
conditions: list,
candidate_pool: int,
):
"""Run semantic + keyword legs and fuse them with RRF; return (Chunk, score) rows."""
tsvector = func.to_tsvector("english", Chunk.content)
tsquery = func.plainto_tsquery("english", query)
semantic = (
select(
Chunk.id,
func.rank()
.over(order_by=Chunk.embedding.op("<=>")(query_embedding))
.label("rank"),
)
.join(Document, Chunk.document_id == Document.id)
.where(*conditions)
.order_by(Chunk.embedding.op("<=>")(query_embedding))
.limit(candidate_pool)
.cte("semantic_search")
)
keyword = (
select(
Chunk.id,
func.rank()
.over(order_by=func.ts_rank_cd(tsvector, tsquery).desc())
.label("rank"),
)
.join(Document, Chunk.document_id == Document.id)
.where(*conditions)
.where(tsvector.op("@@")(tsquery))
.order_by(func.ts_rank_cd(tsvector, tsquery).desc())
.limit(candidate_pool)
.cte("keyword_search")
)
fused = (
select(
Chunk,
(
func.coalesce(1.0 / (_RRF_K + semantic.c.rank), 0.0)
+ func.coalesce(1.0 / (_RRF_K + keyword.c.rank), 0.0)
).label("score"),
)
.select_from(
semantic.outerjoin(keyword, semantic.c.id == keyword.c.id, full=True)
)
.join(Chunk, Chunk.id == func.coalesce(semantic.c.id, keyword.c.id))
.options(joinedload(Chunk.document))
.order_by(text("score DESC"))
.limit(candidate_pool)
)
result = await db_session.execute(fused)
return result.all()
def _group_into_documents(rows, *, top_k: int) -> list[DocumentHit]:
"""Group fused chunks by document, keep the top_k best, order chunks for reading."""
chunks_by_doc: dict[int, list[ChunkHit]] = {}
document_by_id: dict[int, Document] = {}
best_score: dict[int, float] = {}
order: list[int] = []
for chunk, score in rows:
document_id = chunk.document.id
if document_id not in chunks_by_doc:
chunks_by_doc[document_id] = []
document_by_id[document_id] = chunk.document
best_score[document_id] = float(score)
order.append(document_id)
chunks_by_doc[document_id].append(
ChunkHit(
chunk_id=chunk.id,
content=chunk.content,
position=chunk.position,
score=float(score),
)
)
return [
DocumentHit(
document_id=document_id,
title=document_by_id[document_id].title,
document_type=_type_value(document_by_id[document_id]),
metadata=document_by_id[document_id].document_metadata or {},
score=best_score[document_id],
chunks=_reading_order(chunks_by_doc[document_id]),
)
for document_id in order[:top_k]
]
def _reading_order(chunks: list[ChunkHit]) -> list[ChunkHit]:
"""Keep the most relevant chunks, then present them in document order."""
most_relevant = sorted(chunks, key=lambda c: c.score, reverse=True)[
:_MAX_PASSAGES_PER_DOC
]
return sorted(most_relevant, key=lambda c: c.position)
def _type_value(document: Document) -> str | None:
document_type = getattr(document, "document_type", None)
return document_type.value if document_type is not None else None
__all__ = ["search_chunks"]

View file

@ -0,0 +1,47 @@
"""Value objects for knowledge-base retrieval: the query scope and raw hits.
``SearchScope`` is the optional filter a search runs under. ``DocumentHit`` /
``ChunkHit`` are the retriever's typed output — matched chunks grouped by their
document which the adapter turns into renderable ``RetrievedDocument``s.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
@dataclass(frozen=True)
class SearchScope:
"""Filters narrowing a search; ``None``/empty means "whole knowledge base"."""
document_types: tuple[str, ...] | None = None
document_ids: tuple[int, ...] | None = None
start_date: datetime | None = None
end_date: datetime | None = None
@dataclass(frozen=True)
class ChunkHit:
"""One matched chunk, with the position that orders it within its document."""
chunk_id: int
content: str
position: int
score: float
@dataclass(frozen=True)
class DocumentHit:
"""A document and the chunks that matched the query, ordered by position."""
document_id: int
title: str
document_type: str | None
metadata: dict[str, Any]
score: float
chunks: list[ChunkHit] = field(default_factory=list)
__all__ = ["ChunkHit", "DocumentHit", "SearchScope"]