replace rag llm factory with llamaindex custom llm

This commit is contained in:
seehi 2024-03-08 20:19:28 +08:00
parent 4712b2136b
commit 9fe9a4a2d1
9 changed files with 87 additions and 142 deletions

View file

@ -29,11 +29,11 @@ from llama_index.core.schema import (
from metagpt.rag.factories import (
get_index,
get_rag_embedding,
get_rag_llm,
get_rankers,
get_retriever,
)
from metagpt.rag.interface import RAGObject
from metagpt.rag.llm import get_rag_llm
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
from metagpt.rag.schema import (
BaseIndexConfig,

View file

@ -1,8 +1,7 @@
"""RAG factories"""
from metagpt.rag.factories.retriever import get_retriever
from metagpt.rag.factories.ranker import get_rankers
from metagpt.rag.factories.llm import get_rag_llm
from metagpt.rag.factories.embedding import get_rag_embedding
from metagpt.rag.factories.index import get_index
__all__ = ["get_retriever", "get_rankers", "get_rag_llm", "get_rag_embedding", "get_index"]
__all__ = ["get_retriever", "get_rankers", "get_rag_embedding", "get_index"]

View file

@ -1,7 +1,4 @@
"""RAG LLM Factory.
The LLM of LlamaIndex and the LLM of MG are not the same.
"""
"""RAG Embedding Factory."""
from llama_index.core.embeddings import BaseEmbedding
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
@ -12,7 +9,7 @@ from metagpt.rag.factories.base import GenericFactory
class RAGEmbeddingFactory(GenericFactory):
"""Create LlamaIndex LLM with MG config."""
"""Create LlamaIndex Embedding with MetaGPT's config."""
def __init__(self):
creators = {

View file

@ -1,65 +0,0 @@
"""RAG LLM Factory.
The LLM of LlamaIndex and the LLM of MG are not the same.
"""
from llama_index.core.llms import LLM
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.llms.gemini import Gemini
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI
from metagpt.config2 import config
from metagpt.configs.llm_config import LLMType
from metagpt.rag.factories.base import GenericFactory
class RAGLLMFactory(GenericFactory):
"""Create LlamaIndex LLM with MG config."""
def __init__(self):
creators = {
LLMType.OPENAI: self._create_openai,
LLMType.AZURE: self._create_azure,
LLMType.GEMINI: self._create_gemini,
LLMType.OLLAMA: self._create_ollama,
}
super().__init__(creators)
def get_rag_llm(self, key: LLMType = None) -> LLM:
"""Key is LLMType, default use config.llm.api_type."""
return super().get_instance(key or config.llm.api_type)
def _create_openai(self):
return OpenAI(
api_base=config.llm.base_url,
api_key=config.llm.api_key,
api_version=config.llm.api_version,
model=config.llm.model,
max_tokens=config.llm.max_token,
temperature=config.llm.temperature,
)
def _create_azure(self):
return AzureOpenAI(
azure_endpoint=config.llm.base_url,
api_key=config.llm.api_key,
api_version=config.llm.api_version,
deployment_name=config.llm.model,
max_tokens=config.llm.max_token,
temperature=config.llm.temperature,
)
def _create_gemini(self):
return Gemini(
api_base=config.llm.base_url,
api_key=config.llm.api_key,
model_name=config.llm.model,
max_tokens=config.llm.max_token,
temperature=config.llm.temperature,
)
def _create_ollama(self):
return Ollama(base_url=config.llm.base_url, model=config.llm.model, temperature=config.llm.temperature)
get_rag_llm = RAGLLMFactory().get_rag_llm

48
metagpt/rag/llm.py Normal file
View file

@ -0,0 +1,48 @@
"""RAG LLM."""
from typing import Any
from llama_index.core.llms import (
CompletionResponse,
CompletionResponseGen,
CustomLLM,
LLMMetadata,
)
from llama_index.core.llms.callbacks import llm_completion_callback
from metagpt.config2 import config
from metagpt.llm import LLM
from metagpt.provider.base_llm import BaseLLM
from metagpt.utils.async_helper import run_coroutine_in_new_loop
class RAGLLM(CustomLLM):
"""LlamaIndex's LLM is different from MetaGPT's LLM.
Inherit CustomLLM from llamaindex, making MetaGPT's LLM can be used by LlamaIndex.
"""
model_infer: BaseLLM
model_name: str = config.llm.model
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(model_name=self.model_name)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
return run_coroutine_in_new_loop(self.acomplete(prompt, **kwargs))
@llm_completion_callback()
async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse:
text = await self.model_infer.aask(msg=prompt, stream=False)
return CompletionResponse(text=text)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
...
def get_rag_llm(model_infer: BaseLLM = None):
"""Get llm that can be used by LlamaIndex."""
return RAGLLM(model_infer=model_infer or LLM())