diff --git a/data/inference/make_datasets/utils.py b/data/inference/make_datasets/utils.py index 081c1bc1f..284f8d976 100644 --- a/data/inference/make_datasets/utils.py +++ b/data/inference/make_datasets/utils.py @@ -1,5 +1,4 @@ import re -import re def extract_diff(response): diff --git a/data/inference/run.py b/data/inference/run.py index 96d9cc082..a3f3c54aa 100644 --- a/data/inference/run.py +++ b/data/inference/run.py @@ -8,11 +8,9 @@ original_argv = sys.argv.copy() try: # 设置你想要传递给脚本的命令行参数 - sys.argv = ['run_api.py', '--dataset_name_or_path', 'princeton-nlp/SWE-bench_oracle', '--output_dir', - './outputs'] + sys.argv = ["run_api.py", "--dataset_name_or_path", "princeton-nlp/SWE-bench_oracle", "--output_dir", "./outputs"] # 执行脚本 - runpy.run_path(path_name='run_api.py', run_name='__main__') + runpy.run_path(path_name="run_api.py", run_name="__main__") finally: # 恢复原始的sys.argv以避免对后续代码的潜在影响 sys.argv = original_argv - diff --git a/data/inference/run_api.py b/data/inference/run_api.py index 9202d6a42..7882f13e7 100644 --- a/data/inference/run_api.py +++ b/data/inference/run_api.py @@ -10,12 +10,12 @@ from make_datasets.utils import extract_diff from tenacity import retry, stop_after_attempt, wait_random_exponential from tqdm.auto import tqdm +from data.inference.const import SCIKIT_LEARN_IDS from metagpt.config2 import config from metagpt.logs import logger from metagpt.roles.di.data_interpreter import DataInterpreter from metagpt.utils import count_string_tokens from metagpt.utils.recovery_util import save_history -from data.inference.const import SCIKIT_LEARN_IDS # Replace with your own MAX_TOKEN = 128000 @@ -71,7 +71,7 @@ async def openai_inference( for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"): di = DataInterpreter(use_reflection=use_reflection) instance_id = datum["instance_id"] - + if instance_id in existing_ids: continue output_dict = {"instance_id": instance_id} diff --git a/metagpt/roles/di/data_interpreter.py b/metagpt/roles/di/data_interpreter.py index 11be96dcd..0e2cce309 100644 --- a/metagpt/roles/di/data_interpreter.py +++ b/metagpt/roles/di/data_interpreter.py @@ -43,66 +43,66 @@ class DataInterpreter(Role): tool_recommender: ToolRecommender = None react_mode: Literal["plan_and_act", "react"] = "plan_and_act" max_react_loop: int = 10 # used for react mode - + @model_validator(mode="after") def set_plan_and_tool(self) -> "Interpreter": self._set_react_mode(react_mode=self.react_mode, max_react_loop=self.max_react_loop, auto_run=self.auto_run) self.use_plan = ( - self.react_mode == "plan_and_act" + self.react_mode == "plan_and_act" ) # create a flag for convenience, overwrite any passed-in value if self.tools: self.tool_recommender = BM25ToolRecommender(tools=self.tools) self.set_actions([WriteAnalysisCode]) self._set_state(0) return self - + @property def working_memory(self): return self.rc.working_memory - + async def _think(self) -> bool: """Useful in 'react' mode. Use LLM to decide whether and what to do next.""" user_requirement = self.get_memories()[0].content context = self.working_memory.get() - + if not context: # just started the run, we need action certainly self.working_memory.add(self.get_memories()[0]) # add user requirement to working memory self._set_state(0) return True - + prompt = REACT_THINK_PROMPT.format(user_requirement=user_requirement, context=context) rsp = await self.llm.aask(prompt) rsp_dict = json.loads(CodeParser.parse_code(block=None, text=rsp)) self.working_memory.add(Message(content=rsp_dict["thoughts"], role="assistant")) need_action = rsp_dict["state"] self._set_state(0) if need_action else self._set_state(-1) - + return need_action - + async def _act(self) -> Message: """Useful in 'react' mode. Return a Message conforming to Role._act interface.""" code, _, _ = await self._write_and_exec_code() return Message(content=code, role="assistant", cause_by=WriteAnalysisCode) - + async def _plan_and_act(self) -> Message: rsp = await super()._plan_and_act() await self.execute_code.terminate() return rsp - + async def _act_on_task(self, current_task: Task) -> TaskResult: """Useful in 'plan_and_act' mode. Wrap the output in a TaskResult for review and confirmation.""" code, result, is_success = await self._write_and_exec_code() task_result = TaskResult(code=code, result=result, is_success=is_success) return task_result - + async def _write_and_exec_code(self, max_retry: int = 3): counter = 0 success = False - + # plan info plan_status = self.planner.get_plan_status() if self.use_plan else "" - + # tool info if self.tools: context = ( @@ -112,46 +112,48 @@ class DataInterpreter(Role): tool_info = await self.tool_recommender.get_recommended_tool_info(context=context, plan=plan) else: tool_info = "" - + # data info await self._check_data() - + while not success and counter < max_retry: ### write code ### code, cause_by = await self._write_code(counter, plan_status, tool_info) - + self.working_memory.add(Message(content=code, role="assistant", cause_by=cause_by)) - + ### execute code ### - import pdb;pdb.set_trace() + import pdb + + pdb.set_trace() result, success = await self.execute_code.run(code) print(result) - + self.working_memory.add(Message(content=result, role="user", cause_by=ExecuteNbCode)) - + ### process execution result ### counter += 1 - + if not success and counter >= max_retry: logger.info("coding failed!") review, _ = await self.planner.ask_review(auto_run=False, trigger=ReviewConst.CODE_REVIEW_TRIGGER) if ReviewConst.CHANGE_WORDS[0] in review: counter = 0 # redo the task again with help of human suggestions - + return code, result, success - + async def _write_code( - self, - counter: int, - plan_status: str = "", - tool_info: str = "", + self, + counter: int, + plan_status: str = "", + tool_info: str = "", ): todo = self.rc.todo # todo is WriteAnalysisCode logger.info(f"ready to {todo.name}") use_reflection = counter > 0 and self.use_reflection # only use reflection after the first trial - + user_requirement = self.get_memories()[0].content - + code = await todo.run( user_requirement=user_requirement, plan_status=plan_status, @@ -159,19 +161,19 @@ class DataInterpreter(Role): working_memory=self.working_memory.get(), use_reflection=use_reflection, ) - + return code, todo - + async def _check_data(self): if ( - not self.use_plan - or not self.planner.plan.get_finished_tasks() - or self.planner.plan.current_task.task_type - not in [ - TaskType.DATA_PREPROCESS.type_name, - TaskType.FEATURE_ENGINEERING.type_name, - TaskType.MODEL_TRAIN.type_name, - ] + not self.use_plan + or not self.planner.plan.get_finished_tasks() + or self.planner.plan.current_task.task_type + not in [ + TaskType.DATA_PREPROCESS.type_name, + TaskType.FEATURE_ENGINEERING.type_name, + TaskType.MODEL_TRAIN.type_name, + ] ): return logger.info("Check updated data")