diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 62de34ef4..1b3dcf5f0 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -132,7 +132,7 @@ class WriteCode(Action): code = await self.write_code(prompt) if not coding_context.code_doc: # avoid root_path pydantic ValidationError if use WriteCode alone - root_path = self.i_context.src_workspace if self.i_context.src_workspace else "" + root_path = self.context.src_workspace if self.context.src_workspace else "" coding_context.code_doc = Document(filename=coding_context.filename, root_path=str(root_path)) coding_context.code_doc.content = code return coding_context diff --git a/metagpt/config2.py b/metagpt/config2.py index cb5c22ac2..30d3818f6 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -121,12 +121,10 @@ class Config(CLIParams, YamlModel): return llm[0] return None - def get_llm_config(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> LLMConfig: + def get_llm_config(self, name: Optional[str] = None, provider: LLMType = None) -> LLMConfig: """Return a LLMConfig instance""" if provider: llm_configs = self.get_llm_configs_by_type(provider) - if name: - llm_configs = [c for c in llm_configs if c.name == name] if len(llm_configs) == 0: raise ValueError(f"Cannot find llm config with name {name} and provider {provider}") diff --git a/metagpt/context.py b/metagpt/context.py index e2bead828..35892f3f3 100644 --- a/metagpt/context.py +++ b/metagpt/context.py @@ -77,7 +77,7 @@ class Context(BaseModel): # self._llm = None # return self._llm - def llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM: + def llm(self, name: Optional[str] = None, provider: LLMType = None) -> BaseLLM: """Return a LLM instance, fixme: support cache""" # if self._llm is None: self._llm = create_llm_instance(self.config.get_llm_config(name, provider)) @@ -85,6 +85,14 @@ class Context(BaseModel): self._llm.cost_manager = self.cost_manager return self._llm + def llm_with_cost_manager_from_llm_config(self, llm_config: LLMConfig) -> BaseLLM: + """Return a LLM instance, fixme: support cache""" + # if self._llm is None: + llm = create_llm_instance(llm_config) + if llm.cost_manager is None: + llm.cost_manager = self.cost_manager + return llm + class ContextMixin(BaseModel): """Mixin class for context and config""" @@ -132,7 +140,7 @@ class ContextMixin(BaseModel): """Set llm""" self.set("_llm", llm, override) - def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM: + def use_llm(self, name: Optional[str] = None, provider: LLMType = None) -> BaseLLM: """Use a LLM instance""" self._llm_config = self.config.get_llm_config(name, provider) self._llm = None @@ -165,9 +173,9 @@ class ContextMixin(BaseModel): @property def llm(self) -> BaseLLM: """Role llm: role llm > context llm""" - # print(f"class:{self.__class__.__name__}({self.name}), llm: {self._llm}, llm_config: {self._llm_config}") + print(f"class:{self.__class__.__name__}({self.name}), llm: {self._llm}, llm_config: {self._llm_config}") if self._llm_config and not self._llm: - self._llm = self.context.llm(self._llm_config.name, self._llm_config.provider) + self._llm = self.context.llm_with_cost_manager_from_llm_config(self._llm_config) return self._llm or self.context.llm() @llm.setter diff --git a/tests/metagpt/actions/test_write_code.py b/tests/metagpt/actions/test_write_code.py index cfc5863f4..792b89d90 100644 --- a/tests/metagpt/actions/test_write_code.py +++ b/tests/metagpt/actions/test_write_code.py @@ -19,8 +19,8 @@ from metagpt.const import ( TEST_OUTPUTS_FILE_REPO, ) from metagpt.context import CONTEXT +from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.openai_api import OpenAILLM as LLM from metagpt.schema import CodingContext, Document from metagpt.utils.common import aread from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE @@ -32,7 +32,7 @@ async def test_write_code(): filename="task_filename.py", design_doc=Document(content="设计一个名为'add'的函数,该函数接受两个整数作为输入,并返回它们的和。") ) doc = Document(content=ccontext.model_dump_json()) - write_code = WriteCode(context=doc) + write_code = WriteCode(i_context=doc) code = await write_code.run() logger.info(code.model_dump_json()) @@ -86,7 +86,7 @@ async def test_write_code_deps(): ) coding_doc = Document(root_path="snake1", filename="game.py", content=ccontext.json()) - action = WriteCode(context=coding_doc) + action = WriteCode(i_context=coding_doc) rsp = await action.run() assert rsp assert rsp.code_doc.content