diff --git a/expo/MCTS.py b/expo/MCTS.py index 9d778e4ed..7e1d7c88a 100644 --- a/expo/MCTS.py +++ b/expo/MCTS.py @@ -235,7 +235,8 @@ class Node: score_dict = {k: normalize_score(v) for k, v in score_dict.items()} self.normalized_reward = score_dict - return score_dict + result_dict = role.get_solution() + return score_dict, result_dict class MCTS: @@ -281,7 +282,7 @@ class MCTS: mcts_logger.log("MCTS", f"Start simulating node {node.id}:") while node.children: node = random.choice(node.children) - reward = await node.run_node(role) + reward, result_dict = await node.run_node(role) mcts_logger.log("MCTS", f"Simulated node's reward: {reward}") return reward @@ -341,12 +342,17 @@ class MCTS: scores["test_raw"].append(node.raw_reward["test_score"]) return scores - async def search(self, state, rollouts, load_tree=False, reflection=False): + async def search(self, state, args): + reflection = args.reflection + load_tree = args.load_tree + rollouts = args.rollouts + from_scratch = args.from_scratch role, root = initialize_di_root_node(state, reflection=reflection) self.root_node = root self.instruction_generator = InstructionGenerator( - file_path=state["exp_pool_path"], use_fixed_insights=self.use_fixed_insights + state=state, use_fixed_insights=self.use_fixed_insights, from_scratch=from_scratch ) + await self.instruction_generator.initialize() tree_loaded = False if load_tree: diff --git a/expo/experimenter/mcts.py b/expo/experimenter/mcts.py index c063268c8..d212eb204 100644 --- a/expo/experimenter/mcts.py +++ b/expo/experimenter/mcts.py @@ -24,12 +24,7 @@ class MCTSExperimenter(Experimenter): mcts = Random(root_node=None, max_depth=5, use_fixed_insights=self.args.use_fixed_insights) else: mcts = MCTS(root_node=None, max_depth=5, use_fixed_insights=self.args.use_fixed_insights) - best_nodes = await mcts.search( - state=self.state, - reflection=self.args.reflection, - rollouts=self.args.rollouts, - load_tree=self.args.load_tree, - ) + best_nodes = await mcts.search(state=self.state, args=self.args) best_node = best_nodes["global_best"] dev_best_node = best_nodes["dev_best"] score_dict = best_nodes["scores"] diff --git a/expo/insights/instruction_generator.py b/expo/insights/instruction_generator.py index 330795730..7fa4d72ea 100644 --- a/expo/insights/instruction_generator.py +++ b/expo/insights/instruction_generator.py @@ -1,6 +1,7 @@ import json import os import random +from difflib import SequenceMatcher from expo.insights.solution_designer import SolutionDesigner from expo.utils import clean_json_from_rsp, load_data_config, mcts_logger @@ -33,11 +34,21 @@ DATA_CONFIG = load_data_config() class InstructionGenerator: data_config = DATA_CONFIG - def __init__(self, file_path, use_fixed_insights=False): - self.file_path = file_path + def __init__(self, state, use_fixed_insights, from_scratch): + self.state = state + self.file_path = state["exp_pool_path"] + self.dataset_info_path = f"{self.data_config['datasets_dir']}/{state['task']}/dataset_info.json" + with open(self.dataset_info_path, "r") as file: + self.dataset_info = json.load(file) self.use_fixed_insights = use_fixed_insights - self.analysis_pool = self.load_insight_pool(file_path, use_fixed_insights) self.proposer = SolutionDesigner() + self.from_scratch = from_scratch + + async def initialize(self): + if self.from_scratch: + self.insight_pool = await self.generate_solutions_from_scratch(self.dataset_info, self.state["task"]) + else: + self.insight_pool = self.load_insight_pool(self.file_path, self.use_fixed_insights) @staticmethod def load_json_data(json_dir): @@ -84,14 +95,14 @@ class InstructionGenerator: data.extend(fixed_insights) for item in data: if "task_id" not in item: - raise ValueError("task_id is not found in the analysis pool") + raise ValueError("task_id is not found in the insight_pool") if task_id: data = [item for item in data if int(item["task_id"]) == int(task_id)] return data async def generate_new_instructions(self, task_id, original_instruction, max_num, ext_info=None): - data = self.analysis_pool + data = self.insight_pool new_instructions = [] if len(data) == 0: mcts_logger.log("MCTS", f"No insights available for task {task_id}") @@ -108,6 +119,34 @@ class InstructionGenerator: new_instructions.append(new_instruction) return new_instructions + async def propose_new_insights(self, solution, score): + new_insights = await self.proposer.propose_insights(solution, score) + added_insights = self.add_insight(new_insights) + return added_insights + + async def generate_solutions_from_scratch(self, dataset_info, dataset_name): + insight_pool = await self.proposer.generate_solutions(dataset_info, dataset_name, save_analysis_pool=False) + return insight_pool + + def add_insight(self, new_insights): + added_insights = [] + for new_insight in new_insights: + if not self.is_similar_to_existing(new_insight): + added_insights.append(new_insight) + self.insight_pool.append(new_insight) + return added_insights + + def is_similar_to_existing(self, new_insight, similarity_threshold=0.8): + for existing_insight in self.insight_pool: + similarity = self.calculate_similarity(new_insight["Analysis"], existing_insight["Analysis"]) + if similarity > similarity_threshold: + return True + return False + + @staticmethod + def calculate_similarity(text1, text2): + return SequenceMatcher(None, text1, text2).ratio() + @staticmethod async def generate_new_instruction(original_instruction, insights, ext_info): prompt = CHANGE_INSTRUCTION.format(instruction=original_instruction, insights=insights) diff --git a/expo/insights/solution_designer.py b/expo/insights/solution_designer.py index 9968131ca..2336911db 100644 --- a/expo/insights/solution_designer.py +++ b/expo/insights/solution_designer.py @@ -70,6 +70,45 @@ Your model choices should be advanced enough to be helpful. ``` """ + +INSIGHT_PROPOSAL_PROMPT = """ +You are an AI assistant tasked with analyzing a machine learning solution and proposing new insights to improve its performance. Given the current solution code and development score, suggest innovative approaches to enhance the model. + +Current Solution Code: +{solution_code} + +Development Score: {dev_score} + +Based on this information, propose 3-5 new insights across different aspects of the machine learning pipeline (Data Preprocessing, Feature Engineering, and Model Training). Your insights should be specific, actionable, and have the potential to improve the model's performance. + +Please format your response as a JSON array with the following structure: +[ + + {{ + "task_type": "Data Preprocessing", + "insights": [ + "insight1", + "insight2" + ] + }}, + {{ + "task_type": "Feature Engineering", + "insights": [ + "insight1", + "insight2" + ] + }}, + {{ + "task_type": "Model Training", + "insights": [ + "insight1", + "insight2" + ] + }} +] +""" + + KEY_DATASET_FEATURES = [ "NumberOfClasses", "NumberOfFeatures", @@ -86,7 +125,7 @@ TASK_TO_ID = {"EDA": 1, "Data Preprocessing": 2, "Feature Engineering": 3, "Mode class SolutionDesigner: data_dir: str = DATA_CONFIG["datasets_dir"] - async def generate_solutions(self, dataset_info, dataset_name): + async def generate_solutions(self, dataset_info, dataset_name, save_analysis_pool=True): llm = LLM() context = DATASET_INSIGHT_PROMPT.format( dataset=dataset_info["description"], @@ -96,8 +135,18 @@ class SolutionDesigner: rsp = await llm.aask(context) rsp = clean_json_from_rsp(rsp) analysis_pool = self.process_analysis_pool(json.loads(rsp)) - dataset_path = f"{self.data_dir}/{dataset_name}" - self.save_analysis_pool(dataset_path, analysis_pool) + if save_analysis_pool: + dataset_path = f"{self.data_dir}/{dataset_name}" + self.save_analysis_pool(dataset_path, analysis_pool) + return analysis_pool + + async def propose_new_insights(self, solution, score): + llm = LLM() + context = INSIGHT_PROPOSAL_PROMPT.format(solution_code=solution, dev_score=score) + rsp = await llm.aask(context) + rsp = clean_json_from_rsp(rsp) + new_insights = self.process_analysis_pool(json.loads(rsp)) + return new_insights def process_analysis_pool(self, insights_rsp): analysis_pool = [] diff --git a/expo/research_assistant.py b/expo/research_assistant.py index fb34ece38..0b53521a3 100644 --- a/expo/research_assistant.py +++ b/expo/research_assistant.py @@ -139,6 +139,11 @@ class ResearchAssistant(DataInterpreter): save_notebook(role=self, save_dir=self.role_dir, name=self.get_node_name()) return task_result + def get_solution(self): + codes = [task.code for task in self.planner.plan.tasks] + results = [task.result for task in self.planner.plan.tasks] + return {"codes": codes, "results": results} + def save_state(self, static_save=False): """ attribute: diff --git a/expo/run_experiment.py b/expo/run_experiment.py index 15be27d60..c43da12fd 100644 --- a/expo/run_experiment.py +++ b/expo/run_experiment.py @@ -32,6 +32,9 @@ def get_mcts_args(parser): parser.add_argument("--rollouts", type=int, default=5) parser.add_argument("--use_fixed_insights", dest="use_fixed_insights", action="store_true") parser.add_argument("--start_task_id", type=int, default=2) + parser.add_argument( + "--from_scratch", dest="from_scratch", action="store_true", help="Generate solutions from scratch" + ) def get_aug_exp_args(parser):