diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index 83e147dd9..bdb22cb4a 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -4,6 +4,7 @@ import asyncio from pydantic import BaseModel from metagpt.const import DATA_PATH, EXAMPLE_DATA_PATH +from metagpt.logs import logger from metagpt.rag.engines import SimpleEngine from metagpt.rag.schema import ( BM25RetrieverConfig, @@ -85,10 +86,10 @@ class RAGExample: travel_question = f"{TRAVEL_QUESTION}{LLM_TIP}" travel_filepath = TRAVEL_DOC_PATH - print("[Before add docs]") + logger.info("[Before add docs]") await self.rag_pipeline(question=travel_question, print_title=False) - print("[After add docs]") + logger.info("[After add docs]") self.engine.add_docs([travel_filepath]) await self.rag_pipeline(question=travel_question, print_title=False) @@ -110,19 +111,19 @@ class RAGExample: player = Player(name="Mike") question = f"{player.rag_key()}" - print("[Before add objs]") + logger.info("[Before add objs]") await self._retrieve_and_print(question) - print("[After add objs]") + logger.info("[After add objs]") self.engine.add_objs([player]) nodes = await self._retrieve_and_print(question) - print("[Object Detail]") + logger.info("[Object Detail]") try: player: Player = nodes[0].metadata["obj"] - print(player.name) + logger.info(player.name) except Exception as e: - print(f"ERROR: nodes is empty, llm don't answer correctly, exception: {e}") + logger.info(f"ERROR: nodes is empty, llm don't answer correctly, exception: {e}") async def rag_ini_objs(self): """This example show how to from objs, will print something like: @@ -162,20 +163,20 @@ class RAGExample: @staticmethod def _print_title(title): - print(f"{'#'*50} {title} {'#'*50}") + logger.info(f"{'#'*30} {title} {'#'*30}") @staticmethod def _print_result(result, state="Retrieve"): """print retrieve or query result""" - print(f"{state} Result:") + logger.info(f"{state} Result:") if state == "Retrieve": for i, node in enumerate(result): - print(f"{i}. {node.text[:10]}..., {node.score}") - print() + logger.info(f"{i}. {node.text[:10]}..., {node.score}") + logger.info("") return - print(f"{result}\n") + logger.info(f"{result}\n") async def _retrieve_and_print(self, question): nodes = await self.engine.aretrieve(question) diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index cceb9dd03..9afee9b35 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -29,11 +29,11 @@ from llama_index.core.schema import ( from metagpt.rag.factories import ( get_index, get_rag_embedding, - get_rag_llm, get_rankers, get_retriever, ) from metagpt.rag.interface import RAGObject +from metagpt.rag.llm import get_rag_llm from metagpt.rag.retrievers.base import ModifiableRAGRetriever from metagpt.rag.schema import ( BaseIndexConfig, diff --git a/metagpt/rag/factories/__init__.py b/metagpt/rag/factories/__init__.py index df2d38502..d7fcc27ed 100644 --- a/metagpt/rag/factories/__init__.py +++ b/metagpt/rag/factories/__init__.py @@ -1,8 +1,7 @@ """RAG factories""" from metagpt.rag.factories.retriever import get_retriever from metagpt.rag.factories.ranker import get_rankers -from metagpt.rag.factories.llm import get_rag_llm from metagpt.rag.factories.embedding import get_rag_embedding from metagpt.rag.factories.index import get_index -__all__ = ["get_retriever", "get_rankers", "get_rag_llm", "get_rag_embedding", "get_index"] +__all__ = ["get_retriever", "get_rankers", "get_rag_embedding", "get_index"] diff --git a/metagpt/rag/factories/embedding.py b/metagpt/rag/factories/embedding.py index 67c2f3d06..ebabf7b8a 100644 --- a/metagpt/rag/factories/embedding.py +++ b/metagpt/rag/factories/embedding.py @@ -1,7 +1,4 @@ -"""RAG LLM Factory. - -The LLM of LlamaIndex and the LLM of MG are not the same. -""" +"""RAG Embedding Factory.""" from llama_index.core.embeddings import BaseEmbedding from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding from llama_index.embeddings.openai import OpenAIEmbedding @@ -12,7 +9,7 @@ from metagpt.rag.factories.base import GenericFactory class RAGEmbeddingFactory(GenericFactory): - """Create LlamaIndex LLM with MG config.""" + """Create LlamaIndex Embedding with MetaGPT's config.""" def __init__(self): creators = { diff --git a/metagpt/rag/factories/llm.py b/metagpt/rag/factories/llm.py deleted file mode 100644 index c5d12079e..000000000 --- a/metagpt/rag/factories/llm.py +++ /dev/null @@ -1,65 +0,0 @@ -"""RAG LLM Factory. - -The LLM of LlamaIndex and the LLM of MG are not the same. -""" -from llama_index.core.llms import LLM -from llama_index.llms.azure_openai import AzureOpenAI -from llama_index.llms.gemini import Gemini -from llama_index.llms.ollama import Ollama -from llama_index.llms.openai import OpenAI - -from metagpt.config2 import config -from metagpt.configs.llm_config import LLMType -from metagpt.rag.factories.base import GenericFactory - - -class RAGLLMFactory(GenericFactory): - """Create LlamaIndex LLM with MG config.""" - - def __init__(self): - creators = { - LLMType.OPENAI: self._create_openai, - LLMType.AZURE: self._create_azure, - LLMType.GEMINI: self._create_gemini, - LLMType.OLLAMA: self._create_ollama, - } - super().__init__(creators) - - def get_rag_llm(self, key: LLMType = None) -> LLM: - """Key is LLMType, default use config.llm.api_type.""" - return super().get_instance(key or config.llm.api_type) - - def _create_openai(self): - return OpenAI( - api_base=config.llm.base_url, - api_key=config.llm.api_key, - api_version=config.llm.api_version, - model=config.llm.model, - max_tokens=config.llm.max_token, - temperature=config.llm.temperature, - ) - - def _create_azure(self): - return AzureOpenAI( - azure_endpoint=config.llm.base_url, - api_key=config.llm.api_key, - api_version=config.llm.api_version, - deployment_name=config.llm.model, - max_tokens=config.llm.max_token, - temperature=config.llm.temperature, - ) - - def _create_gemini(self): - return Gemini( - api_base=config.llm.base_url, - api_key=config.llm.api_key, - model_name=config.llm.model, - max_tokens=config.llm.max_token, - temperature=config.llm.temperature, - ) - - def _create_ollama(self): - return Ollama(base_url=config.llm.base_url, model=config.llm.model, temperature=config.llm.temperature) - - -get_rag_llm = RAGLLMFactory().get_rag_llm diff --git a/metagpt/rag/llm.py b/metagpt/rag/llm.py new file mode 100644 index 000000000..81ac4e1b7 --- /dev/null +++ b/metagpt/rag/llm.py @@ -0,0 +1,48 @@ +"""RAG LLM.""" +from typing import Any + +from llama_index.core.llms import ( + CompletionResponse, + CompletionResponseGen, + CustomLLM, + LLMMetadata, +) +from llama_index.core.llms.callbacks import llm_completion_callback + +from metagpt.config2 import config +from metagpt.llm import LLM +from metagpt.provider.base_llm import BaseLLM +from metagpt.utils.async_helper import run_coroutine_in_new_loop + + +class RAGLLM(CustomLLM): + """LlamaIndex's LLM is different from MetaGPT's LLM. + + Inherit CustomLLM from llamaindex, making MetaGPT's LLM can be used by LlamaIndex. + """ + + model_infer: BaseLLM + model_name: str = config.llm.model + + @property + def metadata(self) -> LLMMetadata: + """Get LLM metadata.""" + return LLMMetadata(model_name=self.model_name) + + @llm_completion_callback() + def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: + return run_coroutine_in_new_loop(self.acomplete(prompt, **kwargs)) + + @llm_completion_callback() + async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse: + text = await self.model_infer.aask(msg=prompt, stream=False) + return CompletionResponse(text=text) + + @llm_completion_callback() + def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: + ... + + +def get_rag_llm(model_infer: BaseLLM = None): + """Get llm that can be used by LlamaIndex.""" + return RAGLLM(model_infer=model_infer or LLM()) diff --git a/metagpt/utils/async_helper.py b/metagpt/utils/async_helper.py new file mode 100644 index 000000000..ee440ef44 --- /dev/null +++ b/metagpt/utils/async_helper.py @@ -0,0 +1,22 @@ +import asyncio +import threading +from typing import Any + + +def run_coroutine_in_new_loop(coroutine) -> Any: + """Runs a coroutine in a new, separate event loop on a different thread. + + This function is useful when try to execute an async function within a sync function, but encounter the error `RuntimeError: This event loop is already running`. + """ + new_loop = asyncio.new_event_loop() + t = threading.Thread(target=lambda: new_loop.run_forever()) + t.start() + + future = asyncio.run_coroutine_threadsafe(coroutine, new_loop) + + try: + return future.result() + finally: + new_loop.call_soon_threadsafe(new_loop.stop) + t.join() + new_loop.close() diff --git a/requirements.txt b/requirements.txt index c5760899c..326fa8bb9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,9 +14,6 @@ llama-index-core==0.10.15 llama-index-embeddings-azure-openai==0.1.6 llama-index-embeddings-openai==0.1.5 llama-index-llms-azure-openai==0.1.4 -llama-index-llms-gemini==0.1.4 -llama-index-llms-ollama==0.1.2 -llama-index-llms-openai==0.1.5 llama-index-readers-file==0.1.4 llama-index-retrievers-bm25==0.1.3 llama-index-vector-stores-faiss==0.1.1 diff --git a/tests/metagpt/rag/factories/test_llm.py b/tests/metagpt/rag/factories/test_llm.py deleted file mode 100644 index 94e3a8f67..000000000 --- a/tests/metagpt/rag/factories/test_llm.py +++ /dev/null @@ -1,54 +0,0 @@ -import pytest -from llama_index.llms.azure_openai import AzureOpenAI -from llama_index.llms.gemini import Gemini -from llama_index.llms.ollama import Ollama -from llama_index.llms.openai import OpenAI - -from metagpt.configs.llm_config import LLMType -from metagpt.rag.factories.llm import RAGLLMFactory - - -class TestRAGLLMFactory: - @pytest.fixture(autouse=True) - def setup(self, mocker): - # Mock the config object for all tests in this class - self.mock_config = mocker.MagicMock() - self.mock_config.llm.api_type = LLMType.OPENAI - self.mock_config.llm.base_url = "http://example.com" - self.mock_config.llm.api_key = "test_api_key" - self.mock_config.llm.api_version = "v1" - self.mock_config.llm.model = "test_model" - self.mock_config.llm.max_token = 100 - self.mock_config.llm.temperature = 0.5 - mocker.patch("metagpt.rag.factories.llm.config", self.mock_config) - self.factory = RAGLLMFactory() - - @pytest.mark.parametrize( - "llm_type,expected_class", - [ - (LLMType.OPENAI, OpenAI), - (LLMType.AZURE, AzureOpenAI), - (LLMType.GEMINI, Gemini), - (LLMType.OLLAMA, Ollama), - ], - ) - def test_creates_correct_llm_instance(self, llm_type, expected_class, mocker): - # Mock the LLM constructors - mocker.patch.object(expected_class, "__init__", return_value=None) - instance = self.factory.get_rag_llm(key=llm_type) - assert isinstance(instance, expected_class) - expected_class.__init__.assert_called_once() - - def test_uses_default_llm_type_when_no_key_provided(self, mocker): - # Assume the default API type is OPENAI for this test - mock = mocker.patch.object(OpenAI, "__init__", return_value=None) - instance = self.factory.get_rag_llm() - assert isinstance(instance, OpenAI) - mock.assert_called_once_with( - api_base=self.mock_config.llm.base_url, - api_key=self.mock_config.llm.api_key, - api_version=self.mock_config.llm.api_version, - model=self.mock_config.llm.model, - max_tokens=self.mock_config.llm.max_token, - temperature=self.mock_config.llm.temperature, - )