From e960ac8dc8eafca456e902890d4825b7622a2677 Mon Sep 17 00:00:00 2001 From: yzlin Date: Tue, 12 Mar 2024 16:43:24 +0800 Subject: [PATCH] update requirement and example, recover legacy code --- ...h_tools.py => machine_learning_complex.py} | 2 +- metagpt/strategy/task_type.py | 10 +++++---- metagpt/tools/tool_recommend.py | 22 ++++++++++++++++++- requirements.txt | 2 ++ 4 files changed, 30 insertions(+), 6 deletions(-) rename examples/di/{machine_learning_with_tools.py => machine_learning_complex.py} (92%) diff --git a/examples/di/machine_learning_with_tools.py b/examples/di/machine_learning_complex.py similarity index 92% rename from examples/di/machine_learning_with_tools.py rename to examples/di/machine_learning_complex.py index 42c0ef55b..42059ac4f 100644 --- a/examples/di/machine_learning_with_tools.py +++ b/examples/di/machine_learning_complex.py @@ -4,7 +4,7 @@ from metagpt.roles.di.data_interpreter import DataInterpreter async def main(requirement: str): - role = DataInterpreter(tools=[""]) + role = DataInterpreter(use_reflection=True, tools=[""]) await role.run(requirement) diff --git a/metagpt/strategy/task_type.py b/metagpt/strategy/task_type.py index 28a86f100..9eeeb79ce 100644 --- a/metagpt/strategy/task_type.py +++ b/metagpt/strategy/task_type.py @@ -19,29 +19,31 @@ class TaskTypeDef(BaseModel): class TaskType(Enum): + """By identifying specific types of tasks, we can inject human priors (guidance) to help task solving""" + EDA = TaskTypeDef( name="eda", desc="For performing exploratory data analysis", guidance=EDA_PROMPT, ) DATA_PREPROCESS = TaskTypeDef( - name="data_preprocess", + name="data preprocessing", desc="For preprocessing dataset in a data analysis or machine learning task ONLY," "general data operation doesn't fall into this type", guidance=DATA_PREPROCESS_PROMPT, ) FEATURE_ENGINEERING = TaskTypeDef( - name="feature_engineering", + name="feature engineering", desc="Only for creating new columns for input data.", guidance=FEATURE_ENGINEERING_PROMPT, ) MODEL_TRAIN = TaskTypeDef( - name="model_train", + name="model train", desc="Only for training model.", guidance=MODEL_TRAIN_PROMPT, ) MODEL_EVALUATE = TaskTypeDef( - name="model_evaluate", + name="model evaluate", desc="Only for evaluating model.", guidance=MODEL_EVALUATE_PROMPT, ) diff --git a/metagpt/tools/tool_recommend.py b/metagpt/tools/tool_recommend.py index 9e9bf4a01..9b00a7379 100644 --- a/metagpt/tools/tool_recommend.py +++ b/metagpt/tools/tool_recommend.py @@ -62,7 +62,7 @@ class ToolRecommender(BaseModel): """ tools: dict[str, Tool] = {} - force: bool = False + force: bool = False # whether to forcedly recommend the specified tools @field_validator("tools", mode="before") @classmethod @@ -145,6 +145,26 @@ class ToolRecommender(BaseModel): return list(valid_tools.values())[:topk] +class TypeMatchToolRecommender(ToolRecommender): + """ + A legacy ToolRecommender using task type matching at the recall stage: + 1. Recall: Find tools based on exact match between task type and tool tag; + 2. Rank: LLM rank, the same as the default ToolRecommender. + """ + + async def recall_tools(self, context: str = "", plan: Plan = None, topk: int = 20) -> list[Tool]: + if not plan: + return list(self.tools.values())[:topk] + + # find tools based on exact match between task type and tool tag + task_type = plan.current_task.task_type + candidate_tools = TOOL_REGISTRY.get_tools_by_tag(task_type) + candidate_tool_names = set(self.tools.keys()) & candidate_tools.keys() + recalled_tools = [candidate_tools[tool_name] for tool_name in candidate_tool_names] + + return recalled_tools[:topk] + + class BM25ToolRecommender(ToolRecommender): """ A ToolRecommender using BM25 at the recall stage: diff --git a/requirements.txt b/requirements.txt index 64b174913..d0ee8c95c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -71,3 +71,5 @@ Pillow imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py qianfan==0.3.2 dashscope==1.14.1 +rank-bm25==0.2.2 # for tool recommendation +jieba==0.42.1 # for tool recommendation \ No newline at end of file