diff --git a/tests/unit/test_retrieval/test_document_rag_diversity_selection.py b/tests/unit/test_retrieval/test_document_rag_diversity_selection.py index 47d03cb8..6dcd9458 100644 --- a/tests/unit/test_retrieval/test_document_rag_diversity_selection.py +++ b/tests/unit/test_retrieval/test_document_rag_diversity_selection.py @@ -1,30 +1,8 @@ -import importlib.util -from pathlib import Path - - -REPO_ROOT = Path(__file__).resolve().parents[3] -RERANK_PATH = ( - REPO_ROOT - / "trustgraph-flow" - / "trustgraph" - / "retrieval" - / "document_rag" - / "rerank.py" +from trustgraph.retrieval.document_rag.rerank import ( + RerankCandidate, normalize_candidate_scores, mmr_select, + _pair_diversity_penalty ) -spec = importlib.util.spec_from_file_location( - "document_rag_diversity_rerank", - RERANK_PATH, -) -rerank = importlib.util.module_from_spec(spec) -spec.loader.exec_module(rerank) - -RerankCandidate = rerank.RerankCandidate -normalize_candidate_scores = rerank.normalize_candidate_scores -mmr_select = rerank.mmr_select -_pair_diversity_penalty = rerank._pair_diversity_penalty - - def candidate(index, chunk_id, text, score): return RerankCandidate( index=index,