mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-05 05:42:37 +02:00
replace rag llm factory with llamaindex custom llm
This commit is contained in:
parent
4712b2136b
commit
9fe9a4a2d1
9 changed files with 87 additions and 142 deletions
|
|
@ -29,11 +29,11 @@ from llama_index.core.schema import (
|
|||
from metagpt.rag.factories import (
|
||||
get_index,
|
||||
get_rag_embedding,
|
||||
get_rag_llm,
|
||||
get_rankers,
|
||||
get_retriever,
|
||||
)
|
||||
from metagpt.rag.interface import RAGObject
|
||||
from metagpt.rag.llm import get_rag_llm
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BaseIndexConfig,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
"""RAG factories"""
|
||||
from metagpt.rag.factories.retriever import get_retriever
|
||||
from metagpt.rag.factories.ranker import get_rankers
|
||||
from metagpt.rag.factories.llm import get_rag_llm
|
||||
from metagpt.rag.factories.embedding import get_rag_embedding
|
||||
from metagpt.rag.factories.index import get_index
|
||||
|
||||
__all__ = ["get_retriever", "get_rankers", "get_rag_llm", "get_rag_embedding", "get_index"]
|
||||
__all__ = ["get_retriever", "get_rankers", "get_rag_embedding", "get_index"]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,4 @@
|
|||
"""RAG LLM Factory.
|
||||
|
||||
The LLM of LlamaIndex and the LLM of MG are not the same.
|
||||
"""
|
||||
"""RAG Embedding Factory."""
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
|
@ -12,7 +9,7 @@ from metagpt.rag.factories.base import GenericFactory
|
|||
|
||||
|
||||
class RAGEmbeddingFactory(GenericFactory):
|
||||
"""Create LlamaIndex LLM with MG config."""
|
||||
"""Create LlamaIndex Embedding with MetaGPT's config."""
|
||||
|
||||
def __init__(self):
|
||||
creators = {
|
||||
|
|
|
|||
|
|
@ -1,65 +0,0 @@
|
|||
"""RAG LLM Factory.
|
||||
|
||||
The LLM of LlamaIndex and the LLM of MG are not the same.
|
||||
"""
|
||||
from llama_index.core.llms import LLM
|
||||
from llama_index.llms.azure_openai import AzureOpenAI
|
||||
from llama_index.llms.gemini import Gemini
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.base import GenericFactory
|
||||
|
||||
|
||||
class RAGLLMFactory(GenericFactory):
|
||||
"""Create LlamaIndex LLM with MG config."""
|
||||
|
||||
def __init__(self):
|
||||
creators = {
|
||||
LLMType.OPENAI: self._create_openai,
|
||||
LLMType.AZURE: self._create_azure,
|
||||
LLMType.GEMINI: self._create_gemini,
|
||||
LLMType.OLLAMA: self._create_ollama,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_rag_llm(self, key: LLMType = None) -> LLM:
|
||||
"""Key is LLMType, default use config.llm.api_type."""
|
||||
return super().get_instance(key or config.llm.api_type)
|
||||
|
||||
def _create_openai(self):
|
||||
return OpenAI(
|
||||
api_base=config.llm.base_url,
|
||||
api_key=config.llm.api_key,
|
||||
api_version=config.llm.api_version,
|
||||
model=config.llm.model,
|
||||
max_tokens=config.llm.max_token,
|
||||
temperature=config.llm.temperature,
|
||||
)
|
||||
|
||||
def _create_azure(self):
|
||||
return AzureOpenAI(
|
||||
azure_endpoint=config.llm.base_url,
|
||||
api_key=config.llm.api_key,
|
||||
api_version=config.llm.api_version,
|
||||
deployment_name=config.llm.model,
|
||||
max_tokens=config.llm.max_token,
|
||||
temperature=config.llm.temperature,
|
||||
)
|
||||
|
||||
def _create_gemini(self):
|
||||
return Gemini(
|
||||
api_base=config.llm.base_url,
|
||||
api_key=config.llm.api_key,
|
||||
model_name=config.llm.model,
|
||||
max_tokens=config.llm.max_token,
|
||||
temperature=config.llm.temperature,
|
||||
)
|
||||
|
||||
def _create_ollama(self):
|
||||
return Ollama(base_url=config.llm.base_url, model=config.llm.model, temperature=config.llm.temperature)
|
||||
|
||||
|
||||
get_rag_llm = RAGLLMFactory().get_rag_llm
|
||||
48
metagpt/rag/llm.py
Normal file
48
metagpt/rag/llm.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""RAG LLM."""
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.llms import (
|
||||
CompletionResponse,
|
||||
CompletionResponseGen,
|
||||
CustomLLM,
|
||||
LLMMetadata,
|
||||
)
|
||||
from llama_index.core.llms.callbacks import llm_completion_callback
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.async_helper import run_coroutine_in_new_loop
|
||||
|
||||
|
||||
class RAGLLM(CustomLLM):
|
||||
"""LlamaIndex's LLM is different from MetaGPT's LLM.
|
||||
|
||||
Inherit CustomLLM from llamaindex, making MetaGPT's LLM can be used by LlamaIndex.
|
||||
"""
|
||||
|
||||
model_infer: BaseLLM
|
||||
model_name: str = config.llm.model
|
||||
|
||||
@property
|
||||
def metadata(self) -> LLMMetadata:
|
||||
"""Get LLM metadata."""
|
||||
return LLMMetadata(model_name=self.model_name)
|
||||
|
||||
@llm_completion_callback()
|
||||
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
|
||||
return run_coroutine_in_new_loop(self.acomplete(prompt, **kwargs))
|
||||
|
||||
@llm_completion_callback()
|
||||
async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse:
|
||||
text = await self.model_infer.aask(msg=prompt, stream=False)
|
||||
return CompletionResponse(text=text)
|
||||
|
||||
@llm_completion_callback()
|
||||
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
|
||||
...
|
||||
|
||||
|
||||
def get_rag_llm(model_infer: BaseLLM = None):
|
||||
"""Get llm that can be used by LlamaIndex."""
|
||||
return RAGLLM(model_infer=model_infer or LLM())
|
||||
Loading…
Add table
Add a link
Reference in a new issue