mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-26 21:39:43 +02:00
retrieval: add reranking wrapper and context service
This commit is contained in:
parent
407bfcd94f
commit
4fe208557a
3 changed files with 135 additions and 0 deletions
|
|
@ -0,0 +1,18 @@
|
|||
"""Knowledge-base retrieval: hybrid search rendered as citable evidence.
|
||||
|
||||
Public surface is the service (``search_knowledge_base_context``) and its input
|
||||
value object (``SearchScope``); the rest are building blocks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .models import ChunkHit, DocumentHit, SearchScope
|
||||
from .service import build_context, search_knowledge_base_context
|
||||
|
||||
__all__ = [
|
||||
"ChunkHit",
|
||||
"DocumentHit",
|
||||
"SearchScope",
|
||||
"build_context",
|
||||
"search_knowledge_base_context",
|
||||
]
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
"""Reorder retrieved documents with the configured reranker (no-op if disabled).
|
||||
|
||||
Ranking is by concatenated matched-chunk content; ``DocumentHit`` order is
|
||||
rewritten to follow the reranker's result.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .models import DocumentHit
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.reranker_service import RerankerService
|
||||
|
||||
|
||||
def rerank_hits(
|
||||
query: str,
|
||||
hits: list[DocumentHit],
|
||||
reranker: RerankerService | None,
|
||||
) -> list[DocumentHit]:
|
||||
"""Return ``hits`` reordered by the reranker; unchanged when none is set."""
|
||||
if reranker is None or len(hits) < 2:
|
||||
return hits
|
||||
|
||||
hit_by_id = {hit.document_id: hit for hit in hits}
|
||||
ranked = reranker.rerank_documents(query, [_as_document(hit) for hit in hits])
|
||||
reordered = [
|
||||
hit_by_id[doc["document_id"]]
|
||||
for doc in ranked
|
||||
if doc.get("document_id") in hit_by_id
|
||||
]
|
||||
# Fall back to the original order if the reranker dropped or garbled ids.
|
||||
return reordered if len(reordered) == len(hits) else hits
|
||||
|
||||
|
||||
def _as_document(hit: DocumentHit) -> dict[str, Any]:
|
||||
"""The minimal dict shape ``RerankerService.rerank_documents`` scores on."""
|
||||
return {
|
||||
"document_id": hit.document_id,
|
||||
"content": "\n\n".join(chunk.content for chunk in hit.chunks),
|
||||
"score": hit.score,
|
||||
"document": {
|
||||
"id": hit.document_id,
|
||||
"title": hit.title,
|
||||
"document_type": hit.document_type,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["rerank_hits"]
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
"""Search the knowledge base and render it as model-facing ``<retrieved_context>``.
|
||||
|
||||
The retrieval spine end to end: hybrid search → rerank → adapt → render, with
|
||||
each shown passage registered for ``[n]`` citation along the way.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.chat.multi_agent_chat.shared.citations import CitationRegistry
|
||||
from app.agents.chat.multi_agent_chat.shared.retrieved_context import (
|
||||
render_retrieved_context,
|
||||
)
|
||||
|
||||
from .adapter import to_retrieved_document
|
||||
from .hybrid_search import search_chunks
|
||||
from .models import DocumentHit, SearchScope
|
||||
from .reranking import rerank_hits
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.reranker_service import RerankerService
|
||||
|
||||
_DEFAULT_TOP_K = 10
|
||||
|
||||
|
||||
async def search_knowledge_base_context(
|
||||
db_session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
query: str,
|
||||
registry: CitationRegistry,
|
||||
scope: SearchScope | None = None,
|
||||
reranker: RerankerService | None = None,
|
||||
top_k: int = _DEFAULT_TOP_K,
|
||||
) -> str | None:
|
||||
"""Retrieve KB evidence for ``query`` and render it, registering each ``[n]``.
|
||||
|
||||
Returns ``None`` when nothing matched, so the caller can skip the block.
|
||||
"""
|
||||
hits = await search_chunks(
|
||||
db_session,
|
||||
search_space_id=search_space_id,
|
||||
query=query,
|
||||
scope=scope or SearchScope(),
|
||||
top_k=top_k,
|
||||
)
|
||||
return build_context(query, hits, registry, reranker=reranker)
|
||||
|
||||
|
||||
def build_context(
|
||||
query: str,
|
||||
hits: list[DocumentHit],
|
||||
registry: CitationRegistry,
|
||||
*,
|
||||
reranker: RerankerService | None = None,
|
||||
) -> str | None:
|
||||
"""Rerank → adapt → render. Pure given ``hits``, so it is unit-testable."""
|
||||
ranked = rerank_hits(query, hits, reranker)
|
||||
documents = [to_retrieved_document(hit) for hit in ranked]
|
||||
return render_retrieved_context(documents, registry)
|
||||
|
||||
|
||||
__all__ = ["build_context", "search_knowledge_base_context"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue