SurfSense/surfsense_backend/app/services/reranker_service.py

125 lines
4.7 KiB
Python
Raw Normal View History

2025-03-14 18:53:14 -07:00
import logging
from typing import Any, Optional
2025-03-14 18:53:14 -07:00
from rerankers import Document as RerankerDocument
2025-03-14 18:53:14 -07:00
class RerankerService:
"""
Service for reranking documents using a configured reranker
"""
2025-03-14 18:53:14 -07:00
def __init__(self, reranker_instance=None):
"""
Initialize the reranker service
2025-03-14 18:53:14 -07:00
Args:
reranker_instance: The reranker instance to use for reranking
"""
self.reranker_instance = reranker_instance
def rerank_documents(
self, query_text: str, documents: list[dict[str, Any]]
) -> list[dict[str, Any]]:
2025-03-14 18:53:14 -07:00
"""
Rerank documents using the configured reranker.
Documents can be either:
- Document-grouped (new format): Has `document_id`, `chunks` list, and `content` (concatenated)
- Chunk-based (legacy format): Individual chunks with `chunk_id` and `content`
2025-03-14 18:53:14 -07:00
Args:
query_text: The query text to use for reranking
documents: List of document dictionaries to rerank
2025-03-14 18:53:14 -07:00
Returns:
List[Dict[str, Any]]: Reranked documents with preserved structure
2025-03-14 18:53:14 -07:00
"""
if not self.reranker_instance or not documents:
return documents
2025-03-14 18:53:14 -07:00
try:
# Create Document objects for the rerankers library
reranker_docs = []
for i, doc in enumerate(documents):
# Use document_id for matching
doc_id = doc.get("document_id") or f"doc_{i}"
# Use concatenated content for reranking
2025-03-14 18:53:14 -07:00
content = doc.get("content", "")
score = doc.get("score", 0.0)
document_info = doc.get("document", {})
2025-03-14 18:53:14 -07:00
reranker_docs.append(
RerankerDocument(
text=content,
doc_id=doc_id,
2025-03-14 18:53:14 -07:00
metadata={
"document_id": document_info.get("id", ""),
"document_title": document_info.get("title", ""),
"document_type": document_info.get("document_type", ""),
"rrf_score": score,
# Track original index for fallback matching
"original_index": i,
},
2025-03-14 18:53:14 -07:00
)
)
2025-03-14 18:53:14 -07:00
# Rerank using the configured reranker
reranking_results = self.reranker_instance.rank(
query=query_text, docs=reranker_docs
2025-03-14 18:53:14 -07:00
)
2025-03-14 18:53:14 -07:00
# Process the results from the reranker
# Convert to serializable dictionaries while preserving full structure
2025-03-14 18:53:14 -07:00
serialized_results = []
for result in reranking_results.results:
result_doc_id = result.document.doc_id
original_index = result.document.metadata.get("original_index")
# Find the original document by document_id
original_doc = None
for doc in documents:
if doc.get("document_id") == result_doc_id:
original_doc = doc
break
# Fallback to original index if ID matching fails
if (
original_doc is None
and original_index is not None
and 0 <= original_index < len(documents)
):
original_doc = documents[original_index]
2025-03-14 18:53:14 -07:00
if original_doc:
# Create a deep copy to preserve the full structure including chunks
2025-03-14 18:53:14 -07:00
reranked_doc = original_doc.copy()
# Preserve chunks list if present (important for citation formatting)
if "chunks" in original_doc:
reranked_doc["chunks"] = original_doc["chunks"]
2025-03-14 18:53:14 -07:00
reranked_doc["score"] = float(result.score)
reranked_doc["rank"] = result.rank
serialized_results.append(reranked_doc)
2025-03-14 18:53:14 -07:00
return serialized_results
2025-03-14 18:53:14 -07:00
except Exception as e:
# Log the error
logging.error(f"Error during reranking: {e!s}")
2025-03-14 18:53:14 -07:00
# Fall back to original documents without reranking
return documents
2025-03-14 18:53:14 -07:00
@staticmethod
def get_reranker_instance() -> Optional["RerankerService"]:
2025-03-14 18:53:14 -07:00
"""
Get a reranker service instance from the global configuration.
2025-03-14 18:53:14 -07:00
Returns:
Optional[RerankerService]: A reranker service instance if configured, None otherwise
2025-03-14 18:53:14 -07:00
"""
from app.config import config
if hasattr(config, "reranker_instance") and config.reranker_instance:
2025-03-14 18:53:14 -07:00
return RerankerService(config.reranker_instance)
return None