replace rag llm factory with llamaindex custom llm

This commit is contained in:
seehi 2024-03-08 20:19:28 +08:00
parent 4712b2136b
commit 9fe9a4a2d1
9 changed files with 87 additions and 142 deletions

View file

@ -4,6 +4,7 @@ import asyncio
from pydantic import BaseModel
from metagpt.const import DATA_PATH, EXAMPLE_DATA_PATH
from metagpt.logs import logger
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import (
BM25RetrieverConfig,
@ -85,10 +86,10 @@ class RAGExample:
travel_question = f"{TRAVEL_QUESTION}{LLM_TIP}"
travel_filepath = TRAVEL_DOC_PATH
print("[Before add docs]")
logger.info("[Before add docs]")
await self.rag_pipeline(question=travel_question, print_title=False)
print("[After add docs]")
logger.info("[After add docs]")
self.engine.add_docs([travel_filepath])
await self.rag_pipeline(question=travel_question, print_title=False)
@ -110,19 +111,19 @@ class RAGExample:
player = Player(name="Mike")
question = f"{player.rag_key()}"
print("[Before add objs]")
logger.info("[Before add objs]")
await self._retrieve_and_print(question)
print("[After add objs]")
logger.info("[After add objs]")
self.engine.add_objs([player])
nodes = await self._retrieve_and_print(question)
print("[Object Detail]")
logger.info("[Object Detail]")
try:
player: Player = nodes[0].metadata["obj"]
print(player.name)
logger.info(player.name)
except Exception as e:
print(f"ERROR: nodes is empty, llm don't answer correctly, exception: {e}")
logger.info(f"ERROR: nodes is empty, llm don't answer correctly, exception: {e}")
async def rag_ini_objs(self):
"""This example show how to from objs, will print something like:
@ -162,20 +163,20 @@ class RAGExample:
@staticmethod
def _print_title(title):
print(f"{'#'*50} {title} {'#'*50}")
logger.info(f"{'#'*30} {title} {'#'*30}")
@staticmethod
def _print_result(result, state="Retrieve"):
"""print retrieve or query result"""
print(f"{state} Result:")
logger.info(f"{state} Result:")
if state == "Retrieve":
for i, node in enumerate(result):
print(f"{i}. {node.text[:10]}..., {node.score}")
print()
logger.info(f"{i}. {node.text[:10]}..., {node.score}")
logger.info("")
return
print(f"{result}\n")
logger.info(f"{result}\n")
async def _retrieve_and_print(self, question):
nodes = await self.engine.aretrieve(question)

View file

@ -29,11 +29,11 @@ from llama_index.core.schema import (
from metagpt.rag.factories import (
get_index,
get_rag_embedding,
get_rag_llm,
get_rankers,
get_retriever,
)
from metagpt.rag.interface import RAGObject
from metagpt.rag.llm import get_rag_llm
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
from metagpt.rag.schema import (
BaseIndexConfig,

View file

@ -1,8 +1,7 @@
"""RAG factories"""
from metagpt.rag.factories.retriever import get_retriever
from metagpt.rag.factories.ranker import get_rankers
from metagpt.rag.factories.llm import get_rag_llm
from metagpt.rag.factories.embedding import get_rag_embedding
from metagpt.rag.factories.index import get_index
__all__ = ["get_retriever", "get_rankers", "get_rag_llm", "get_rag_embedding", "get_index"]
__all__ = ["get_retriever", "get_rankers", "get_rag_embedding", "get_index"]

View file

@ -1,7 +1,4 @@
"""RAG LLM Factory.
The LLM of LlamaIndex and the LLM of MG are not the same.
"""
"""RAG Embedding Factory."""
from llama_index.core.embeddings import BaseEmbedding
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
@ -12,7 +9,7 @@ from metagpt.rag.factories.base import GenericFactory
class RAGEmbeddingFactory(GenericFactory):
"""Create LlamaIndex LLM with MG config."""
"""Create LlamaIndex Embedding with MetaGPT's config."""
def __init__(self):
creators = {

View file

@ -1,65 +0,0 @@
"""RAG LLM Factory.
The LLM of LlamaIndex and the LLM of MG are not the same.
"""
from llama_index.core.llms import LLM
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.llms.gemini import Gemini
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI
from metagpt.config2 import config
from metagpt.configs.llm_config import LLMType
from metagpt.rag.factories.base import GenericFactory
class RAGLLMFactory(GenericFactory):
"""Create LlamaIndex LLM with MG config."""
def __init__(self):
creators = {
LLMType.OPENAI: self._create_openai,
LLMType.AZURE: self._create_azure,
LLMType.GEMINI: self._create_gemini,
LLMType.OLLAMA: self._create_ollama,
}
super().__init__(creators)
def get_rag_llm(self, key: LLMType = None) -> LLM:
"""Key is LLMType, default use config.llm.api_type."""
return super().get_instance(key or config.llm.api_type)
def _create_openai(self):
return OpenAI(
api_base=config.llm.base_url,
api_key=config.llm.api_key,
api_version=config.llm.api_version,
model=config.llm.model,
max_tokens=config.llm.max_token,
temperature=config.llm.temperature,
)
def _create_azure(self):
return AzureOpenAI(
azure_endpoint=config.llm.base_url,
api_key=config.llm.api_key,
api_version=config.llm.api_version,
deployment_name=config.llm.model,
max_tokens=config.llm.max_token,
temperature=config.llm.temperature,
)
def _create_gemini(self):
return Gemini(
api_base=config.llm.base_url,
api_key=config.llm.api_key,
model_name=config.llm.model,
max_tokens=config.llm.max_token,
temperature=config.llm.temperature,
)
def _create_ollama(self):
return Ollama(base_url=config.llm.base_url, model=config.llm.model, temperature=config.llm.temperature)
get_rag_llm = RAGLLMFactory().get_rag_llm

48
metagpt/rag/llm.py Normal file
View file

@ -0,0 +1,48 @@
"""RAG LLM."""
from typing import Any
from llama_index.core.llms import (
CompletionResponse,
CompletionResponseGen,
CustomLLM,
LLMMetadata,
)
from llama_index.core.llms.callbacks import llm_completion_callback
from metagpt.config2 import config
from metagpt.llm import LLM
from metagpt.provider.base_llm import BaseLLM
from metagpt.utils.async_helper import run_coroutine_in_new_loop
class RAGLLM(CustomLLM):
"""LlamaIndex's LLM is different from MetaGPT's LLM.
Inherit CustomLLM from llamaindex, making MetaGPT's LLM can be used by LlamaIndex.
"""
model_infer: BaseLLM
model_name: str = config.llm.model
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(model_name=self.model_name)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
return run_coroutine_in_new_loop(self.acomplete(prompt, **kwargs))
@llm_completion_callback()
async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse:
text = await self.model_infer.aask(msg=prompt, stream=False)
return CompletionResponse(text=text)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
...
def get_rag_llm(model_infer: BaseLLM = None):
"""Get llm that can be used by LlamaIndex."""
return RAGLLM(model_infer=model_infer or LLM())

View file

@ -0,0 +1,22 @@
import asyncio
import threading
from typing import Any
def run_coroutine_in_new_loop(coroutine) -> Any:
"""Runs a coroutine in a new, separate event loop on a different thread.
This function is useful when try to execute an async function within a sync function, but encounter the error `RuntimeError: This event loop is already running`.
"""
new_loop = asyncio.new_event_loop()
t = threading.Thread(target=lambda: new_loop.run_forever())
t.start()
future = asyncio.run_coroutine_threadsafe(coroutine, new_loop)
try:
return future.result()
finally:
new_loop.call_soon_threadsafe(new_loop.stop)
t.join()
new_loop.close()

View file

@ -14,9 +14,6 @@ llama-index-core==0.10.15
llama-index-embeddings-azure-openai==0.1.6
llama-index-embeddings-openai==0.1.5
llama-index-llms-azure-openai==0.1.4
llama-index-llms-gemini==0.1.4
llama-index-llms-ollama==0.1.2
llama-index-llms-openai==0.1.5
llama-index-readers-file==0.1.4
llama-index-retrievers-bm25==0.1.3
llama-index-vector-stores-faiss==0.1.1

View file

@ -1,54 +0,0 @@
import pytest
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.llms.gemini import Gemini
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI
from metagpt.configs.llm_config import LLMType
from metagpt.rag.factories.llm import RAGLLMFactory
class TestRAGLLMFactory:
@pytest.fixture(autouse=True)
def setup(self, mocker):
# Mock the config object for all tests in this class
self.mock_config = mocker.MagicMock()
self.mock_config.llm.api_type = LLMType.OPENAI
self.mock_config.llm.base_url = "http://example.com"
self.mock_config.llm.api_key = "test_api_key"
self.mock_config.llm.api_version = "v1"
self.mock_config.llm.model = "test_model"
self.mock_config.llm.max_token = 100
self.mock_config.llm.temperature = 0.5
mocker.patch("metagpt.rag.factories.llm.config", self.mock_config)
self.factory = RAGLLMFactory()
@pytest.mark.parametrize(
"llm_type,expected_class",
[
(LLMType.OPENAI, OpenAI),
(LLMType.AZURE, AzureOpenAI),
(LLMType.GEMINI, Gemini),
(LLMType.OLLAMA, Ollama),
],
)
def test_creates_correct_llm_instance(self, llm_type, expected_class, mocker):
# Mock the LLM constructors
mocker.patch.object(expected_class, "__init__", return_value=None)
instance = self.factory.get_rag_llm(key=llm_type)
assert isinstance(instance, expected_class)
expected_class.__init__.assert_called_once()
def test_uses_default_llm_type_when_no_key_provided(self, mocker):
# Assume the default API type is OPENAI for this test
mock = mocker.patch.object(OpenAI, "__init__", return_value=None)
instance = self.factory.get_rag_llm()
assert isinstance(instance, OpenAI)
mock.assert_called_once_with(
api_base=self.mock_config.llm.base_url,
api_key=self.mock_config.llm.api_key,
api_version=self.mock_config.llm.api_version,
model=self.mock_config.llm.model,
max_tokens=self.mock_config.llm.max_token,
temperature=self.mock_config.llm.temperature,
)