From 39ed967c80a55ad226f662f7e5123f3c79235920 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Thu, 2 May 2024 01:48:04 +0800 Subject: [PATCH 1/4] make rag configurable --- config/config2.example.yaml | 1 + metagpt/configs/embedding_config.py | 4 ++++ metagpt/rag/schema.py | 12 +++++++++--- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/config/config2.example.yaml b/config/config2.example.yaml index 3249f5ae3..6d1148a85 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -18,6 +18,7 @@ embedding: model: "" api_version: "" embed_batch_size: 100 + dimensions: repair_llm_output: true # when the output is not a valid json, try to repair it diff --git a/metagpt/configs/embedding_config.py b/metagpt/configs/embedding_config.py index 20de47999..f9b41b9dc 100644 --- a/metagpt/configs/embedding_config.py +++ b/metagpt/configs/embedding_config.py @@ -20,11 +20,13 @@ class EmbeddingConfig(YamlModel): --------- api_type: "openai" api_key: "YOU_API_KEY" + dimensions: "YOUR_MODEL_DIMENSIONS" api_type: "azure" api_key: "YOU_API_KEY" base_url: "YOU_BASE_URL" api_version: "YOU_API_VERSION" + dimensions: "YOUR_MODEL_DIMENSIONS" api_type: "gemini" api_key: "YOU_API_KEY" @@ -32,6 +34,7 @@ class EmbeddingConfig(YamlModel): api_type: "ollama" base_url: "YOU_BASE_URL" model: "YOU_MODEL" + dimensions: "YOUR_MODEL_DIMENSIONS" """ api_type: Optional[EmbeddingType] = None @@ -41,6 +44,7 @@ class EmbeddingConfig(YamlModel): model: Optional[str] = None embed_batch_size: Optional[int] = None + dimensions: Optional[int] = None # output dimension of embedding model @field_validator("api_type", mode="before") @classmethod diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index e7b2e5ce9..bedba164c 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -12,6 +12,7 @@ from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from metagpt.config2 import config from metagpt.configs.embedding_config import EmbeddingType +from metagpt.logs import logger from metagpt.rag.interface import RAGObject @@ -34,16 +35,21 @@ class IndexRetrieverConfig(BaseRetrieverConfig): class FAISSRetrieverConfig(IndexRetrieverConfig): """Config for FAISS-based retrievers.""" - dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.") + dimensions: int = Field( + default=config.embedding.dimensions, description="Dimensionality of the vectors for FAISS index construction." + ) _embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = { EmbeddingType.GEMINI: 768, - EmbeddingType.OLLAMA: 4096, } @model_validator(mode="after") def check_dimensions(self): - if self.dimensions == 0: + if self.dimensions is None: + if config.embedding.api_type not in self._embedding_type_to_dimensions: + logger.info( + f"You didn't set the dimensions in config when using {config.embedding.api_type}, default to 1536" + ) self.dimensions = self._embedding_type_to_dimensions.get(config.embedding.api_type, 1536) return self From fa8b35cef4fb4de9db2b0f0b79cfe85718e26775 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Fri, 3 May 2024 16:46:41 +0800 Subject: [PATCH 2/4] add comment --- config/config2.example.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/config2.example.yaml b/config/config2.example.yaml index 6d1148a85..e57ec3ee8 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -18,7 +18,7 @@ embedding: model: "" api_version: "" embed_batch_size: 100 - dimensions: + dimensions: # output dimension of embedding model repair_llm_output: true # when the output is not a valid json, try to repair it From 553702fa6131b042c0af272de00860a27e506c31 Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Tue, 7 May 2024 12:39:18 +0800 Subject: [PATCH 3/4] make dimensions default --- metagpt/rag/schema.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index bedba164c..ccd727687 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -35,22 +35,23 @@ class IndexRetrieverConfig(BaseRetrieverConfig): class FAISSRetrieverConfig(IndexRetrieverConfig): """Config for FAISS-based retrievers.""" - dimensions: int = Field( - default=config.embedding.dimensions, description="Dimensionality of the vectors for FAISS index construction." - ) + dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.") _embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = { EmbeddingType.GEMINI: 768, + EmbeddingType.OLLAMA: 4096, } @model_validator(mode="after") def check_dimensions(self): - if self.dimensions is None: + if self.dimensions == 0: + self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get( + config.embedding.api_type, 1536 + ) if config.embedding.api_type not in self._embedding_type_to_dimensions: - logger.info( - f"You didn't set the dimensions in config when using {config.embedding.api_type}, default to 1536" + logger.warning( + f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536" ) - self.dimensions = self._embedding_type_to_dimensions.get(config.embedding.api_type, 1536) return self From d31d7507ef893e344af52e01ae07eace5b3962df Mon Sep 17 00:00:00 2001 From: usamimeri_renko <1710269958@qq.com> Date: Tue, 7 May 2024 16:27:25 +0800 Subject: [PATCH 4/4] fix --- metagpt/rag/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index ccd727687..618880a22 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -48,7 +48,7 @@ class FAISSRetrieverConfig(IndexRetrieverConfig): self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get( config.embedding.api_type, 1536 ) - if config.embedding.api_type not in self._embedding_type_to_dimensions: + if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions: logger.warning( f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536" )