mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
Merge pull request #1172 from seehi/feat-rag-embedding
Make RAG embedding configurable and add gpt-4-turbo in token_counter.
This commit is contained in:
commit
63ca5452bc
12 changed files with 268 additions and 37 deletions
|
|
@ -13,6 +13,16 @@ llm:
|
|||
# - gpt-4 8k: "gpt-4"
|
||||
# See for more: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
|
||||
|
||||
# RAG Embedding.
|
||||
# For backward compatibility, if the embedding is not set and the llm's api_type is either openai or azure, the llm's config will be used.
|
||||
embedding:
|
||||
api_type: "" # openai / azure / gemini / ollama etc. Check EmbeddingType for more options.
|
||||
base_url: ""
|
||||
api_key: ""
|
||||
model: ""
|
||||
api_version: ""
|
||||
embed_batch_size: 100
|
||||
|
||||
repair_llm_output: true # when the output is not a valid json, try to repair it
|
||||
|
||||
proxy: "YOUR_PROXY" # for tools like requests, playwright, selenium, etc.
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from metagpt.const import DATA_PATH, EXAMPLE_DATA_PATH
|
|||
from metagpt.logs import logger
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.schema import (
|
||||
BM25RetrieverConfig,
|
||||
ChromaIndexConfig,
|
||||
ChromaRetrieverConfig,
|
||||
ElasticsearchIndexConfig,
|
||||
|
|
@ -51,7 +50,7 @@ class RAGExample:
|
|||
if not self._engine:
|
||||
self._engine = SimpleEngine.from_docs(
|
||||
input_files=[DOC_PATH],
|
||||
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
|
||||
retriever_configs=[FAISSRetrieverConfig()],
|
||||
ranker_configs=[LLMRankerConfig()],
|
||||
)
|
||||
return self._engine
|
||||
|
|
@ -61,7 +60,7 @@ class RAGExample:
|
|||
self._engine = value
|
||||
|
||||
async def run_pipeline(self, question=QUESTION, print_title=True):
|
||||
"""This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like:
|
||||
"""This example run rag pipeline, use faiss retriever and llm ranker, will print something like:
|
||||
|
||||
Retrieve Result:
|
||||
0. Productivi..., 10.0
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from typing import Dict, Iterable, List, Literal, Optional
|
|||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from metagpt.configs.browser_config import BrowserConfig
|
||||
from metagpt.configs.embedding_config import EmbeddingConfig
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.configs.mermaid_config import MermaidConfig
|
||||
from metagpt.configs.redis_config import RedisConfig
|
||||
|
|
@ -47,6 +48,9 @@ class Config(CLIParams, YamlModel):
|
|||
# Key Parameters
|
||||
llm: LLMConfig
|
||||
|
||||
# RAG Embedding
|
||||
embedding: EmbeddingConfig = EmbeddingConfig()
|
||||
|
||||
# Global Proxy. Will be used if llm.proxy is not set
|
||||
proxy: str = ""
|
||||
|
||||
|
|
|
|||
50
metagpt/configs/embedding_config.py
Normal file
50
metagpt/configs/embedding_config.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import field_validator
|
||||
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class EmbeddingType(Enum):
|
||||
OPENAI = "openai"
|
||||
AZURE = "azure"
|
||||
GEMINI = "gemini"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
class EmbeddingConfig(YamlModel):
|
||||
"""Config for Embedding.
|
||||
|
||||
Examples:
|
||||
---------
|
||||
api_type: "openai"
|
||||
api_key: "YOU_API_KEY"
|
||||
|
||||
api_type: "azure"
|
||||
api_key: "YOU_API_KEY"
|
||||
base_url: "YOU_BASE_URL"
|
||||
api_version: "YOU_API_VERSION"
|
||||
|
||||
api_type: "gemini"
|
||||
api_key: "YOU_API_KEY"
|
||||
|
||||
api_type: "ollama"
|
||||
base_url: "YOU_BASE_URL"
|
||||
model: "YOU_MODEL"
|
||||
"""
|
||||
|
||||
api_type: Optional[EmbeddingType] = None
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
api_version: Optional[str] = None
|
||||
|
||||
model: Optional[str] = None
|
||||
embed_batch_size: Optional[int] = None
|
||||
|
||||
@field_validator("api_type", mode="before")
|
||||
@classmethod
|
||||
def check_api_type(cls, v):
|
||||
if v == "":
|
||||
return None
|
||||
return v
|
||||
|
|
@ -26,6 +26,9 @@ class GenericFactory:
|
|||
if creator:
|
||||
return creator(**kwargs)
|
||||
|
||||
self._raise_for_key(key)
|
||||
|
||||
def _raise_for_key(self, key: Any):
|
||||
raise ValueError(f"Creator not registered for key: {key}")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,37 +1,103 @@
|
|||
"""RAG Embedding Factory."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
from llama_index.embeddings.gemini import GeminiEmbedding
|
||||
from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.base import GenericFactory
|
||||
|
||||
|
||||
class RAGEmbeddingFactory(GenericFactory):
|
||||
"""Create LlamaIndex Embedding with MetaGPT's config."""
|
||||
"""Create LlamaIndex Embedding with MetaGPT's embedding config."""
|
||||
|
||||
def __init__(self):
|
||||
creators = {
|
||||
EmbeddingType.OPENAI: self._create_openai,
|
||||
EmbeddingType.AZURE: self._create_azure,
|
||||
EmbeddingType.GEMINI: self._create_gemini,
|
||||
EmbeddingType.OLLAMA: self._create_ollama,
|
||||
# For backward compatibility
|
||||
LLMType.OPENAI: self._create_openai,
|
||||
LLMType.AZURE: self._create_azure,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_rag_embedding(self, key: LLMType = None) -> BaseEmbedding:
|
||||
"""Key is LLMType, default use config.llm.api_type."""
|
||||
return super().get_instance(key or config.llm.api_type)
|
||||
def get_rag_embedding(self, key: EmbeddingType = None) -> BaseEmbedding:
|
||||
"""Key is EmbeddingType."""
|
||||
return super().get_instance(key or self._resolve_embedding_type())
|
||||
|
||||
def _create_openai(self):
|
||||
return OpenAIEmbedding(api_key=config.llm.api_key, api_base=config.llm.base_url)
|
||||
def _resolve_embedding_type(self) -> EmbeddingType | LLMType:
|
||||
"""Resolves the embedding type.
|
||||
|
||||
def _create_azure(self):
|
||||
return AzureOpenAIEmbedding(
|
||||
azure_endpoint=config.llm.base_url,
|
||||
api_key=config.llm.api_key,
|
||||
api_version=config.llm.api_version,
|
||||
If the embedding type is not specified, for backward compatibility, it checks if the LLM API type is either OPENAI or AZURE.
|
||||
Raise TypeError if embedding type not found.
|
||||
"""
|
||||
if config.embedding.api_type:
|
||||
return config.embedding.api_type
|
||||
|
||||
if config.llm.api_type in [LLMType.OPENAI, LLMType.AZURE]:
|
||||
return config.llm.api_type
|
||||
|
||||
raise TypeError("To use RAG, please set your embedding in config2.yaml.")
|
||||
|
||||
def _create_openai(self) -> OpenAIEmbedding:
|
||||
params = dict(
|
||||
api_key=config.embedding.api_key or config.llm.api_key,
|
||||
api_base=config.embedding.base_url or config.llm.base_url,
|
||||
)
|
||||
|
||||
self._try_set_model_and_batch_size(params)
|
||||
|
||||
return OpenAIEmbedding(**params)
|
||||
|
||||
def _create_azure(self) -> AzureOpenAIEmbedding:
|
||||
params = dict(
|
||||
api_key=config.embedding.api_key or config.llm.api_key,
|
||||
azure_endpoint=config.embedding.base_url or config.llm.base_url,
|
||||
api_version=config.embedding.api_version or config.llm.api_version,
|
||||
)
|
||||
|
||||
self._try_set_model_and_batch_size(params)
|
||||
|
||||
return AzureOpenAIEmbedding(**params)
|
||||
|
||||
def _create_gemini(self) -> GeminiEmbedding:
|
||||
params = dict(
|
||||
api_key=config.embedding.api_key,
|
||||
api_base=config.embedding.base_url,
|
||||
)
|
||||
|
||||
self._try_set_model_and_batch_size(params)
|
||||
|
||||
return GeminiEmbedding(**params)
|
||||
|
||||
def _create_ollama(self) -> OllamaEmbedding:
|
||||
params = dict(
|
||||
base_url=config.embedding.base_url,
|
||||
)
|
||||
|
||||
self._try_set_model_and_batch_size(params)
|
||||
|
||||
return OllamaEmbedding(**params)
|
||||
|
||||
def _try_set_model_and_batch_size(self, params: dict):
|
||||
"""Set the model_name and embed_batch_size only when they are specified."""
|
||||
if config.embedding.model:
|
||||
params["model_name"] = config.embedding.model
|
||||
|
||||
if config.embedding.embed_batch_size:
|
||||
params["embed_batch_size"] = config.embedding.embed_batch_size
|
||||
|
||||
def _raise_for_key(self, key: Any):
|
||||
raise ValueError(f"The embedding type is currently not supported: `{type(key)}`, {key}")
|
||||
|
||||
|
||||
get_rag_embedding = RAGEmbeddingFactory().get_rag_embedding
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""RAG LLM."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW
|
||||
|
|
@ -15,7 +15,7 @@ from pydantic import Field
|
|||
from metagpt.config2 import config
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.async_helper import run_coroutine_in_new_loop
|
||||
from metagpt.utils.async_helper import NestAsyncio
|
||||
from metagpt.utils.token_counter import TOKEN_MAX
|
||||
|
||||
|
||||
|
|
@ -39,7 +39,8 @@ class RAGLLM(CustomLLM):
|
|||
|
||||
@llm_completion_callback()
|
||||
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
|
||||
return run_coroutine_in_new_loop(self.acomplete(prompt, **kwargs))
|
||||
NestAsyncio.apply_once()
|
||||
return asyncio.get_event_loop().run_until_complete(self.acomplete(prompt, **kwargs))
|
||||
|
||||
@llm_completion_callback()
|
||||
async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse:
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
"""RAG schemas."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Union
|
||||
from typing import Any, ClassVar, Literal, Union
|
||||
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.core.vector_stores.types import VectorStoreQueryMode
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
from metagpt.rag.interface import RAGObject
|
||||
|
||||
|
||||
|
|
@ -31,7 +33,19 @@ class IndexRetrieverConfig(BaseRetrieverConfig):
|
|||
class FAISSRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for FAISS-based retrievers."""
|
||||
|
||||
dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.")
|
||||
dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.")
|
||||
|
||||
_embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = {
|
||||
EmbeddingType.GEMINI: 768,
|
||||
EmbeddingType.OLLAMA: 4096,
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_dimensions(self):
|
||||
if self.dimensions == 0:
|
||||
self.dimensions = self._embedding_type_to_dimensions.get(config.embedding.api_type, 1536)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class BM25RetrieverConfig(IndexRetrieverConfig):
|
||||
|
|
|
|||
|
|
@ -20,3 +20,18 @@ def run_coroutine_in_new_loop(coroutine) -> Any:
|
|||
new_loop.call_soon_threadsafe(new_loop.stop)
|
||||
t.join()
|
||||
new_loop.close()
|
||||
|
||||
|
||||
class NestAsyncio:
|
||||
"""Make asyncio event loop reentrant."""
|
||||
|
||||
is_applied = False
|
||||
|
||||
@classmethod
|
||||
def apply_once(cls):
|
||||
"""Ensures `nest_asyncio.apply()` is called only once."""
|
||||
if not cls.is_applied:
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
cls.is_applied = True
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ TOKEN_COSTS = {
|
|||
"gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-4-0613": {"prompt": 0.06, "completion": 0.12},
|
||||
"gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-turbo": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-0125-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03},
|
||||
"gpt-4-vision-preview": {"prompt": 0.01, "completion": 0.03}, # TODO add extra image price calculator
|
||||
|
|
@ -147,6 +148,7 @@ FIREWORKS_GRADE_TOKEN_COSTS = {
|
|||
TOKEN_MAX = {
|
||||
"gpt-4-0125-preview": 128000,
|
||||
"gpt-4-turbo-preview": 128000,
|
||||
"gpt-4-turbo": 128000,
|
||||
"gpt-4-1106-preview": 128000,
|
||||
"gpt-4-vision-preview": 128000,
|
||||
"gpt-4-1106-vision-preview": 128000,
|
||||
|
|
@ -202,6 +204,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0125"):
|
|||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-0125-preview",
|
||||
"gpt-4-1106-preview",
|
||||
|
|
|
|||
3
setup.py
3
setup.py
|
|
@ -32,12 +32,15 @@ extras_require = {
|
|||
"llama-index-core==0.10.15",
|
||||
"llama-index-embeddings-azure-openai==0.1.6",
|
||||
"llama-index-embeddings-openai==0.1.5",
|
||||
"llama-index-embeddings-gemini==0.1.6",
|
||||
"llama-index-embeddings-ollama==0.1.2",
|
||||
"llama-index-llms-azure-openai==0.1.4",
|
||||
"llama-index-readers-file==0.1.4",
|
||||
"llama-index-retrievers-bm25==0.1.3",
|
||||
"llama-index-vector-stores-faiss==0.1.1",
|
||||
"llama-index-vector-stores-elasticsearch==0.1.6",
|
||||
"llama-index-vector-stores-chroma==0.1.6",
|
||||
"docx2txt==0.8",
|
||||
],
|
||||
"android_assistant": ["pyshine==0.0.9", "opencv-python==4.6.0.66"],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
|
||||
|
||||
|
|
@ -10,30 +11,51 @@ class TestRAGEmbeddingFactory:
|
|||
self.embedding_factory = RAGEmbeddingFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_embedding(self, mocker):
|
||||
def mock_config(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.config")
|
||||
|
||||
@staticmethod
|
||||
def mock_openai_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_embedding(self, mocker):
|
||||
@staticmethod
|
||||
def mock_azure_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding")
|
||||
|
||||
def test_get_rag_embedding_openai(self, mock_openai_embedding):
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(LLMType.OPENAI)
|
||||
@staticmethod
|
||||
def mock_gemini_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.GeminiEmbedding")
|
||||
|
||||
# Assert
|
||||
mock_openai_embedding.assert_called_once()
|
||||
@staticmethod
|
||||
def mock_ollama_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.OllamaEmbedding")
|
||||
|
||||
def test_get_rag_embedding_azure(self, mock_azure_embedding):
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(LLMType.AZURE)
|
||||
|
||||
# Assert
|
||||
mock_azure_embedding.assert_called_once()
|
||||
|
||||
def test_get_rag_embedding_default(self, mocker, mock_openai_embedding):
|
||||
@pytest.mark.parametrize(
|
||||
("mock_func", "embedding_type"),
|
||||
[
|
||||
(mock_openai_embedding, LLMType.OPENAI),
|
||||
(mock_azure_embedding, LLMType.AZURE),
|
||||
(mock_openai_embedding, EmbeddingType.OPENAI),
|
||||
(mock_azure_embedding, EmbeddingType.AZURE),
|
||||
(mock_gemini_embedding, EmbeddingType.GEMINI),
|
||||
(mock_ollama_embedding, EmbeddingType.OLLAMA),
|
||||
],
|
||||
)
|
||||
def test_get_rag_embedding(self, mock_func, embedding_type, mocker):
|
||||
# Mock
|
||||
mock_config = mocker.patch("metagpt.rag.factories.embedding.config")
|
||||
mock = mock_func(mocker)
|
||||
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(embedding_type)
|
||||
|
||||
# Assert
|
||||
mock.assert_called_once()
|
||||
|
||||
def test_get_rag_embedding_default(self, mocker, mock_config):
|
||||
# Mock
|
||||
mock_openai_embedding = self.mock_openai_embedding(mocker)
|
||||
|
||||
mock_config.embedding.api_type = None
|
||||
mock_config.llm.api_type = LLMType.OPENAI
|
||||
|
||||
# Exec
|
||||
|
|
@ -41,3 +63,44 @@ class TestRAGEmbeddingFactory:
|
|||
|
||||
# Assert
|
||||
mock_openai_embedding.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, embed_batch_size, expected_params",
|
||||
[("test_model", 100, {"model_name": "test_model", "embed_batch_size": 100}), (None, None, {})],
|
||||
)
|
||||
def test_try_set_model_and_batch_size(self, mock_config, model, embed_batch_size, expected_params):
|
||||
# Mock
|
||||
mock_config.embedding.model = model
|
||||
mock_config.embedding.embed_batch_size = embed_batch_size
|
||||
|
||||
# Setup
|
||||
test_params = {}
|
||||
|
||||
# Exec
|
||||
self.embedding_factory._try_set_model_and_batch_size(test_params)
|
||||
|
||||
# Assert
|
||||
assert test_params == expected_params
|
||||
|
||||
def test_resolve_embedding_type(self, mock_config):
|
||||
# Mock
|
||||
mock_config.embedding.api_type = EmbeddingType.OPENAI
|
||||
|
||||
# Exec
|
||||
embedding_type = self.embedding_factory._resolve_embedding_type()
|
||||
|
||||
# Assert
|
||||
assert embedding_type == EmbeddingType.OPENAI
|
||||
|
||||
def test_resolve_embedding_type_exception(self, mock_config):
|
||||
# Mock
|
||||
mock_config.embedding.api_type = None
|
||||
mock_config.llm.api_type = LLMType.GEMINI
|
||||
|
||||
# Assert
|
||||
with pytest.raises(TypeError):
|
||||
self.embedding_factory._resolve_embedding_type()
|
||||
|
||||
def test_raise_for_key(self):
|
||||
with pytest.raises(ValueError):
|
||||
self.embedding_factory._raise_for_key("key")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue