Merge branch 'fix-rag' into 'mgx_ops'

Fix rag

See merge request pub/MetaGPT!317
This commit is contained in:
张雷 2024-08-15 13:12:09 +00:00
commit aa8e2fa8c3
5 changed files with 34 additions and 16 deletions

View file

@ -29,7 +29,7 @@ class RAGEmbeddingFactory(GenericFactory):
LLMType.AZURE: self._create_azure,
}
super().__init__(creators)
self.config = config if self.config else Config.default()
self.config = config if config else Config.default()
def get_rag_embedding(self, key: EmbeddingType = None) -> BaseEmbedding:
"""Key is EmbeddingType."""

View file

@ -10,7 +10,7 @@ from llama_index.core.llms import (
LLMMetadata,
)
from llama_index.core.llms.callbacks import llm_completion_callback
from pydantic import Field, model_validator
from pydantic import Field
from metagpt.config2 import Config
from metagpt.llm import LLM
@ -30,19 +30,30 @@ class RAGLLM(CustomLLM):
num_output: int = -1
model_name: str = ""
@model_validator(mode="after")
def update_from_config(self):
def __init__(
self,
model_infer: BaseLLM,
context_window: int = -1,
num_output: int = -1,
model_name: str = "",
*args,
**kwargs
):
super().__init__(*args, **kwargs)
config = Config.default()
if self.context_window < 0:
self.context_window = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW)
if context_window < 0:
context_window = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW)
if self.num_output < 0:
self.num_output = config.llm.max_token
if num_output < 0:
num_output = config.llm.max_token
if not self.model_name:
self.model_name = config.llm.model
if not model_name:
model_name = config.llm.model
return self
self.model_infer = model_infer
self.context_window = context_window
self.num_output = num_output
self.model_name = model_name
@property
def metadata(self) -> LLMMetadata: