From 56e7a08a1c7304b5262da58a55546eb0244cb9c0 Mon Sep 17 00:00:00 2001 From: Yizhou Chi Date: Fri, 11 Oct 2024 14:48:56 +0800 Subject: [PATCH] insight pool is now able to dynamically increase --- expo/MCTS.py | 16 ++++++++++------ expo/insights/instruction_generator.py | 16 +++++++++------- expo/insights/solution_designer.py | 1 + 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/expo/MCTS.py b/expo/MCTS.py index 7de123572..9d778e4ed 100644 --- a/expo/MCTS.py +++ b/expo/MCTS.py @@ -155,18 +155,15 @@ class Node: role = role.model_copy() role.save_state(static_save=True) - async def expand(self, max_children, use_fixed_insights): + async def expand(self, max_children: int, instruction_generator: InstructionGenerator): if self.is_fully_expanded(): return - insight_geneartor = InstructionGenerator() role = self.load_role() original_instruction = role.get_next_instruction() - insights = await insight_geneartor.generate_new_instructions( + insights = await instruction_generator.generate_new_instructions( task_id=role.start_task_id + 1, original_instruction=original_instruction, max_num=max_children, - file_path=self.state["exp_pool_path"], - use_fixed_insights=use_fixed_insights, ) new_state = self.state.copy() new_state["start_task_id"] += 1 @@ -249,6 +246,8 @@ class MCTS: c_explore: float = 1.4 c_unvisited: float = 0.8 node_order: list = [] + # insight generator + instruction_generator: InstructionGenerator = None def __init__(self, root_node, max_depth, use_fixed_insights): self.root_node = root_node @@ -272,7 +271,7 @@ class MCTS: return max(all_children, key=uct) async def expand(self, node: Node, max_children=5): - await node.expand(max_children, self.use_fixed_insights) + await node.expand(max_children, self.instruction_generator) if node not in self.children or not self.children[node]: self.children[node] = node.children return node.children @@ -284,6 +283,7 @@ class MCTS: node = random.choice(node.children) reward = await node.run_node(role) mcts_logger.log("MCTS", f"Simulated node's reward: {reward}") + return reward def backpropagate(self, node: Node, reward): @@ -344,6 +344,10 @@ class MCTS: async def search(self, state, rollouts, load_tree=False, reflection=False): role, root = initialize_di_root_node(state, reflection=reflection) self.root_node = root + self.instruction_generator = InstructionGenerator( + file_path=state["exp_pool_path"], use_fixed_insights=self.use_fixed_insights + ) + tree_loaded = False if load_tree: tree_loaded = self.load_tree() diff --git a/expo/insights/instruction_generator.py b/expo/insights/instruction_generator.py index 07e5fb655..ae6c742fb 100644 --- a/expo/insights/instruction_generator.py +++ b/expo/insights/instruction_generator.py @@ -2,6 +2,7 @@ import json import os import random +from expo.insights.solution_designer import SolutionDesigner from expo.utils import clean_json_from_rsp, load_data_config, mcts_logger from metagpt.llm import LLM from metagpt.schema import Message @@ -32,6 +33,12 @@ DATA_CONFIG = load_data_config() class InstructionGenerator: data_config = DATA_CONFIG + def __init__(self, file_path, use_fixed_insights=False): + self.file_path = file_path + self.use_fixed_insights = use_fixed_insights + self.analysis_pool = self.load_analysis_pool(file_path, use_fixed_insights) + self.proposer = SolutionDesigner() + @staticmethod def load_json_data(json_dir): with open(json_dir, "r") as file: @@ -83,13 +90,8 @@ class InstructionGenerator: data = [item for item in data if int(item["task_id"]) == int(task_id)] return data - @staticmethod - async def generate_new_instructions( - task_id, original_instruction, max_num, file_path, ext_info=None, use_fixed_insights=False - ): - data = InstructionGenerator.load_analysis_pool( - file_path, task_id=task_id, use_fixed_insights=use_fixed_insights - ) + async def generate_new_instructions(self, task_id, original_instruction, max_num, ext_info=None): + data = self.analysis_pool new_instructions = [] if len(data) == 0: mcts_logger.log("MCTS", f"No insights available for task {task_id}") diff --git a/expo/insights/solution_designer.py b/expo/insights/solution_designer.py index b1fcf4188..9968131ca 100644 --- a/expo/insights/solution_designer.py +++ b/expo/insights/solution_designer.py @@ -21,6 +21,7 @@ The insights should be proposed based on the dataset description with different Each task type should have at least 5 insights. Make sure each method is diverse enough and can be implemented separately. Be specific about models' choices, ensemble and tuning techniques, and preprocessing & feature engineering techniques. +Your model choices should be advanced enough to be helpful. # Format ```json