diff --git a/config/config2.example.yaml b/config/config2.example.yaml index a24892c2a..776ea6f54 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -80,7 +80,7 @@ exp_pool: enable_write: false persist_path: .chroma_exp_data # The directory. retrieval_type: bm25 # Default is `bm25`, can be set to `chroma` for vector storage, which requires setting up embedding. - use_llm_ranker: false # If `use_llm_ranker` is true, then it will use LLM Reranker to get better result, but it is not always guaranteed that the output will be parseable for reranking. + use_llm_ranker: true # If `use_llm_ranker` is true, then it will use LLM Reranker to get better result, but it is not always guaranteed that the output will be parseable for reranking. azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY" azure_tts_region: "eastus" diff --git a/metagpt/configs/exp_pool_config.py b/metagpt/configs/exp_pool_config.py index 7611dda27..8d33b25aa 100644 --- a/metagpt/configs/exp_pool_config.py +++ b/metagpt/configs/exp_pool_config.py @@ -21,4 +21,4 @@ class ExperiencePoolConfig(YamlModel): retrieval_type: ExperiencePoolRetrievalType = Field( default=ExperiencePoolRetrievalType.BM25, description="The retrieval type for experience pool." ) - use_llm_ranker: bool = Field(default=False, description="Use LLM Reranker to get better result.") + use_llm_ranker: bool = Field(default=True, description="Use LLM Reranker to get better result.") diff --git a/metagpt/rag/factories/ranker.py b/metagpt/rag/factories/ranker.py index 7abda162a..c825c228c 100644 --- a/metagpt/rag/factories/ranker.py +++ b/metagpt/rag/factories/ranker.py @@ -38,6 +38,7 @@ class RankerFactory(ConfigBasedFactory): def _create_llm_ranker(self, config: LLMRankerConfig, **kwargs) -> LLMRerank: config.llm = self._extract_llm(config, **kwargs) + return LLMRerank(**config.model_dump()) def _create_colbert_ranker(self, config: ColbertRerankConfig, **kwargs) -> LLMRerank: diff --git a/metagpt/rag/prompts/__init__.py b/metagpt/rag/prompts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/metagpt/rag/prompts/default_prompts.py b/metagpt/rag/prompts/default_prompts.py new file mode 100644 index 000000000..12a5e2f06 --- /dev/null +++ b/metagpt/rag/prompts/default_prompts.py @@ -0,0 +1,34 @@ +"""Set of default prompts.""" + +from llama_index.core.prompts.base import PromptTemplate +from llama_index.core.prompts.prompt_type import PromptType + +DEFAULT_CHOICE_SELECT_PROMPT_TMPL = """ +You are a highly efficient assistant, tasked with evaluating a list of documents to a given question. + +I will provide you with a question with a list of documents. Your task is to respond with the numbers of the documents you should consult to answer the question, in order of relevance, as well as the relevance score. + + +## Question +{query_str} + +## Documents +{context_str} + +## Format Example +Doc: 9, Relevance: 7 + +## Instructions +- Understand the question. +- Evaluate the relevance between the question and the documents. +- The relevance score is a number from 1-10 based on how relevant you think the document is to the question. +- Do not include any documents that are not relevant to the question. + +## Constraint +Format: Just print the result in format like **Format Example**. + +## Action +Follow instructions, generate output and make sure it follows the **Constraint**. +""" + +DEFAULT_CHOICE_SELECT_PROMPT = PromptTemplate(DEFAULT_CHOICE_SELECT_PROMPT_TMPL, prompt_type=PromptType.CHOICE_SELECT) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 5be2b050b..4180536a3 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -6,6 +6,7 @@ from typing import Any, ClassVar, List, Literal, Optional, Union from chromadb.api.types import CollectionMetadata from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex +from llama_index.core.prompts import BasePromptTemplate from llama_index.core.schema import TextNode from llama_index.core.vector_stores.types import VectorStoreQueryMode from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator @@ -14,6 +15,7 @@ from metagpt.config2 import Config from metagpt.configs.embedding_config import EmbeddingType from metagpt.logs import logger from metagpt.rag.interface import RAGObject +from metagpt.rag.prompts.default_prompts import DEFAULT_CHOICE_SELECT_PROMPT class BaseRetrieverConfig(BaseModel): @@ -124,6 +126,9 @@ class LLMRankerConfig(BaseRankerConfig): default=None, description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1.", ) + choice_select_prompt: Optional[BasePromptTemplate] = Field( + default=DEFAULT_CHOICE_SELECT_PROMPT, description="Choice select prompt." + ) class ColbertRerankConfig(BaseRankerConfig):