diff --git a/config/config2.example.yaml b/config/config2.example.yaml index 3249f5ae3..6d1148a85 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -18,6 +18,7 @@ embedding: model: "" api_version: "" embed_batch_size: 100 + dimensions: repair_llm_output: true # when the output is not a valid json, try to repair it diff --git a/metagpt/configs/embedding_config.py b/metagpt/configs/embedding_config.py index 20de47999..f9b41b9dc 100644 --- a/metagpt/configs/embedding_config.py +++ b/metagpt/configs/embedding_config.py @@ -20,11 +20,13 @@ class EmbeddingConfig(YamlModel): --------- api_type: "openai" api_key: "YOU_API_KEY" + dimensions: "YOUR_MODEL_DIMENSIONS" api_type: "azure" api_key: "YOU_API_KEY" base_url: "YOU_BASE_URL" api_version: "YOU_API_VERSION" + dimensions: "YOUR_MODEL_DIMENSIONS" api_type: "gemini" api_key: "YOU_API_KEY" @@ -32,6 +34,7 @@ class EmbeddingConfig(YamlModel): api_type: "ollama" base_url: "YOU_BASE_URL" model: "YOU_MODEL" + dimensions: "YOUR_MODEL_DIMENSIONS" """ api_type: Optional[EmbeddingType] = None @@ -41,6 +44,7 @@ class EmbeddingConfig(YamlModel): model: Optional[str] = None embed_batch_size: Optional[int] = None + dimensions: Optional[int] = None # output dimension of embedding model @field_validator("api_type", mode="before") @classmethod diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index e7b2e5ce9..bedba164c 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -12,6 +12,7 @@ from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from metagpt.config2 import config from metagpt.configs.embedding_config import EmbeddingType +from metagpt.logs import logger from metagpt.rag.interface import RAGObject @@ -34,16 +35,21 @@ class IndexRetrieverConfig(BaseRetrieverConfig): class FAISSRetrieverConfig(IndexRetrieverConfig): """Config for FAISS-based retrievers.""" - dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.") + dimensions: int = Field( + default=config.embedding.dimensions, description="Dimensionality of the vectors for FAISS index construction." + ) _embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = { EmbeddingType.GEMINI: 768, - EmbeddingType.OLLAMA: 4096, } @model_validator(mode="after") def check_dimensions(self): - if self.dimensions == 0: + if self.dimensions is None: + if config.embedding.api_type not in self._embedding_type_to_dimensions: + logger.info( + f"You didn't set the dimensions in config when using {config.embedding.api_type}, default to 1536" + ) self.dimensions = self._embedding_type_to_dimensions.get(config.embedding.api_type, 1536) return self