From b9089380bdc8722581d8a9056c3456479a0161f3 Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Thu, 15 Aug 2024 17:04:56 +0800 Subject: [PATCH] fix get rag llm error --- metagpt/rag/factories/llm.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/metagpt/rag/factories/llm.py b/metagpt/rag/factories/llm.py index 5d27cde3a..59f6db4d9 100644 --- a/metagpt/rag/factories/llm.py +++ b/metagpt/rag/factories/llm.py @@ -10,7 +10,7 @@ from llama_index.core.llms import ( LLMMetadata, ) from llama_index.core.llms.callbacks import llm_completion_callback -from pydantic import Field, model_validator +from pydantic import Field from metagpt.config2 import Config from metagpt.llm import LLM @@ -30,19 +30,30 @@ class RAGLLM(CustomLLM): num_output: int = -1 model_name: str = "" - @model_validator(mode="after") - def update_from_config(self): + def __init__( + self, + model_infer: BaseLLM, + context_window: int = -1, + num_output: int = -1, + model_name: str = "", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) config = Config.default() - if self.context_window < 0: - self.context_window = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW) + if context_window < 0: + context_window = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW) - if self.num_output < 0: - self.num_output = config.llm.max_token + if num_output < 0: + num_output = config.llm.max_token - if not self.model_name: - self.model_name = config.llm.model + if not model_name: + model_name = config.llm.model - return self + self.model_infer = model_infer + self.context_window = context_window + self.num_output = num_output + self.model_name = model_name @property def metadata(self) -> LLMMetadata: