diff --git a/metagpt/rag/llm.py b/metagpt/rag/llm.py index 9c946698a..1cdbab14d 100644 --- a/metagpt/rag/llm.py +++ b/metagpt/rag/llm.py @@ -2,6 +2,7 @@ from typing import Any +from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW from llama_index.core.llms import ( CompletionResponse, CompletionResponseGen, @@ -9,11 +10,13 @@ from llama_index.core.llms import ( LLMMetadata, ) from llama_index.core.llms.callbacks import llm_completion_callback +from pydantic import Field from metagpt.config2 import config from metagpt.llm import LLM from metagpt.provider.base_llm import BaseLLM from metagpt.utils.async_helper import run_coroutine_in_new_loop +from metagpt.utils.token_counter import TOKEN_MAX class RAGLLM(CustomLLM): @@ -22,13 +25,15 @@ class RAGLLM(CustomLLM): Inherit CustomLLM from llamaindex, making MetaGPT's LLM can be used by LlamaIndex. """ - model_infer: BaseLLM + model_infer: BaseLLM = Field(..., description="The MetaGPT's LLM.") + context_window: int = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW) + num_output: int = config.llm.max_token model_name: str = config.llm.model @property def metadata(self) -> LLMMetadata: """Get LLM metadata.""" - return LLMMetadata(model_name=self.model_name) + return LLMMetadata(context_window=self.context_window, num_output=self.num_output, model_name=self.model_name) @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: