diff --git a/metagpt/actions/di/write_analysis_code.py b/metagpt/actions/di/write_analysis_code.py index 97eb81def..185926e31 100644 --- a/metagpt/actions/di/write_analysis_code.py +++ b/metagpt/actions/di/write_analysis_code.py @@ -50,17 +50,15 @@ class WriteAnalysisCode(Action): ) working_memory = working_memory or [] - context = [Message(content=structual_prompt, role="user")] + working_memory - context = process_message(context) + context = process_message([Message(content=structual_prompt, role="user")] + working_memory) # LLM call - if not use_reflection: + if use_reflection: + code = await self._debug_with_reflection(context=context, working_memory=working_memory) + else: rsp = await self.llm.aask(context, system_msgs=[INTERPRETER_SYSTEM_MSG], **kwargs) code = CodeParser.parse_code(block=None, text=rsp) - else: - code = await self._debug_with_reflection(context=context, working_memory=working_memory) - return code diff --git a/metagpt/roles/di/data_interpreter.py b/metagpt/roles/di/data_interpreter.py index c24c78a90..a8534b710 100644 --- a/metagpt/roles/di/data_interpreter.py +++ b/metagpt/roles/di/data_interpreter.py @@ -39,7 +39,7 @@ class DataInterpreter(Role): use_plan: bool = True use_reflection: bool = False execute_code: ExecuteNbCode = Field(default_factory=ExecuteNbCode, exclude=True) - tools: Union[str, list[str]] = [] + tools: Union[str, list[str]] = [] # Use special symbol [""] to indicate use of all registered tools tool_recommender: ToolRecommender = None react_mode: Literal["plan_and_act", "react"] = "plan_and_act" max_react_loop: int = 10 # used for react mode @@ -53,6 +53,7 @@ class DataInterpreter(Role): if self.tools: self.tool_recommender = BM25ToolRecommender(tools=self.tools) self.set_actions([WriteAnalysisCode]) + self._set_state(0) return self @property @@ -140,13 +141,13 @@ class DataInterpreter(Role): async def _write_code( self, - counter, - plan_status="", - tool_info="", + counter: int, + plan_status: str = "", + tool_info: str = "", ): - todo = WriteAnalysisCode() + todo = self.rc.todo # todo is WriteAnalysisCode logger.info(f"ready to {todo.name}") - use_reflection = counter > 0 and self.use_reflection + use_reflection = counter > 0 and self.use_reflection # only use reflection after the first trial user_requirement = self.get_memories()[0].content @@ -176,7 +177,6 @@ class DataInterpreter(Role): code = await CheckData().run(self.planner.plan) if not code.strip(): return - success = False result, success = await self.execute_code.run(code) if success: print(result) diff --git a/metagpt/tools/tool_recommend.py b/metagpt/tools/tool_recommend.py index 9b00a7379..69b9a4b5d 100644 --- a/metagpt/tools/tool_recommend.py +++ b/metagpt/tools/tool_recommend.py @@ -8,7 +8,7 @@ import numpy as np from pydantic import BaseModel, field_validator from rank_bm25 import BM25Okapi -from metagpt.actions import Action +from metagpt.llm import LLM from metagpt.logs import logger from metagpt.schema import Plan from metagpt.tools import TOOL_REGISTRY @@ -49,11 +49,6 @@ Recommend up to {topk} tools from 'Available Tools' that can help solve the 'Use """ -class RecommendTool(Action): - async def run(self, prompt): - return await self._aask(prompt) - - class ToolRecommender(BaseModel): """ The default ToolRecommender: @@ -67,6 +62,7 @@ class ToolRecommender(BaseModel): @field_validator("tools", mode="before") @classmethod def validate_tools(cls, v: list[str]) -> dict[str, Tool]: + # One can use special symbol [""] to indicate use of all registered tools if v == [""]: return TOOL_REGISTRY.get_all_tools() else: @@ -136,7 +132,7 @@ class ToolRecommender(BaseModel): available_tools=available_tools, topk=topk, ) - rsp = await RecommendTool().run(prompt) + rsp = await LLM().aask(prompt) rsp = CodeParser.parse_code(block=None, text=rsp) ranked_tools = json.loads(rsp) @@ -160,9 +156,11 @@ class TypeMatchToolRecommender(ToolRecommender): task_type = plan.current_task.task_type candidate_tools = TOOL_REGISTRY.get_tools_by_tag(task_type) candidate_tool_names = set(self.tools.keys()) & candidate_tools.keys() - recalled_tools = [candidate_tools[tool_name] for tool_name in candidate_tool_names] + recalled_tools = [candidate_tools[tool_name] for tool_name in candidate_tool_names][:topk] - return recalled_tools[:topk] + logger.info(f"Recalled tools: \n{[tool.name for tool in recalled_tools]}") + + return recalled_tools class BM25ToolRecommender(ToolRecommender): diff --git a/tests/metagpt/roles/di/test_data_interpreter.py b/tests/metagpt/roles/di/test_data_interpreter.py index f51f5bbfc..d25e5a099 100644 --- a/tests/metagpt/roles/di/test_data_interpreter.py +++ b/tests/metagpt/roles/di/test_data_interpreter.py @@ -22,6 +22,7 @@ async def test_interpreter(mocker, auto_run): assert len(finished_tasks[0].code) > 0 # check one task to see if code is recorded +@pytest.mark.asyncio async def test_interpreter_react_mode(mocker): mocker.patch("metagpt.actions.di.execute_nb_code.ExecuteNbCode.run", return_value=("a successful run", True)) mocker.patch("builtins.input", return_value="confirm") diff --git a/tests/metagpt/tools/test_tool_convert.py b/tests/metagpt/tools/test_tool_convert.py index f85b84b71..061a619ce 100644 --- a/tests/metagpt/tools/test_tool_convert.py +++ b/tests/metagpt/tools/test_tool_convert.py @@ -45,7 +45,6 @@ class DummyClass: pass -# def dummy_fn(df: pd.DataFrame, s: str, k: int = 5, type: Literal["a", "b", "c"] = "a") -> dict: def dummy_fn( df: pd.DataFrame, s: str,