fix get rag llm error

This commit is contained in:
shenchucheng 2024-08-15 17:04:56 +08:00
parent 3f53d34cd0
commit b9089380bd

View file

@ -10,7 +10,7 @@ from llama_index.core.llms import (
LLMMetadata,
)
from llama_index.core.llms.callbacks import llm_completion_callback
from pydantic import Field, model_validator
from pydantic import Field
from metagpt.config2 import Config
from metagpt.llm import LLM
@ -30,19 +30,30 @@ class RAGLLM(CustomLLM):
num_output: int = -1
model_name: str = ""
@model_validator(mode="after")
def update_from_config(self):
def __init__(
self,
model_infer: BaseLLM,
context_window: int = -1,
num_output: int = -1,
model_name: str = "",
*args,
**kwargs
):
super().__init__(*args, **kwargs)
config = Config.default()
if self.context_window < 0:
self.context_window = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW)
if context_window < 0:
context_window = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW)
if self.num_output < 0:
self.num_output = config.llm.max_token
if num_output < 0:
num_output = config.llm.max_token
if not self.model_name:
self.model_name = config.llm.model
if not model_name:
model_name = config.llm.model
return self
self.model_infer = model_infer
self.context_window = context_window
self.num_output = num_output
self.model_name = model_name
@property
def metadata(self) -> LLMMetadata: