mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-03 04:42:38 +02:00
replace rag llm factory with llamaindex custom llm
This commit is contained in:
parent
4712b2136b
commit
9fe9a4a2d1
9 changed files with 87 additions and 142 deletions
|
|
@ -4,6 +4,7 @@ import asyncio
|
|||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.const import DATA_PATH, EXAMPLE_DATA_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.schema import (
|
||||
BM25RetrieverConfig,
|
||||
|
|
@ -85,10 +86,10 @@ class RAGExample:
|
|||
travel_question = f"{TRAVEL_QUESTION}{LLM_TIP}"
|
||||
travel_filepath = TRAVEL_DOC_PATH
|
||||
|
||||
print("[Before add docs]")
|
||||
logger.info("[Before add docs]")
|
||||
await self.rag_pipeline(question=travel_question, print_title=False)
|
||||
|
||||
print("[After add docs]")
|
||||
logger.info("[After add docs]")
|
||||
self.engine.add_docs([travel_filepath])
|
||||
await self.rag_pipeline(question=travel_question, print_title=False)
|
||||
|
||||
|
|
@ -110,19 +111,19 @@ class RAGExample:
|
|||
player = Player(name="Mike")
|
||||
question = f"{player.rag_key()}"
|
||||
|
||||
print("[Before add objs]")
|
||||
logger.info("[Before add objs]")
|
||||
await self._retrieve_and_print(question)
|
||||
|
||||
print("[After add objs]")
|
||||
logger.info("[After add objs]")
|
||||
self.engine.add_objs([player])
|
||||
nodes = await self._retrieve_and_print(question)
|
||||
|
||||
print("[Object Detail]")
|
||||
logger.info("[Object Detail]")
|
||||
try:
|
||||
player: Player = nodes[0].metadata["obj"]
|
||||
print(player.name)
|
||||
logger.info(player.name)
|
||||
except Exception as e:
|
||||
print(f"ERROR: nodes is empty, llm don't answer correctly, exception: {e}")
|
||||
logger.info(f"ERROR: nodes is empty, llm don't answer correctly, exception: {e}")
|
||||
|
||||
async def rag_ini_objs(self):
|
||||
"""This example show how to from objs, will print something like:
|
||||
|
|
@ -162,20 +163,20 @@ class RAGExample:
|
|||
|
||||
@staticmethod
|
||||
def _print_title(title):
|
||||
print(f"{'#'*50} {title} {'#'*50}")
|
||||
logger.info(f"{'#'*30} {title} {'#'*30}")
|
||||
|
||||
@staticmethod
|
||||
def _print_result(result, state="Retrieve"):
|
||||
"""print retrieve or query result"""
|
||||
print(f"{state} Result:")
|
||||
logger.info(f"{state} Result:")
|
||||
|
||||
if state == "Retrieve":
|
||||
for i, node in enumerate(result):
|
||||
print(f"{i}. {node.text[:10]}..., {node.score}")
|
||||
print()
|
||||
logger.info(f"{i}. {node.text[:10]}..., {node.score}")
|
||||
logger.info("")
|
||||
return
|
||||
|
||||
print(f"{result}\n")
|
||||
logger.info(f"{result}\n")
|
||||
|
||||
async def _retrieve_and_print(self, question):
|
||||
nodes = await self.engine.aretrieve(question)
|
||||
|
|
|
|||
|
|
@ -29,11 +29,11 @@ from llama_index.core.schema import (
|
|||
from metagpt.rag.factories import (
|
||||
get_index,
|
||||
get_rag_embedding,
|
||||
get_rag_llm,
|
||||
get_rankers,
|
||||
get_retriever,
|
||||
)
|
||||
from metagpt.rag.interface import RAGObject
|
||||
from metagpt.rag.llm import get_rag_llm
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BaseIndexConfig,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
"""RAG factories"""
|
||||
from metagpt.rag.factories.retriever import get_retriever
|
||||
from metagpt.rag.factories.ranker import get_rankers
|
||||
from metagpt.rag.factories.llm import get_rag_llm
|
||||
from metagpt.rag.factories.embedding import get_rag_embedding
|
||||
from metagpt.rag.factories.index import get_index
|
||||
|
||||
__all__ = ["get_retriever", "get_rankers", "get_rag_llm", "get_rag_embedding", "get_index"]
|
||||
__all__ = ["get_retriever", "get_rankers", "get_rag_embedding", "get_index"]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,4 @@
|
|||
"""RAG LLM Factory.
|
||||
|
||||
The LLM of LlamaIndex and the LLM of MG are not the same.
|
||||
"""
|
||||
"""RAG Embedding Factory."""
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
|
@ -12,7 +9,7 @@ from metagpt.rag.factories.base import GenericFactory
|
|||
|
||||
|
||||
class RAGEmbeddingFactory(GenericFactory):
|
||||
"""Create LlamaIndex LLM with MG config."""
|
||||
"""Create LlamaIndex Embedding with MetaGPT's config."""
|
||||
|
||||
def __init__(self):
|
||||
creators = {
|
||||
|
|
|
|||
|
|
@ -1,65 +0,0 @@
|
|||
"""RAG LLM Factory.
|
||||
|
||||
The LLM of LlamaIndex and the LLM of MG are not the same.
|
||||
"""
|
||||
from llama_index.core.llms import LLM
|
||||
from llama_index.llms.azure_openai import AzureOpenAI
|
||||
from llama_index.llms.gemini import Gemini
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.base import GenericFactory
|
||||
|
||||
|
||||
class RAGLLMFactory(GenericFactory):
|
||||
"""Create LlamaIndex LLM with MG config."""
|
||||
|
||||
def __init__(self):
|
||||
creators = {
|
||||
LLMType.OPENAI: self._create_openai,
|
||||
LLMType.AZURE: self._create_azure,
|
||||
LLMType.GEMINI: self._create_gemini,
|
||||
LLMType.OLLAMA: self._create_ollama,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_rag_llm(self, key: LLMType = None) -> LLM:
|
||||
"""Key is LLMType, default use config.llm.api_type."""
|
||||
return super().get_instance(key or config.llm.api_type)
|
||||
|
||||
def _create_openai(self):
|
||||
return OpenAI(
|
||||
api_base=config.llm.base_url,
|
||||
api_key=config.llm.api_key,
|
||||
api_version=config.llm.api_version,
|
||||
model=config.llm.model,
|
||||
max_tokens=config.llm.max_token,
|
||||
temperature=config.llm.temperature,
|
||||
)
|
||||
|
||||
def _create_azure(self):
|
||||
return AzureOpenAI(
|
||||
azure_endpoint=config.llm.base_url,
|
||||
api_key=config.llm.api_key,
|
||||
api_version=config.llm.api_version,
|
||||
deployment_name=config.llm.model,
|
||||
max_tokens=config.llm.max_token,
|
||||
temperature=config.llm.temperature,
|
||||
)
|
||||
|
||||
def _create_gemini(self):
|
||||
return Gemini(
|
||||
api_base=config.llm.base_url,
|
||||
api_key=config.llm.api_key,
|
||||
model_name=config.llm.model,
|
||||
max_tokens=config.llm.max_token,
|
||||
temperature=config.llm.temperature,
|
||||
)
|
||||
|
||||
def _create_ollama(self):
|
||||
return Ollama(base_url=config.llm.base_url, model=config.llm.model, temperature=config.llm.temperature)
|
||||
|
||||
|
||||
get_rag_llm = RAGLLMFactory().get_rag_llm
|
||||
48
metagpt/rag/llm.py
Normal file
48
metagpt/rag/llm.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""RAG LLM."""
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.llms import (
|
||||
CompletionResponse,
|
||||
CompletionResponseGen,
|
||||
CustomLLM,
|
||||
LLMMetadata,
|
||||
)
|
||||
from llama_index.core.llms.callbacks import llm_completion_callback
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.async_helper import run_coroutine_in_new_loop
|
||||
|
||||
|
||||
class RAGLLM(CustomLLM):
|
||||
"""LlamaIndex's LLM is different from MetaGPT's LLM.
|
||||
|
||||
Inherit CustomLLM from llamaindex, making MetaGPT's LLM can be used by LlamaIndex.
|
||||
"""
|
||||
|
||||
model_infer: BaseLLM
|
||||
model_name: str = config.llm.model
|
||||
|
||||
@property
|
||||
def metadata(self) -> LLMMetadata:
|
||||
"""Get LLM metadata."""
|
||||
return LLMMetadata(model_name=self.model_name)
|
||||
|
||||
@llm_completion_callback()
|
||||
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
|
||||
return run_coroutine_in_new_loop(self.acomplete(prompt, **kwargs))
|
||||
|
||||
@llm_completion_callback()
|
||||
async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse:
|
||||
text = await self.model_infer.aask(msg=prompt, stream=False)
|
||||
return CompletionResponse(text=text)
|
||||
|
||||
@llm_completion_callback()
|
||||
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
|
||||
...
|
||||
|
||||
|
||||
def get_rag_llm(model_infer: BaseLLM = None):
|
||||
"""Get llm that can be used by LlamaIndex."""
|
||||
return RAGLLM(model_infer=model_infer or LLM())
|
||||
22
metagpt/utils/async_helper.py
Normal file
22
metagpt/utils/async_helper.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
import asyncio
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
|
||||
def run_coroutine_in_new_loop(coroutine) -> Any:
|
||||
"""Runs a coroutine in a new, separate event loop on a different thread.
|
||||
|
||||
This function is useful when try to execute an async function within a sync function, but encounter the error `RuntimeError: This event loop is already running`.
|
||||
"""
|
||||
new_loop = asyncio.new_event_loop()
|
||||
t = threading.Thread(target=lambda: new_loop.run_forever())
|
||||
t.start()
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(coroutine, new_loop)
|
||||
|
||||
try:
|
||||
return future.result()
|
||||
finally:
|
||||
new_loop.call_soon_threadsafe(new_loop.stop)
|
||||
t.join()
|
||||
new_loop.close()
|
||||
|
|
@ -14,9 +14,6 @@ llama-index-core==0.10.15
|
|||
llama-index-embeddings-azure-openai==0.1.6
|
||||
llama-index-embeddings-openai==0.1.5
|
||||
llama-index-llms-azure-openai==0.1.4
|
||||
llama-index-llms-gemini==0.1.4
|
||||
llama-index-llms-ollama==0.1.2
|
||||
llama-index-llms-openai==0.1.5
|
||||
llama-index-readers-file==0.1.4
|
||||
llama-index-retrievers-bm25==0.1.3
|
||||
llama-index-vector-stores-faiss==0.1.1
|
||||
|
|
|
|||
|
|
@ -1,54 +0,0 @@
|
|||
import pytest
|
||||
from llama_index.llms.azure_openai import AzureOpenAI
|
||||
from llama_index.llms.gemini import Gemini
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.llm import RAGLLMFactory
|
||||
|
||||
|
||||
class TestRAGLLMFactory:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mocker):
|
||||
# Mock the config object for all tests in this class
|
||||
self.mock_config = mocker.MagicMock()
|
||||
self.mock_config.llm.api_type = LLMType.OPENAI
|
||||
self.mock_config.llm.base_url = "http://example.com"
|
||||
self.mock_config.llm.api_key = "test_api_key"
|
||||
self.mock_config.llm.api_version = "v1"
|
||||
self.mock_config.llm.model = "test_model"
|
||||
self.mock_config.llm.max_token = 100
|
||||
self.mock_config.llm.temperature = 0.5
|
||||
mocker.patch("metagpt.rag.factories.llm.config", self.mock_config)
|
||||
self.factory = RAGLLMFactory()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_type,expected_class",
|
||||
[
|
||||
(LLMType.OPENAI, OpenAI),
|
||||
(LLMType.AZURE, AzureOpenAI),
|
||||
(LLMType.GEMINI, Gemini),
|
||||
(LLMType.OLLAMA, Ollama),
|
||||
],
|
||||
)
|
||||
def test_creates_correct_llm_instance(self, llm_type, expected_class, mocker):
|
||||
# Mock the LLM constructors
|
||||
mocker.patch.object(expected_class, "__init__", return_value=None)
|
||||
instance = self.factory.get_rag_llm(key=llm_type)
|
||||
assert isinstance(instance, expected_class)
|
||||
expected_class.__init__.assert_called_once()
|
||||
|
||||
def test_uses_default_llm_type_when_no_key_provided(self, mocker):
|
||||
# Assume the default API type is OPENAI for this test
|
||||
mock = mocker.patch.object(OpenAI, "__init__", return_value=None)
|
||||
instance = self.factory.get_rag_llm()
|
||||
assert isinstance(instance, OpenAI)
|
||||
mock.assert_called_once_with(
|
||||
api_base=self.mock_config.llm.base_url,
|
||||
api_key=self.mock_config.llm.api_key,
|
||||
api_version=self.mock_config.llm.api_version,
|
||||
model=self.mock_config.llm.model,
|
||||
max_tokens=self.mock_config.llm.max_token,
|
||||
temperature=self.mock_config.llm.temperature,
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue