make rag configurable

This commit is contained in:
usamimeri_renko 2024-05-02 01:48:04 +08:00
parent f201b2f5f3
commit 39ed967c80
3 changed files with 14 additions and 3 deletions

View file

@ -18,6 +18,7 @@ embedding:
model: ""
api_version: ""
embed_batch_size: 100
dimensions:
repair_llm_output: true # when the output is not a valid json, try to repair it

View file

@ -20,11 +20,13 @@ class EmbeddingConfig(YamlModel):
---------
api_type: "openai"
api_key: "YOU_API_KEY"
dimensions: "YOUR_MODEL_DIMENSIONS"
api_type: "azure"
api_key: "YOU_API_KEY"
base_url: "YOU_BASE_URL"
api_version: "YOU_API_VERSION"
dimensions: "YOUR_MODEL_DIMENSIONS"
api_type: "gemini"
api_key: "YOU_API_KEY"
@ -32,6 +34,7 @@ class EmbeddingConfig(YamlModel):
api_type: "ollama"
base_url: "YOU_BASE_URL"
model: "YOU_MODEL"
dimensions: "YOUR_MODEL_DIMENSIONS"
"""
api_type: Optional[EmbeddingType] = None
@ -41,6 +44,7 @@ class EmbeddingConfig(YamlModel):
model: Optional[str] = None
embed_batch_size: Optional[int] = None
dimensions: Optional[int] = None # output dimension of embedding model
@field_validator("api_type", mode="before")
@classmethod

View file

@ -12,6 +12,7 @@ from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from metagpt.config2 import config
from metagpt.configs.embedding_config import EmbeddingType
from metagpt.logs import logger
from metagpt.rag.interface import RAGObject
@ -34,16 +35,21 @@ class IndexRetrieverConfig(BaseRetrieverConfig):
class FAISSRetrieverConfig(IndexRetrieverConfig):
"""Config for FAISS-based retrievers."""
dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.")
dimensions: int = Field(
default=config.embedding.dimensions, description="Dimensionality of the vectors for FAISS index construction."
)
_embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = {
EmbeddingType.GEMINI: 768,
EmbeddingType.OLLAMA: 4096,
}
@model_validator(mode="after")
def check_dimensions(self):
if self.dimensions == 0:
if self.dimensions is None:
if config.embedding.api_type not in self._embedding_type_to_dimensions:
logger.info(
f"You didn't set the dimensions in config when using {config.embedding.api_type}, default to 1536"
)
self.dimensions = self._embedding_type_to_dimensions.get(config.embedding.api_type, 1536)
return self